├── requirements.txt ├── LICENSE ├── convert_data.py ├── generate_cam_gifs ├── README.md ├── visualize.py ├── cam_animation.py └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.3 2 | cycler==0.10.0 3 | decorator==4.0.11 4 | h5py==2.7.1 5 | Keras==2.0.3 6 | matplotlib==2.0.0 7 | networkx==1.11 8 | numpy==1.12.1 9 | olefile==0.44 10 | packaging==16.8 11 | Pillow==4.1.1 12 | protobuf==3.2.0 13 | pydevd==1.1.1 14 | pydot==1.2.3 15 | pyparsing==2.2.0 16 | python-dateutil==2.6.0 17 | pytz==2017.2 18 | PyWavelets==0.5.2 19 | PyYAML==3.12 20 | scikit-image==0.13.0 21 | scipy==0.19.0 22 | six==1.10.0 23 | tensorflow==1.1.0 24 | Theano==0.9.0 25 | tqdm==4.11.2 26 | Werkzeug==0.12.1 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Yujan Shrestha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /convert_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reads image files, decodes jpeg, resizes them, and stores them into an npz so 3 | these operations do not have to be performed more than once. 4 | """ 5 | import argparse 6 | import glob 7 | import os 8 | import sys 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | import scipy.ndimage 13 | import scipy.misc 14 | import tqdm 15 | 16 | IMG_SIZE = (256, 256) 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser(description="Preprocess data for Simpsons classifier") 20 | parser.add_argument('--data-dir', required=True, help="Directory of input data") 21 | args = parser.parse_args(sys.argv[1:]) 22 | 23 | character_name_to_pixels = defaultdict(list) 24 | 25 | input_data = list(glob.glob(os.path.join(args.data_dir, '**/*.jpg'))) 26 | 27 | for image_path in tqdm.tqdm(input_data): 28 | image_pixels = scipy.ndimage.imread(image_path) 29 | resized_image_pixels = scipy.misc.imresize(image_pixels, IMG_SIZE) 30 | image_basepath, _ = os.path.splitext(image_path) 31 | np.savez(image_basepath+'.npz', pixels=resized_image_pixels, compressed=True) 32 | -------------------------------------------------------------------------------- /generate_cam_gifs: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # create_cam_gif 4 | function create_cam_gif() { 5 | final_gif_name="../data_dir/cam_output/$1/$2/$3/$2_$3.gif" 6 | mkdir -p ../data_dir/cam_output/$1/$2/$3 7 | CUDA_VISIBLE_DEVICES='' DISPLAY=:2 python cam_animation.py --weight-directory ../data_dir/weights/$1 --data-directory ../data_dir/simpsons_dataset --image-path ../data_dir/simpsons_dataset/$2/$3 --cam-path ../data_dir/cam_output/$1/$2/$3 --weight-limit 100 8 | convert -delay 30 -size 256x256 ../data_dir/cam_output/$1/$2/$3/*.png -loop 0 $final_gif_name 9 | } 10 | 11 | function create_cam_frame() { 12 | model=$1 13 | character=$2 14 | weight_file=$3 15 | shift; shift; shift; 16 | files=$@ 17 | mkdir -p ../data_dir/cam_output/$model/$character 18 | CUDA_VISIBLE_DEVICES='' DISPLAY=:2 python cam_animation.py --weight-file $weight_file --data-directory ../data_dir/simpsons_dataset --cam-path ../data_dir/cam_output/$model/$character --images $files 19 | } 20 | 21 | max_jobs=8 22 | max_convert_jobs=20 23 | 24 | if [[ $# == 3 ]]; then 25 | create_cam_gif $@ 26 | elif [[ $# == 2 ]]; then 27 | model=$1 28 | character=$2 29 | for weight in ../data_dir/weights/$model/*.h5 30 | do 31 | if [[ -f $weight ]]; then 32 | while [ $(jobs | wc -l) -ge $max_jobs ]; do sleep 1; done 33 | create_cam_frame $model $character $weight ../data_dir/simpsons_dataset/$character/pic_00*.npz & 34 | fi 35 | done 36 | wait 37 | for cam_output in ../data_dir/cam_output/$model/$character/* 38 | do 39 | if [ -d $cam_output ]; then 40 | filename=$(basename $cam_output) 41 | final_gif_name="$cam_output/$character_$filename.gif" 42 | while [ $(jobs | wc -l) -ge $max_convert_jobs ]; do sleep 1; done 43 | convert -delay 30 -size 256x256 $cam_output/*.png -loop 0 $final_gif_name & 44 | fi 45 | done 46 | wait 47 | echo "CAM generation complete." 48 | else 49 | echo "./generate_cam_gifs " 50 | fi 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Classifier Using Pre-Trained Models in Keras 2 | 3 | This repository contains the example code for our [article on pre-trained deep 4 | learning models with Keras][article]. 5 | 6 | [article]: https://innolitics.com/articles/pretrained-models-with-keras/ 7 | 8 | Train, predict, visualize, and produce class-activation map animations for deep 9 | learning models in Keras using pre-trained models as their basis. 10 | 11 | ## Dependencies 12 | 13 | - Python 3.5+ 14 | - Imagemagick 7+ 15 | 16 | ## Running the Example 17 | 18 | ### 1. Download the [example dataset][simpsons-kaggle] 19 | 20 | [simpsons-kaggle]: https://www.kaggle.com/alexattia/the-simpsons-characters-dataset 21 | 22 | ### 2. Preprocess the data 23 | 24 | ```bash 25 | python convert_data.py --data-dir {path-to-data} 26 | ``` 27 | 28 | ### 3. Train the model 29 | 30 | ```bash 31 | python train.py --pretrained_model {model} \ 32 | --data-dir {path-to-data} \ 33 | --weight-directory {path-to-weight-directory} \ 34 | --tensorboard-directory {path-to-tensorboard-logdir} \ 35 | --epochs {max_epochs} 36 | ``` 37 | 38 | ### 4. Visualize model predictions 39 | 40 | ```bash 41 | python visualize.py --weight-file {path-to-weight-file} \ 42 | --data-directory {path-to-data} \ 43 | --output-directory {path-to-output-directory} \ 44 | --image-path {path-to-image-to-visualize} 45 | ``` 46 | 47 | ### 5. Generate a CAM plot 48 | 49 | ```bash 50 | python cam_animation.py --weight-directory {path-to-weight-directory} \ 51 | --data-directory {path-to-data-directory} \ 52 | --image-path {path-to-image-to-visualize} \ 53 | --cam-path {output-path-for-cam-images} \ 54 | --weight-limit {max-weights-to-plot} 55 | 56 | convert -delay 30 -size 256x256 {output-path-for-cam-images}/*.png -loop 0 {final-gif-name} 57 | ``` 58 | 59 | To make the generation of CAM plots easier, you can use the 60 | `./generate_cam_gifs` script. This assumes: 61 | 62 | - Data directory is `../data_dir/simpsons_dataset` 63 | - Weight directory is `../data_dir/weights` 64 | - CAM output path is `../data_dir/cam_output/{model}/{character}` 65 | - All names passed into the script are basenames 66 | 67 | ```bash 68 | # Generate a single CAM plot 69 | ./generate_cam_gifs {model} {character} {npz-file} 70 | # Generate CAM plots for the first 100 images of a character 71 | ./generate_cam_gifs {model} {character} 72 | ``` 73 | 74 | ## About Innolitics 75 | 76 | Innolitics is a team of talented software developers with medical and 77 | engineering backgrounds. Our mission is to accelerate progress in medical 78 | imaging by sharing knowledge, creating tools, and providing quality services to 79 | our clients, with the ultimate purpose of improving patient health. If you are 80 | working on a project that requires image processing or deep learning expertise, 81 | let us know! We offer [consulting and development services][company-site]. 82 | 83 | [company-site]: https://innolitics.com/ 84 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import argparse 4 | import os 5 | import glob 6 | 7 | import tqdm 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import matplotlib.gridspec as gridspec 11 | from keras.models import load_model 12 | 13 | from train import DataGenerator 14 | 15 | num_columns = 6 16 | num_rows = 3 17 | 18 | def plot_row_item(image_ax, labels_ax, pixels, top_character_names, top_character_probabilities): 19 | image_ax.imshow(pixels, interpolation='nearest', aspect='auto') 20 | y_pos = np.arange(len(top_character_names))*0.11 21 | labels_ax.barh(y_pos, top_character_probabilities, height=0.1, align='center', 22 | color='cyan', ecolor='black') 23 | labels_ax.set_xlim([0,1]) 24 | labels_ax.set_yticks(y_pos) 25 | labels_ax.set_yticklabels(top_character_names, position=(1,0)) 26 | labels_ax.invert_yaxis() 27 | labels_ax.tick_params( 28 | axis='both', 29 | which='both', 30 | bottom='off', 31 | top='off', 32 | labelbottom='off') 33 | image_ax.axis('off') 34 | 35 | def plot_prediction(pixels, model, data_encoder): 36 | fig = plt.figure() 37 | inner = gridspec.GridSpec(2, 1, wspace=0.05, hspace=0, height_ratios=[5, 1.2]) 38 | image_ax = plt.Subplot(fig, inner[0]) 39 | labels_ax = plt.Subplot(fig, inner[1]) 40 | 41 | predicted_labels = model.predict(np.array([pixels]), batch_size=1) 42 | character_name_to_probability = data_encoder.one_hot_decode(predicted_labels[0].astype(np.float64)) 43 | top_character_probability = sorted(character_name_to_probability.items(), 44 | key=lambda item_tup: item_tup[1], 45 | reverse=True)[:3] 46 | top_character_names, top_character_probabilities = zip(*top_character_probability) 47 | character_idx = data_encoder.one_hot_index(top_character_names[0]) 48 | 49 | plot_row_item(image_ax, labels_ax, pixels, top_character_names, top_character_probabilities) 50 | 51 | fig.add_subplot(image_ax) 52 | fig.add_subplot(labels_ax) 53 | return fig 54 | 55 | 56 | if __name__ =='__main__': 57 | parser = argparse.ArgumentParser(description="Visualize predictions for an *.npz file given a model weight file.") 58 | parser.add_argument('--weight-file', required=True, help="File containing the weights for the model") 59 | parser.add_argument('--data-directory', required=True, help="Directory containing all input images") 60 | parser.add_argument('--output-directory', required=True, help="Output directory for generated plots.") 61 | parser.add_argument('--image-path', required=True, nargs="+", help="*.npz file to generate predictions for. Can be a glob.") 62 | args = parser.parse_args(sys.argv[1:]) 63 | 64 | model = load_model(args.weight_file) 65 | data_encoder = DataGenerator(args.data_directory).encoder 66 | 67 | print("{} input image(s) found. Beginning prediction plotting.".format(len(args.image_path))) 68 | 69 | for image_path in tqdm.tqdm(args.image_path, unit='image'): 70 | pixels = np.load(image_path)['pixels'] 71 | fig = plot_prediction(pixels, model, data_generator) 72 | plt.savefig(os.path.join(args.output_directory, os.path.basename(image_path) + 'predictions.png')) 73 | plt.close(fig) 74 | -------------------------------------------------------------------------------- /cam_animation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | import tqdm 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import matplotlib.gridspec as gridspec 9 | from keras.models import load_model 10 | from vis.visualization.saliency import visualize_cam 11 | 12 | from train import DataGenerator 13 | from visualize import plot_row_item 14 | 15 | def get_model_predictions_for_npz(model, data_generator, character_name, npz_name): 16 | npz_file_path = os.path.join(data_generator.data_path, character_name, npz_name) 17 | pixels = np.load(npz_file_path)['pixels'] 18 | predicted_labels = model.predict(np.array([pixels]), batch_size=1) 19 | return data_generator.encoder.one_hot_decode(predicted_labels[0].astype(np.float64)) 20 | 21 | def cam_weighted_image(model, image_path, character_idx): 22 | pixels = np.load(image_path)['pixels'] 23 | cam = visualize_cam(model, layer_idx=-1, 24 | filter_indices=[character_idx], 25 | seed_input=pixels) 26 | return np.uint8(pixels*np.dstack([cam]*3)) 27 | 28 | def make_cam_plot(model, weight, image_path, cam_path, data_generator): 29 | path_head, npz_name = os.path.split(image_path) 30 | _, character_name = os.path.split(path_head) 31 | 32 | model_name = os.path.basename(os.path.dirname(weight)) 33 | 34 | character_idx = data_generator.encoder.one_hot_index(character_name) 35 | cam = cam_weighted_image(model, image_path, character_idx) 36 | 37 | fig = plt.figure() 38 | inner = gridspec.GridSpec(2, 1, wspace=0.05, hspace=0, height_ratios=[5, 1.2]) 39 | image_ax = plt.Subplot(fig, inner[0]) 40 | labels_ax = plt.Subplot(fig, inner[1]) 41 | character_name_to_probability = get_model_predictions_for_npz(model, 42 | data_generator, 43 | character_name, 44 | npz_name) 45 | top_character_probability = sorted(character_name_to_probability.items(), 46 | key=lambda item_tup: item_tup[1], 47 | reverse=True)[:3] 48 | top_character_names, top_character_probabilities = zip(*top_character_probability) 49 | 50 | plot_row_item(image_ax, labels_ax, cam, top_character_names, top_character_probabilities) 51 | weight_idx = os.path.basename(weight).split('.')[1] 52 | labels_ax.set_xlabel(npz_name) 53 | image_ax.set_title(model_name + ', epoch ' + weight_idx) 54 | 55 | fig.add_subplot(image_ax) 56 | fig.add_subplot(labels_ax) 57 | 58 | plt.savefig(os.path.join(cam_path, 'cam_{}.png'.format(weight_idx))) 59 | plt.close(fig) 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser(description="Generate an animation of class-activation maps") 63 | parser.add_argument('--weight-file', required=True, 64 | help="Model weight file") 65 | parser.add_argument('--data-directory', required=True, 66 | help="Directory containing the input *.npz images") 67 | parser.add_argument('--cam-path', required=True, 68 | help="Directory for storing CAM plots.") 69 | parser.add_argument('--images', required=True, nargs="+", 70 | help="Images to plot CAM for.") 71 | args = parser.parse_args(sys.argv[1:]) 72 | 73 | data_generator = DataGenerator(args.data_directory) 74 | 75 | 76 | model = load_model(args.weight_file) 77 | for image in tqdm.tqdm(args.images, unit="image"): 78 | try: 79 | image_cam_path = os.path.join(args.cam_path, os.path.basename(image)) 80 | os.makedirs(image_cam_path) 81 | except OSError as err: 82 | if err.errno != os.errno.EEXIST: 83 | raise err 84 | 85 | make_cam_plot(model, args.weight_file, image, image_cam_path, data_generator) 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | '''Builds a model, organizes and loads data, and runs model training.''' 2 | import argparse 3 | from collections import defaultdict 4 | import os 5 | import glob 6 | import random 7 | 8 | import keras 9 | import numpy as np 10 | 11 | from keras.layers import Input, Average 12 | from keras.layers.core import Dense, Flatten, Dropout 13 | from keras.layers.merge import Concatenate 14 | from keras.layers.normalization import BatchNormalization 15 | from keras.layers.pooling import GlobalAveragePooling2D, GlobalMaxPooling2D 16 | from keras.preprocessing.image import ImageDataGenerator 17 | from keras.models import Model 18 | 19 | def get_model(pretrained_model, all_character_names): 20 | if pretrained_model == 'inception': 21 | model_base = keras.applications.inception_v3.InceptionV3(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet') 22 | output = Flatten()(model_base.output) 23 | elif pretrained_model == 'xception': 24 | model_base = keras.applications.xception.Xception(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet') 25 | output = Flatten()(model_base.output) 26 | elif pretrained_model == 'resnet50': 27 | model_base = keras.applications.resnet50.ResNet50(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet') 28 | output = Flatten()(model_base.output) 29 | elif pretrained_model == 'vgg19': 30 | model_base = keras.applications.vgg19.VGG19(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet') 31 | output = Flatten()(model_base.output) 32 | elif pretrained_model == 'all': 33 | input = Input(shape=(*IMG_SIZE, 3)) 34 | inception_model = keras.applications.inception_v3.InceptionV3(include_top=False, input_tensor=input, weights='imagenet') 35 | xception_model = keras.applications.xception.Xception(include_top=False, input_tensor=input, weights='imagenet') 36 | resnet_model = keras.applications.resnet50.ResNet50(include_top=False, input_tensor=input, weights='imagenet') 37 | 38 | flattened_outputs = [Flatten()(inception_model.output), 39 | Flatten()(xception_model.output), 40 | Flatten()(resnet_model.output)] 41 | output = Concatenate()(flattened_outputs) 42 | model_base = Model(input, output) 43 | 44 | output = BatchNormalization()(output) 45 | output = Dropout(0.5)(output) 46 | output = Dense(128, activation='relu')(output) 47 | output = BatchNormalization()(output) 48 | output = Dropout(0.5)(output) 49 | output = Dense(len(all_character_names), activation='softmax')(output) 50 | model = Model(model_base.input, output) 51 | for layer in model_base.layers: 52 | layer.trainable = False 53 | model.summary(line_length=200) 54 | 55 | # Generate a plot of a model 56 | import pydot 57 | pydot.find_graphviz = lambda: True 58 | from keras.utils import plot_model 59 | plot_model(model, show_shapes=True, to_file='../model_pdfs/{}.pdf'.format(pretrained_model)) 60 | 61 | model.compile(optimizer='adam', 62 | loss='categorical_crossentropy', 63 | metrics=['accuracy']) 64 | return model 65 | 66 | BATCH_SIZE = 64 67 | IMG_SIZE = (256, 256) 68 | 69 | image_datagen = ImageDataGenerator( 70 | rotation_range=15, 71 | width_shift_range=.15, 72 | height_shift_range=.15, 73 | shear_range=0.15, 74 | zoom_range=0.15, 75 | channel_shift_range=1, 76 | horizontal_flip=True, 77 | vertical_flip=False,) 78 | 79 | class DataEncoder(): 80 | def __init__(self, all_character_names): 81 | self.all_character_names = all_character_names 82 | 83 | def one_hot_index(self, character_name): 84 | return self.all_character_names.index(character_name) 85 | 86 | def one_hot_decode(self, predicted_labels): 87 | return dict(zip(self.all_character_names, predicted_labels)) 88 | 89 | def one_hot_encode(self, character_name): 90 | one_hot_encoded_vector = np.zeros(len(self.all_character_names)) 91 | idx = self.one_hot_index(character_name) 92 | one_hot_encoded_vector[idx] = 1 93 | return one_hot_encoded_vector 94 | 95 | 96 | class DataGenerator(): 97 | def __init__(self, data_path): 98 | self.data_path = data_path 99 | self.partition_to_character_name_to_npz_paths = { 100 | 'train': defaultdict(list), 101 | 'validation': defaultdict(list), 102 | 'test': defaultdict(list), 103 | } 104 | self.all_character_names = set() 105 | npz_file_listing = list(glob.glob(os.path.join(data_path, '**/*.npz'))) 106 | for npz_path in npz_file_listing: 107 | character_name = os.path.basename(os.path.dirname(npz_path)) 108 | self.all_character_names.add(character_name) 109 | if hash(npz_path) % 10 < 7: 110 | partition = 'train' 111 | elif 7 <= hash(npz_path) % 10 < 9: 112 | partition = 'validation' 113 | elif 9 == hash(npz_path) % 10: 114 | partition = 'test' 115 | else: 116 | raise Exception("partition not assigned") 117 | self.partition_to_character_name_to_npz_paths[partition][character_name].append(npz_path) 118 | self.encoder = DataEncoder(sorted(list(self.all_character_names))) 119 | 120 | 121 | def _pair_generator(self, partition, augmented=True): 122 | while True: 123 | for character_name, npz_paths in self.partition_to_character_name_to_npz_paths[partition].items(): 124 | npz_path = random.choice(npz_paths) 125 | pixels = np.load(npz_path)['pixels'] 126 | one_hot_encoded_labels = self.encoder.one_hot_encode(character_name) 127 | if augmented: 128 | augmented_pixels = next(image_datagen.flow(np.array([pixels])))[0].astype(np.uint8) 129 | yield augmented_pixels, one_hot_encoded_labels 130 | else: 131 | yield pixels, one_hot_encoded_labels 132 | 133 | 134 | def batch_generator(self, partition, batch_size, augmented=True): 135 | while True: 136 | data_gen = self._pair_generator(partition, augmented) 137 | pixels_batch, one_hot_encoded_character_name_batch = zip(*[next(data_gen) for _ in range(batch_size)]) 138 | pixels_batch = np.array(pixels_batch) 139 | one_hot_encoded_character_name_batch = np.array(one_hot_encoded_character_name_batch) 140 | yield pixels_batch, one_hot_encoded_character_name_batch 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('--pretrained_model', choices={'inception', 'xception', 'resnet50', 'all', 'vgg19'}) 146 | parser.add_argument('--data-dir', required=True) 147 | parser.add_argument('--weight-directory', required=True, 148 | help="Directory containing the model weight files") 149 | parser.add_argument('--tensorboard-directory', required=True, 150 | help="Directory containing the Tensorboard log files") 151 | parser.add_argument('--epochs', required=True, type=int, 152 | help="Number of epochs to train over.") 153 | args = parser.parse_args() 154 | 155 | tensorboard_callback = keras.callbacks.TensorBoard(log_dir=args.tensorboard_directory, 156 | histogram_freq=0, 157 | write_graph=True, 158 | write_images=False) 159 | save_model_callback = keras.callbacks.ModelCheckpoint(os.path.join(args.weight_directory, 'weights.{epoch:02d}.h5'), 160 | verbose=3, 161 | save_best_only=False, 162 | save_weights_only=False, 163 | mode='auto', 164 | period=1) 165 | 166 | data_generator = DataGenerator(args.data_dir) 167 | model = get_model(args.pretrained_model, data_generator.encoder.all_character_names) 168 | 169 | model.fit_generator( 170 | data_generator.batch_generator('train', batch_size=BATCH_SIZE), 171 | steps_per_epoch=200, 172 | epochs=args.epochs, 173 | validation_data=data_generator.batch_generator('validation', batch_size=BATCH_SIZE, augmented=False), 174 | validation_steps=10, 175 | callbacks=[save_model_callback, tensorboard_callback], 176 | workers=4, 177 | pickle_safe=True, 178 | ) 179 | --------------------------------------------------------------------------------