-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpanns_wrapper.py
More file actions
87 lines (74 loc) · 3.04 KB
/
panns_wrapper.py
File metadata and controls
87 lines (74 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import sys
import torch
import numpy as np
# 动态添加 PANNs 源码路径
PANNS_PATH = "/Users/corlin/2026/audioset_tagging_cnn/pytorch"
if PANNS_PATH not in sys.path:
sys.path.append(PANNS_PATH)
from models import Cnn14
class PANNs_Inference:
def __init__(self, model_path, device='cpu'):
self.device = device
# Cnn14 默认参数 (AudioSet 训练标准)
self.sample_rate = 32000
self.window_size = 1024
self.hop_size = 320
self.mel_bins = 64
self.fmin = 50
self.fmax = 14000
self.classes_num = 527
print(f"[*] 正在初始化 PANNs (Cnn14) 模型...")
self.model = Cnn14(
sample_rate=self.sample_rate,
window_size=self.window_size,
hop_size=self.hop_size,
mel_bins=self.mel_bins,
fmin=self.fmin,
fmax=self.fmax,
classes_num=self.classes_num
)
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model'])
self.model.to(self.device)
self.model.eval()
# 加载标签映射 (复用 HTS-AT 的标签文件,因为底册是一致的)
self.idx_2_label = self._load_labels("/Users/corlin/2026/HTS-Audio-Transformer/class_label_indice.csv")
print(f"[*] PANNs 权重加载成功: {os.path.basename(model_path)}")
def _load_labels(self, csv_path):
import pandas as pd
if os.path.exists(csv_path):
df = pd.read_csv(csv_path)
return {int(row['index']): row['display_name'] for _, row in df.iterrows()}
return {}
def get_label_name(self, index):
return self.idx_2_label.get(index, f"Unknown_{index}")
def predict(self, audio_data):
"""
推理接口
:param audio_data: np.array, shape (N,), float32, normalized to [-1, 1]
:return: probabilities (527,)
"""
with torch.no_grad():
# 峰值归一化,与 HTS-AT 对齐
max_val = np.max(np.abs(audio_data))
if max_val > 0.001:
audio_data = audio_data / max_val * 0.8
# 转换为 tensor (batch_size=1)
audio_tensor = torch.from_numpy(audio_data).float().to(self.device)
audio_tensor = audio_tensor.unsqueeze(0)
# 推理
output_dict = self.model(audio_tensor, None)
probs = output_dict['clipwise_output'].cpu().numpy()[0]
return probs
if __name__ == "__main__":
# 单元测试
MODEL_FILE = "/Users/corlin/2026/usb_4_mic_array/model/Cnn14_mAP=0.431.pth"
if os.path.exists(MODEL_FILE):
panns = PANNs_Inference(MODEL_FILE)
dummy_audio = np.random.uniform(-1, 1, 32000 * 2).astype(np.float32)
probs = panns.predict(dummy_audio)
print(f"测试推理完成,前10个类别的概率:\n{probs[:10]}")
print(f"分类索引 0 的名称: {panns.get_label_name(0)}")
else:
print(f"错误:找不到模型文件 {MODEL_FILE}")