forked from a1ex90/MusicalKeyCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
108 lines (80 loc) · 2.91 KB
/
Copy pathapi.py
File metadata and controls
108 lines (80 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import tempfile
from contextlib import asynccontextmanager
from pathlib import Path
from typing import List
import torch
from fastapi import FastAPI, File, HTTPException, UploadFile
from pydantic import BaseModel
from audio_utils import compute_waveform_basic, compute_waveform_hmb, compute_waveform_rainbow, load_audio, preprocess_from_waveform
from eval import load_model
from predict_bpm import detect_bpm
from predict_keys import SUPPORTED_EXTENSIONS, camelot_output
MODEL_PATH = Path("checkpoints/keynet.pt")
_model = None
_device = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global _model, _device
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model = load_model(str(MODEL_PATH), _device)
_model.eval()
yield
app = FastAPI(title="MusicalKeyCNN API", version="0.1.0", lifespan=lifespan)
class WaveformBasic(BaseModel):
times: List[float]
amplitudes: List[float]
class WaveformHMB(BaseModel):
times: List[float]
bass: List[float]
mid: List[float]
high: List[float]
class WaveformRGB(BaseModel):
times: List[float]
r: List[float]
g: List[float]
b: List[float]
class PredictResponse(BaseModel):
filename: str
class_id: int
camelot: str
key: str
bpm: float
waveform_basic: WaveformBasic
waveform_hmb: WaveformHMB
waveform_rainbow: WaveformRGB
@app.get("/health")
def health():
return {"status": "ok", "model": "keynet", "device": str(_device)}
@app.post("/predict", response_model=PredictResponse)
async def predict(file: UploadFile = File(...)):
suffix = Path(file.filename).suffix.lower()
if suffix not in SUPPORTED_EXTENSIONS:
raise HTTPException(status_code=400, detail=f"Unsupported format '{suffix}'. Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}")
contents = await file.read()
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(contents)
tmp_path = Path(tmp.name)
try:
waveform, sr = load_audio(tmp_path)
spec_tensor = preprocess_from_waveform(waveform, sr)
spec_tensor = spec_tensor.unsqueeze(0).to(_device)
with torch.no_grad():
outputs = _model(spec_tensor)
pred = int(torch.argmax(outputs, dim=1).cpu().item())
camelot_str, key_text = camelot_output(pred)
bpm = detect_bpm(waveform, sr)
waveform_basic = compute_waveform_basic(waveform, sr)
waveform_hmb = compute_waveform_hmb(waveform, sr)
waveform_rainbow = compute_waveform_rainbow(waveform, sr)
finally:
tmp_path.unlink(missing_ok=True)
return PredictResponse(
filename=file.filename,
class_id=pred,
camelot=camelot_str,
key=key_text,
bpm=bpm,
waveform_basic=WaveformBasic(**waveform_basic),
waveform_hmb=WaveformHMB(**waveform_hmb),
waveform_rainbow=WaveformRGB(**waveform_rainbow),
)