-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
82 lines (71 loc) · 2.15 KB
/
data.py
File metadata and controls
82 lines (71 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from tqdm import tqdm
from functools import partial
from torch.utils.data import DataLoader
from modelscope.msdatasets import MsDataset
from torchvision.transforms import *
def transform(example_batch, input_size=300):
compose = Compose(
[
Resize([input_size, input_size]),
RandomAffine(5),
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
inputs = [compose(x.convert("RGB")) for x in example_batch["mel"]]
example_batch["mel"] = inputs
return example_batch
def prepare_data(use_fl: bool):
print("Preparing & loading data...")
ds = MsDataset.load(
"ccmusic-database/pianos",
subset_name="eval",
cache_dir="./__pycache__",
)
classes = ds["test"].features["label"].names
sizes = []
if use_fl:
num_samples_in_each_category = {k: 0 for k in classes}
for item in tqdm(ds["train"], desc="Statistics by category for focal loss..."):
num_samples_in_each_category[classes[item["label"]]] += 1
sizes = list(num_samples_in_each_category.values())
return ds, classes, sizes
def load_data(
ds: MsDataset,
insize,
has_bn=False,
batch_size=4,
shuffle=True,
num_workers=2,
):
bs = batch_size
if has_bn:
print("The model has bn layer")
if bs < 2:
print("Switch batch_size >= 2")
bs = 2
trainset = ds["train"].with_transform(partial(transform, input_size=insize))
validset = ds["validation"].with_transform(partial(transform, input_size=insize))
testset = ds["test"].with_transform(partial(transform, input_size=insize))
traLoader = DataLoader(
trainset,
batch_size=bs,
shuffle=shuffle,
num_workers=num_workers,
drop_last=has_bn,
)
valLoader = DataLoader(
validset,
batch_size=bs,
shuffle=shuffle,
num_workers=num_workers,
drop_last=has_bn,
)
tesLoader = DataLoader(
testset,
batch_size=bs,
shuffle=shuffle,
num_workers=num_workers,
drop_last=has_bn,
)
return traLoader, valLoader, tesLoader