-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
56 lines (46 loc) · 2.74 KB
/
Copy pathmain.py
File metadata and controls
56 lines (46 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import torch
import numpy as np
import soundfile as sf
from PIL import Image
from scipy.signal import fftconvolve
from torchvision import transforms
from train import RGBD_RIR_Network, device
from trainDepthModel import predict_depth_from_rgb
from config import RIR_LEN
# ─── CONFIG ───
MODEL_PATH = "best_rgbd_rir_model.pth"
IMAGE_PATH = "testDataForReal/Room4.jpg" # Input RGB image
DRY_WAV = "testDataForReal/anechoic/sample1.wav" # Dry audio input
OUTPUT_WAV = "testDataForReal/predictedRIR/predicted_room4-1.wav" # Output simulated audio
RIR_LEN = RIR_LEN # Length of the predicted RIR
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ─── 1. Predict Depth Map ─────────────────────────────────────────────
print("Predicting depth from RGB...")
depth_map = predict_depth_from_rgb(IMAGE_PATH) # → (224, 224), float32
# ─── 2. Load RGB image and combine with Depth ─────────────────────────
print("Preprocessing RGBD input...")
rgb_img = Image.open(IMAGE_PATH).convert("RGB").resize((224, 224))
rgb = np.array(rgb_img).astype(np.float32) / 255.0 # → (H, W, 3)
depth = depth_map[..., np.newaxis] / 255.0 # → (H, W, 1)
rgbd = np.concatenate([rgb, depth], axis=2).astype(np.float32) # → (224, 224, 4)
rgbd_tensor = torch.tensor(rgbd).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
# ─── 3. Load model and predict RIR ────────────────────────────────────
print("Predicting Room Impulse Response...")
model = RGBD_RIR_Network().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
model.eval()
with torch.no_grad():
pred_rir = model(rgbd_tensor)[0].cpu().numpy()
pred_rir = pred_rir[:RIR_LEN]
# ─── 4. Convolve dry audio with predicted RIR ─────────────────────────
print("Applying RIR to dry audio...")
dry, sr = sf.read(DRY_WAV)
print(f"Dry shape: {dry.shape}, RIR shape: {pred_rir.shape}")
# If stereo input → convert to mono (either take one channel or average)
if len(dry.shape) == 2:
dry = dry.mean(axis=1) # Alternatively: dry = dry[:, 0]
wet = fftconvolve(dry, pred_rir)[:len(dry)]
# ─── 5. Save output audio ─────────────────────────────────────────────
sf.write(OUTPUT_WAV, wet, sr)
print(f"Done! Predicted audio saved to {OUTPUT_WAV}")