-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_stack_size.py
More file actions
114 lines (95 loc) · 2.94 KB
/
train_stack_size.py
File metadata and controls
114 lines (95 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import tensorflow as tf
import os
import datetime
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense
from tensorflow.keras.models import Sequential
from shared import test_model, AccuracyThresholdCallback
training_name = "stack_size"
scale = 1 / 255
image_size = (46, 46)
batch_size = 16
epochs = 140
steps_per_epoch = 3
load_existing_model = False
accuracy_threshold = 1.0
loss_threshold = 0.009
test_empty_image = True
folder = "./images/{}/".format(training_name)
training_folder = "{}/training/".format(folder)
validation_folder = training_folder.replace("training", "validation")
original_folder = training_folder.replace("training", "original")
test_folder = validation_folder
classes = os.listdir(training_folder)
num_classes = len(classes)
log_dir = "logs/fit/{}/{}".format(training_name,
datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
output_path = "./training/{}/".format(training_name)
model_output_path = "{}{}.h5".format(output_path, training_name)
classes_output_path = "{}{}-classes.txt".format(output_path, training_name)
train = ImageDataGenerator(rescale=scale)
validation = ImageDataGenerator(rescale=scale)
train_dataset = train.flow_from_directory(
training_folder,
target_size=image_size,
batch_size=batch_size,
class_mode="categorical"
)
validation_dataset = train.flow_from_directory(
validation_folder,
target_size=image_size,
batch_size=batch_size,
class_mode="categorical"
)
print("Class indices: ", validation_dataset.class_indices)
model = Sequential([
Conv2D(16, (3, 3), activation="relu", input_shape=(
image_size[0], image_size[1], 3)),
MaxPool2D(2, 2),
Conv2D(32, (3, 3), activation="relu"),
MaxPool2D(2, 2),
Conv2D(64, (3, 3), activation="relu"),
MaxPool2D(2, 2),
Flatten(),
Dense(512, activation="relu"),
Dense(num_classes, activation="softmax")
])
model.compile(
loss="categorical_crossentropy",
optimizer=RMSprop(learning_rate=0.001),
metrics=["accuracy"]
)
model.summary()
if load_existing_model and os.path.exists(model_output_path):
model.load_weights(model_output_path)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1
)
accuracy_threshold_callback = AccuracyThresholdCallback(
accuracy_threshold=accuracy_threshold,
loss_threshold=loss_threshold
)
model.fit(
train_dataset,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[
tensorboard_callback,
accuracy_threshold_callback
],
validation_data=validation_dataset
)
model.save(model_output_path)
f = open(classes_output_path, "w")
for c in classes:
f.write(c + ",")
f.close()
test_model(
model=model,
test_folder=test_folder,
classes=classes,
image_size=image_size,
test_empty_image=test_empty_image
)