├── requirements.txt ├── download_cifar10.sh ├── README.md ├── submit.sh ├── utils ├── callbacks.py └── loader.py ├── .gitignore ├── train.py └── models └── CNN.py /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard==1.14.0 2 | tensorflow==1.14.0 3 | tensorflow-base==1.14.0 4 | tensorflow-estimator==1.14.0 5 | tensorflow-gpu==1.14.0 6 | -------------------------------------------------------------------------------- /download_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz -P data/ 3 | cd data/ 4 | tar -xf cifar10.tgz 5 | rm cifar10.tgz 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Pegasus 2 | 3 | ##### Sample Workflow: Getting Started 4 | - Step 1: bash download_cifar10.sh 5 | - Step 2: pip install -r requirements.txt 6 | - Step 3: bash submit.sh 7 | -------------------------------------------------------------------------------- /submit.sh: -------------------------------------------------------------------------------- 1 | python train.py --checkpoint_dir '/mnt/project/grp_202/khan74/work/Pegasus' --run_id '0004' --data_dir 'data/cifar10/' \ 2 | --learning_rate 0.001 --batch_size 16 --epochs 2 3 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import tensorflow 2 | 3 | def CSVlogger(RUN_FOLDER): 4 | return tensorflow.keras.callbacks.CSVLogger(RUN_FOLDER + '/logs/epoch_history.log') 5 | 6 | def Regular_ModelCheckpoint(RUN_FOLDER): 7 | return tensorflow.keras.callbacks.ModelCheckpoint(RUN_FOLDER + '/weights/model_{epoch:02d}-{val_loss:.2f}.h5', \ 8 | monitor='val_loss', verbose=0, save_weights_only=False, period=1) 9 | 10 | def Early_Stopping(patience, min_delta=0, verbose=1): 11 | return tensorflow.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=min_delta, patience=patience, verbose=verbose, mode='min') 12 | 13 | def ReduceLROnPlateau(factor, patience, min_lr, verbose=1): 14 | return tensorflow.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=factor, patience=patience, min_lr=min_lr, verbose=verbose, mode='min') -------------------------------------------------------------------------------- /utils/loader.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 2 | 3 | def data_generators(train_dir, test_dir, batch_size=32, augment=True): 4 | 5 | if augment == True: 6 | train_datagen = ImageDataGenerator(rescale=1./255, 7 | fill_mode = "nearest", 8 | zoom_range = 0.2, 9 | width_shift_range = 0.2, 10 | height_shift_range=0.2, 11 | rotation_range=30) 12 | else: 13 | train_datagen = ImageDataGenerator(rescale=1./255) 14 | 15 | test_datagen = ImageDataGenerator(rescale=1./255) 16 | 17 | 18 | train_generator = train_datagen.flow_from_directory( 19 | directory=train_dir, 20 | target_size=(32, 32), 21 | batch_size=batch_size, 22 | class_mode='categorical') 23 | 24 | validation_generator = test_datagen.flow_from_directory( 25 | directory=test_dir, 26 | target_size=(32, 32), 27 | batch_size=batch_size, 28 | class_mode='categorical') 29 | 30 | return (train_generator, validation_generator) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | runs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | *.pyc 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # vim 108 | *.swp 109 | 110 | # Autogenerated svgs for formulas 111 | stem*.svg 112 | stem*.png 113 | 114 | # Mac OS nonsense 115 | .DS_Store 116 | 117 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | from utils.loader import data_generators 5 | from utils.callbacks import CSVlogger, Regular_ModelCheckpoint, Early_Stopping, ReduceLROnPlateau 6 | from models.CNN import Basic_CNN 7 | 8 | 9 | ### Set Params 10 | import argparse 11 | parser = argparse.ArgumentParser(description="sample script") 12 | parser.add_argument('--checkpoint_dir', help='root directory', default='/home/khan74/checkpoint/') 13 | parser.add_argument('--run_id', type=str, help='run ID', default='0001') 14 | parser.add_argument('--data_dir', help='root directory', default='/home/khan74/data/') 15 | parser.add_argument('--learning_rate', type=float, help='learning rate', default=0.001) 16 | parser.add_argument('--batch_size', type=int, help='batch size', default=64) 17 | parser.add_argument('--data_aug', type=bool, help='data augmentation', default=True) 18 | parser.add_argument('--epochs', type=int, help='num of epochs to train', default=1) 19 | parser.add_argument('--shuffle', type=bool, help='whether to shuffle data', default=True) 20 | parser.add_argument('--verbose', type=int, help='verbosity', default=1) 21 | parser.add_argument('--mode', type=str, help='mode: build or load', default='build') 22 | 23 | 24 | args = parser.parse_args() 25 | 26 | checkpoint_dir = args.checkpoint_dir 27 | RUN_ID = args.run_id 28 | DATA_DIR = args.data_dir 29 | LEARNING_RATE = args.learning_rate 30 | BATCH_SIZE = args.batch_size 31 | DATA_AUGMENTATION = args.data_aug 32 | EPOCHS = args.epochs 33 | SHUFFLE = args.shuffle 34 | VERBOSE = args.verbose 35 | MODE = args.mode 36 | 37 | RUN_FOLDER = os.path.join(checkpoint_dir , 'runs/{}/'.format(RUN_ID)) 38 | 39 | if not os.path.exists(RUN_FOLDER): 40 | os.makedirs(RUN_FOLDER) 41 | os.makedirs(RUN_FOLDER + '/weights/') 42 | os.makedirs(RUN_FOLDER +'/logs/') 43 | 44 | 45 | ### Load data 46 | train_generator, validation_generator = data_generators(train_dir=DATA_DIR+'/train/', 47 | test_dir=DATA_DIR+'/test/', 48 | batch_size=BATCH_SIZE, 49 | augment=DATA_AUGMENTATION) 50 | 51 | ### Define Model 52 | if MODE == 'build': #'load' # 53 | CNN = Basic_CNN( 54 | input_dim=(32, 32, 3) 55 | , conv_filters = [32,64,64] 56 | , conv_kernel_size = [3,3,3] 57 | , fc_layer_size = [64] 58 | ) 59 | else: 60 | pass 61 | 62 | 63 | ### Train the Model 64 | CNN.compile(LEARNING_RATE) 65 | 66 | callbacks = [ 67 | CSVlogger(RUN_FOLDER), 68 | Regular_ModelCheckpoint(RUN_FOLDER), 69 | Early_Stopping(patience=7, verbose=1), 70 | ReduceLROnPlateau(factor=0.5, patience=3, min_lr=1e-10, verbose=1), 71 | ] 72 | 73 | CNN.train(train_generator, validation_generator, epochs=EPOCHS, shuffle=SHUFFLE, verbose=VERBOSE, callbacks=callbacks) 74 | -------------------------------------------------------------------------------- /models/CNN.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.models import Model 2 | from tensorflow.keras.layers import Input, Conv2D, Dense, Dropout, Flatten, MaxPooling2D 3 | from tensorflow.keras.optimizers import Adam 4 | from tensorflow.keras import backend as K 5 | 6 | class Basic_CNN(): 7 | def __init__(self 8 | , input_dim 9 | , conv_filters 10 | , conv_kernel_size 11 | , fc_layer_size 12 | ): 13 | 14 | self.name = 'Basic_CNN' 15 | self.input_dim = input_dim 16 | self.conv_filters = conv_filters 17 | self.conv_kernel_size = conv_kernel_size 18 | self.fc_layer_size = fc_layer_size 19 | 20 | self.n_layers_conv = len(conv_filters) 21 | self.n_layers_dense = len(fc_layer_size) 22 | 23 | self._build() 24 | 25 | def _build(self): 26 | 27 | inp = Input(shape=self.input_dim, name='input') 28 | 29 | x = inp 30 | for i in range(self.n_layers_conv): 31 | 32 | conv_layer = Conv2D( 33 | filters = self.conv_filters[i] 34 | , kernel_size = self.conv_kernel_size[i] 35 | , padding = 'same' 36 | , activation='relu' 37 | , name = 'conv_' + str(i) 38 | ) 39 | 40 | x = conv_layer(x) 41 | x = MaxPooling2D((2, 2))(x) 42 | 43 | x = Flatten()(x) 44 | for i in range(self.n_layers_dense): 45 | 46 | x = Dense(self.fc_layer_size[i], activation='relu', name = 'dense_' + str(i))(x) 47 | 48 | out = Dense(10, activation='softmax')(x) 49 | 50 | model_input = inp 51 | model_output = out 52 | 53 | self.model = Model(model_input, model_output) 54 | 55 | def compile(self, learning_rate): 56 | self.learning_rate = learning_rate 57 | 58 | optimizer = Adam(lr=learning_rate) 59 | 60 | self.model.compile(optimizer=optimizer, 61 | loss='categorical_crossentropy', 62 | metrics=['accuracy']) 63 | 64 | def save_model(self, folder, checkpoint_name): 65 | 66 | if not os.path.exists(folder): 67 | os.makedirs(folder) 68 | self.model.save( os.path.join(folder, checkpoint_name, '.h5') ) 69 | 70 | def load_model(self, folder, checkpoint_name): 71 | 72 | self.model = tensorflow.keras.models.load_model( os.path.join(folder, checkpoint_name, '.h5') ) 73 | 74 | 75 | 76 | def train(self, train_generator, validation_generator, epochs=1, shuffle=True, verbose=1, callbacks=[]): 77 | 78 | history = self.model.fit_generator(generator=train_generator, 79 | validation_data=validation_generator, 80 | epochs=epochs, 81 | callbacks=callbacks, 82 | max_queue_size = 1000, 83 | workers = 32, 84 | use_multiprocessing = False, 85 | verbose=verbose) --------------------------------------------------------------------------------