-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgtsrb.py
More file actions
73 lines (59 loc) · 2.75 KB
/
gtsrb.py
File metadata and controls
73 lines (59 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import tensorflow as tf
from deep_models import street_sign_model
from utils import split_data, order_test_set, create_generators
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
if __name__ == "__main__":
# Activate this part of the code with the appropriate paths to use the
# split_data utility function to split the training dataset into 90% training and 10% validation
if False:
path_to_data = "./archive/Train"
path_to_save_train = "./archive/training_data/train"
path_to_save_val = "./archive/training_data/val"
split_data(path_to_data, path_to_save_train=path_to_save_train,
path_to_save_val=path_to_save_val)
# Activate with appropriate path to organize the test dataset based on the info in the csv file
if False:
path_to_imgs = "./archive/Test"
pathe_to_csv = "./archive/Test.csv"
order_test_set(path_to_imgs, pathe_to_csv)
path_to_train = "./archive/training_data/train"
path_to_val = "./archive/training_data/val"
path_to_test = "./archive/Test"
batch_size = 64
epochs = 15
train_generator, val_generator, test_generator = create_generators(
batch_size, path_to_train, path_to_val, path_to_test)
nmb_calsses = train_generator.num_classes
TRAIN = False
TEST = True
if TRAIN:
# A callback to save the model with the maximum accuracy value obtained during training
path_to_save_model = './ModelsData'
chpt_saver = ModelCheckpoint(
path_to_save_model,
monitor='val_accuracy',
mode='max',
save_best_only=True,
save_freq='epoch',
verbose=1
)
# A callback to stop the training if the accuracy doesn't improve after patience num of training rounds
early_stop = EarlyStopping(
monitor='val_accuracy',
patience=10,
)
model = street_sign_model(nmb_calsses)
# Using the same loss function used in the preprocessing
model.compile(optimizer='adam',
loss='categorical_crossentropy', metrics=['accuracy'])
# Passing the generator because it contains the img and its label which are the return
# Type of ImageDataGenerator.flow_from_directory used in the utils.py
model.fit(train_generator, epochs=epochs, batch_size=batch_size,
validation_data=val_generator, callbacks=[chpt_saver, early_stop])
if TEST:
model = tf.keras.models.load_model("./ModelsData")
model.summary()
print("evaluating the model on the validation set: ")
model.evaluate(val_generator)
print("evaluating the model on the test set: ")
model.evaluate(test_generator)