-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainResult.py
More file actions
48 lines (37 loc) · 1.66 KB
/
Copy pathtrainResult.py
File metadata and controls
48 lines (37 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import torch
import numpy as np
from scipy.io.wavfile import write as wav_write
from train import RGBD_RIR_Network, device
from rirReader import compare_rirs
def predict_rir_from_rgbd(rgbd_file, model_path="best_rgbd_rir_model.pth", output_wav_path="predicted_rir.wav", sample_rate=44100):
"""
Given a single RGBD image file, predict the Room Impulse Response (RIR) and save it as a .wav file.
"""
model = RGBD_RIR_Network().to(device)
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
model.eval()
rgbd = np.load(rgbd_file).astype(np.float32)
rgbd_tensor = torch.tensor(rgbd).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
pred = model(rgbd_tensor)[0].cpu().numpy()
wav_write(output_wav_path, sample_rate, pred)
print(f"RIR prediction saved to {output_wav_path}")
# === Predict RIR from a single RGBD image and visualize comparison with ground truth (based on padded RIRs) ===
rgbd_file = "dataset_2d/combined_rgbd/hm3d_qyAac8rV8Zk_453_c03.npy"
true_rir_npy = "dataset_2d/padded_rirs/hm3d_qyAac8rV8Zk_453_c03.npy"
pred_rir_file = "eval_wavs/pred_Adairsville.wav"
true_rir_file = "eval_wavs/true_Adairsville.wav"
model_path = "best_rgbd_rir_model.pth"
os.makedirs(os.path.dirname(pred_rir_file), exist_ok=True)
# 1. Save the predicted RIR
predict_rir_from_rgbd(
rgbd_file=rgbd_file,
model_path=model_path,
output_wav_path=pred_rir_file
)
# 2. Save the ground truth RIR as .wav (loaded from padded RIRs)
true_rir = np.load(true_rir_npy).astype(np.float32)
wav_write(true_rir_file, 44100, true_rir)
# 3. Compare waveforms
compare_rirs(true_rir_file, pred_rir_file)