-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_model.py
131 lines (108 loc) · 5.15 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
This is the script for training the model.
You should read keras documentation before using this script.
Change the values of epochs and epoch_len according to your needs,
training is done for 100 epochs, with training epoch len being 500 and
testing being 200
"""
import matplotlib
matplotlib.use('pdf')
import numpy as np
import matplotlib.pyplot as plt
import keras
import sys
import pickle
from keras.callbacks import ModelCheckpoint
from unet import unet_builder
from unet import losses
from utils import data_utils as du
from utils.unet_utils import ModelMGPU
def train_unet():
# Set the hyperparameters and store them to dictionary
batch_size = 8 # how many images to process before updating the gradient
n_classes = 13 # number of classes
n_chans = 14 # number of channels
epochs = 5 # number of epochs to train the model
rotation = 16 # How many different rotations to perform
reflection = True # Are random flips included in augmentations
# Next generate lists containing paths to training and validation images and masks
train_images = ['training_data/train_image_' + str(i) + '.npy' for i in range(9)]
train_masks = ['training_data/train_masks_' + str(i) + '.npy' for i in range(9)]
val_images = ['training_data/val_image_' + str(i) + '.npy' for i in range(2)]
val_masks = ['training_data/val_masks_' + str(i) + '.npy' for i in range(2)]
# Set parameters for train and test datagenerators
train_params = {'dim': [64,128,256,384],
'image_files': train_images,
'mask_files': train_masks,
'batch_size': batch_size,
'n_classes': n_classes,
'n_chans': n_chans,
'bands': list(range(n_chans)),
'single_label': None,
'rotation':rotation,
'reflection':reflection,
'epoch_len':10} # number of batches in one epoch
test_params = {'dim': [128,256],
'image_files': val_images,
'mask_files': val_masks,
'batch_size': batch_size,
'n_classes': n_classes,
'n_chans': n_chans,
'bands': list(range(n_chans)),
'single_label': None,
'rotation':rotation,
'reflection':reflection,
'epoch_len':4}
# Datagenerators for model
train_gen = du.DataGeneratorImage(**train_params)
val_gen = du.DataGeneratorImage(**test_params)
# Save weights from each epoch
cb = ModelCheckpoint('models/test-{epoch:02d}-{val_loss:.2f}.h5', monitor='val_loss', verbose=1, save_best_only=False)
# Use Nadam as optimizer. Other possibilities can be found https://keras.io/optimizers/
opt = keras.optimizers.Nadam()
# Loss and activation functions depend on binary or multiclass
# If binary classification, then activation is sigmoid and loss is binary crossentropy
#activation = 'sigmoid'
#loss = keras.losses.binary_crossentropy
# If multiclass classification, then activation is softmax and loss is categorical_crossentropy
activation = 'softmax'
loss = keras.losses.categorical_crossentropy
# We have also other possible losses, such as focal and Lovasz-Softmax, but they still need tuning and stuff
# Don't try them if you don't know what you are doing
#loss = losses.focal()
# Build and compile network
model = unet_builder.build_unet(n_chans, n_classes, activation=activation)
parallel_model = model
#If multiple gpus are possible and sensible to use, then you can use ModelMGPU
#parallel_model = ModelMGPU(model, gpus=2)
parallel_model.compile(loss=loss, optimizer=opt, metrics=['accuracy'])
# print(model.summary())
# Train the network.
model_train = parallel_model.fit_generator(train_gen, epochs=epochs,
verbose =1, validation_data = val_gen,
callbacks=[cb], use_multiprocessing=True, workers = 4)
# Save model history JUST IN CASE something about graphs goes wrong
with open('histories/test_history.obj', 'wb') as history:
pickle.dump(model_train.history, history, protocol=pickle.HIGHEST_PROTOCOL)
# Plot and save the training metrics
# Change savefig locations accordingly
# These may not work in taito-gpu but maybe.
plt.figure()
acc = model_train.history['acc']
val_acc = model_train.history['val_acc']
loss = model_train.history['loss']
val_loss = model_train.history['val_loss']
epoc = range(len(acc))
plt.plot(epoc, acc, 'bo', label='Training accuracy')
plt.plot(epoc, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()
plt.savefig('data/graphs/test_training_acc.pdf', bbox_inches='tight')
plt.figure()
plt.plot(epoc, loss, 'bo', label='Training loss')
plt.plot(epoc, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.savefig('data/graphs/test_training_loss.pdf' , bbox_inches='tight')
if __name__ == '__main__':
train_unet()