-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathutil.py
More file actions
158 lines (139 loc) · 4.71 KB
/
util.py
File metadata and controls
158 lines (139 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
'''
This is written by Jiyuan Liu, Dec. 21, 2021.
Homepage: https://liujiyuan13.github.io.
Email: liujiyuan13@163.com.
All rights reserved.
'''
import os
import numpy as np
import random
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
def set_seed(seed=0):
"""
set seed for torch.
@param seed: int, default 0
"""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
def load_data(data_dir, data_name, is_train, image_size, batch_size, n_worker):
"""
load data.
@param data_dir: data dir, data folder
@param data_name: data name
@param is_train: train data or test data
@param image_size: image size
@param batch_size: batch size
@param n_worker: number of workers to load data
@return: data_loader: loader for train data;
n_class: number of data classes
"""
# load data
if data_name is 'cifar10':
MEAN, STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)
])
data = datasets.CIFAR10(data_dir, transform=transform, train=is_train, download=True)
elif data_name is 'cifar100':
MEAN, STD = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)
])
data = datasets.CIFAR100(data_dir, transform=transform, train=is_train, download=True)
elif data_name is 'stl10':
transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.ToTensor()
])
data = datasets.STL10(data_dir, transform=transform, split='train' if is_train else 'test', download=True)
elif data_name is 'imagenet':
MEAN, STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) # constants in timm.data.constants
transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)
])
data = datasets.ImageFolder(os.path.join(data_dir, 'ImageNet1K', 'train' if is_train else 'val'), transform=transform)
else:
raise Exception(data_name + ': not supported yet.')
# obtain class number from test data
n_class = len(set(data.targets))
# create data loader
data_loader = DataLoader(data,
batch_size=batch_size,
shuffle=True,
num_workers=n_worker,
pin_memory=True,
drop_last=True)
return data_loader, n_class
def save_ckpt(model, optimizer, args, epoch, save_file):
'''
save checkpoint
:param model: target model
:param optimizer: used optimizer
:param args: training parameters
:param epoch: save at which epoch
:param save_file: file path
:return:
'''
ckpt = {
'args': args,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
}
torch.save(ckpt, save_file)
del ckpt
def load_ckpt(model, load_file):
'''
load ckpt to model
:param model: target model
:param load_file: file path
:return: the loaded model
'''
ckpt = torch.load(load_file)
model.load_state_dict(ckpt['model'])
del ckpt
return model
def accuracy(y_true, y_pred):
"""
compute classification accuracy.
# Arguments
y: true labels, numpy.array with shape `(n_samples,)`
y_pred: predicted labels, numpy.array with shape `(n_samples,)`
# Return
accuracy, in [0,1]
"""
assert y_pred.size == y_true.size
y_true, y_pred = y_true.astype(np.int64), y_pred.astype(np.int64)
return sum(np.equal(y_true, y_pred)) / y_true.size
class AverageMeter(object):
'''
compute and store the average and current value
'''
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count