-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhtsat_wrapper.py
More file actions
91 lines (75 loc) · 3.25 KB
/
htsat_wrapper.py
File metadata and controls
91 lines (75 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import sys
import torch
import numpy as np
import pandas as pd
# 自动定位 HTS-AT 源码路径
# 修改为当前项目下的 model 目录
DEFAULT_HTSAT_PATH = "/Users/corlin/2026/HTS-Audio-Transformer"
DEFAULT_CKPT_PATH = "/Users/corlin/2026/usb_4_mic_array/model/HTSAT_AudioSet_Saved_1.ckpt"
if DEFAULT_HTSAT_PATH not in sys.path:
sys.path.append(DEFAULT_HTSAT_PATH)
# 注意:这些导入必须在 sys.path 修改之后
import config
from model.htsat import HTSAT_Swin_Transformer
class HTSAT_Inference:
def __init__(self, checkpoint_path=None, device="cpu"):
self.device = torch.device(device)
self.hts_path = DEFAULT_HTSAT_PATH
if checkpoint_path is None:
checkpoint_path = DEFAULT_CKPT_PATH
self.idx_2_label = self._load_labels()
self.model = self._load_model(checkpoint_path)
print(f"* HTS-AT Model loaded from {checkpoint_path} on {self.device}")
def _load_labels(self):
"""加载 AudioSet 标签索引 (527类)"""
csv_path = os.path.join(self.hts_path, "class_label_indice.csv")
df = pd.read_csv(csv_path)
# AudioSet CSV 格式: index,mid,display_name
return {int(row['index']): row['display_name'] for _, row in df.iterrows()}
def _load_model(self, checkpoint_path):
"""加载并初始化模型"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# 清理 state_dict 中的 'sed_model.' 前缀
state_dict = checkpoint["state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
new_key = k.replace("sed_model.", "")
new_state_dict[new_key] = v
# 初始化模型结构 (参数需与 config.py 一致)
model = HTSAT_Swin_Transformer(
spec_size=config.htsat_spec_size,
patch_size=config.htsat_patch_size,
in_chans=1,
num_classes=config.classes_num,
window_size=config.htsat_window_size,
config=config,
depths=config.htsat_depth,
embed_dim=config.htsat_dim,
patch_stride=config.htsat_stride,
num_heads=config.htsat_num_head,
)
model.load_state_dict(new_state_dict)
model.to(self.device)
model.eval()
return model
def predict(self, waveform, rescale=True):
"""
输入: waveform (numpy array, 形状 [N,])
"""
if rescale:
# 峰值归一化到 0.8,防止数值过小导致 Sigmoid 停留在 0.5
max_val = np.max(np.abs(waveform))
if max_val > 0.001:
waveform = waveform / max_val * 0.8
with torch.no_grad():
x = torch.from_numpy(waveform).float().to(self.device)
if x.ndim == 1:
x = x[None, :]
output_dict = self.model(x, None, True)
pred = torch.sigmoid(output_dict["clipwise_output"])
return pred[0].cpu().numpy()
def get_label_name(self, index):
return self.idx_2_label.get(index, f"Unknown_{index}")