├── .gitignore ├── README.md ├── data_utils.py ├── evaluate_mobilenet.py ├── images ├── page1.png ├── page2.png ├── page3.png └── skip_model.png ├── model_utils.py ├── requirements.txt ├── train_mobilenet.py ├── train_utils.py └── weights ├── mobilenet_model.h5 └── mobilenet_model_v2.h5 /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | logs/* 3 | logs-old/* 4 | 5 | .idea/* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | .static_storage/ 62 | .media/ 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras Mobile Colorizer 2 | 3 | Utilizes a U-Net inspired model conditioned on MobileNet class features to generate a mapping from Grayscale to Color image. 4 | Based on the work https://github.com/baldassarreFe/deep-koalarization 5 | 6 | Uses MobileNets for memory efficiency in comparison to Inception-ResNet-V2 so that training can be done on a single GPU (of 4 GB size minimum). 7 | 8 | # Installation 9 | Open the `data_utils.py` script and edit the `TRAIN_IMAGE_PATH` and `VALIDATION_IMAGE_PATH` to point to directories of images. There must be at least 1 folder pointed to by each of those paths. 10 | 11 | Then run `data_utils.py` to construct the required folders and the TFRecords which will store the training data. 12 | 13 | This is necessary to drastically improve the speed of training by extracting all the MobileNet features from each training image before training. The major bottleneck during training is the extraction of image features from MobileNet at runtime. 14 | 15 | # Training & Evaluation 16 | 17 | - To train the model : Use the `train_mobilenet.py` script. Make sure to verify the batch size and how many images are in the TF record before beginning training. 18 | 19 | - To evaluate the model : Use the `evaluate_mobilenet.py` script. Make sure that the path to the validation images is provided in `data_utils.py` 20 | 21 | # Evaluation 22 | There are a lot of splotchy reddish-brown patches. This may probably be because training was done using only 60k images from MS-COCO dataset, not the full ImageNet dataset. 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | # Requirements 31 | - Keras 2.0.8+ 32 | - Numpy 33 | - Scikit-image 34 | - Tensorflow (GPU is a must for training, CPU is fine for inference) 35 | 36 | Install via `pip install -r "requirements.txt"` 37 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | 5 | from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb 6 | from skimage.transform import resize 7 | from skimage.io import imsave, imread 8 | 9 | import tensorflow as tf 10 | from tensorflow import data as tfdata 11 | sess = tf.Session() 12 | 13 | from keras import backend as K 14 | K.set_session(sess) 15 | 16 | sess = None 17 | 18 | from keras.applications.mobilenet import MobileNet, preprocess_input 19 | from keras.models import Model 20 | 21 | # change these to your local paths 22 | TRAIN_IMAGE_PATH = r"" 23 | VALIDATION_IMAGE_PATH = r"D:\Yue\Documents\Datasets\MSCOCO\val\\" 24 | 25 | IMAGE_SIZE = 128 # Global constant image size 26 | EMBEDDING_IMAGE_SIZE = 224 # Global constant embedding size 27 | 28 | TRAIN_RECORDS_PATH = "data/images.tfrecord" # local path to tf record directory 29 | VAL_RECORDS_PATH = "data/val_images.tfrecord" # local path to tf record directory 30 | 31 | 32 | if not os.path.exists('weights/'): 33 | os.makedirs('weights/') 34 | 35 | if not os.path.exists('results/'): 36 | os.makedirs('results/') 37 | 38 | if not os.path.exists('data/'): 39 | os.makedirs('data/') 40 | 41 | feature_extraction_model = None 42 | mobilenet_activations = None 43 | 44 | def _load_mobilenet(): 45 | global feature_extraction_model, mobilenet_activations 46 | 47 | # Feature extraction module 48 | feature_extraction_model = MobileNet(input_shape=(EMBEDDING_IMAGE_SIZE, EMBEDDING_IMAGE_SIZE, 3), 49 | alpha=1.0, 50 | depth_multiplier=1, 51 | include_top=True, 52 | weights='imagenet') 53 | 54 | # Set it up so that we can do inference on MobileNet without training it by mistake 55 | feature_extraction_model.graph = tf.get_default_graph() 56 | feature_extraction_model.trainable = False 57 | 58 | # Get the pre-softmax activations from MobileNet 59 | mobilenet_activations = Model(feature_extraction_model.input, feature_extraction_model.layers[-3].output) 60 | mobilenet_activations.trainable = False 61 | 62 | 63 | def _get_pre_activations(grayscale_image, batchsize=100): 64 | # batchwise retrieve feature map from last layer - pre softmax 65 | activations = mobilenet_activations.predict(grayscale_image, batch_size=batchsize) 66 | return activations 67 | 68 | 69 | def _extract_features(grayscaled_rgb, batchsize=100): 70 | # Load up MobileNet only when necessary, not during training 71 | if feature_extraction_model is None: 72 | _load_mobilenet() 73 | 74 | grayscaled_rgb_resized = [] 75 | 76 | for i in grayscaled_rgb: 77 | # Resize to size of MobileNet Input 78 | i = resize(i, (EMBEDDING_IMAGE_SIZE, EMBEDDING_IMAGE_SIZE, 3), mode='constant') 79 | grayscaled_rgb_resized.append(i) 80 | 81 | grayscaled_rgb_resized = np.array(grayscaled_rgb_resized) * 255. # scale to 0-255 range for MobileNet preprocess_input 82 | grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized) 83 | 84 | with feature_extraction_model.graph.as_default(): # using the shared graph of Colorization model and MobileNet 85 | features = _get_pre_activations(grayscaled_rgb_resized, batchsize) # batchwise get the feature maps 86 | features = features.reshape((-1, 1000)) 87 | 88 | return features 89 | 90 | 91 | def _float32_feature_list(floats): 92 | return tf.train.Feature(float_list=tf.train.FloatList(value=floats)) 93 | 94 | 95 | def _generate_records(images_path, tf_record_name, batch_size=100): 96 | ''' 97 | Creates a TF Record containing the pre-processed image consisting of 98 | 1) L channel input 99 | 2) ab channels output 100 | 3) features extracted from MobileNet 101 | 102 | This step is crucial for speed during training, as the major bottleneck 103 | is the extraction of feature maps from MobileNet. It is slow, and inefficient. 104 | ''' 105 | if os.path.exists(TRAIN_RECORDS_PATH): 106 | print("**** Delete old TF Records first! ****") 107 | exit(0) 108 | 109 | files = glob.glob(images_path + "*/*.jpg") 110 | files = sorted(files) 111 | nb_files = len(files) 112 | 113 | # Use ZLIB compression to save space and create a TFRecordWriter 114 | options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB) 115 | writer = tf.python_io.TFRecordWriter(tf_record_name, options) 116 | 117 | size = max(EMBEDDING_IMAGE_SIZE, IMAGE_SIZE) # keep larger size until stored in TF Record 118 | 119 | X_buffer = [] 120 | for i, fn in enumerate(files): 121 | try: # prevent crash due to corrupted imaged 122 | X = imread(fn) 123 | X = resize(X, (size, size, 3), mode='constant') # resize to the larger size for now 124 | except: 125 | continue 126 | 127 | X_buffer.append(X) 128 | 129 | if len(X_buffer) >= batch_size: 130 | X_buffer = np.array(X_buffer) 131 | _serialize_batch(X_buffer, writer, batch_size) # serialize the image into the TF Record 132 | 133 | del X_buffer # delete buffered images from memory 134 | X_buffer = [] # reset to new list 135 | 136 | print("Processed %d / %d images" % (i + 1, nb_files)) 137 | 138 | if len(X_buffer) != 0: 139 | X_buffer = np.array(X_buffer) 140 | _serialize_batch(X_buffer, writer) # serialize the remaining images in buffer 141 | 142 | del X_buffer # delete buffer 143 | 144 | print("Processed %d / %d images" % (nb_files, nb_files)) 145 | print("Finished creating TF Record") 146 | 147 | writer.close() 148 | 149 | 150 | def _serialize_batch(X, writer, batch_size=100): 151 | ''' 152 | Processes a batch of images, and then serializes into the TFRecord 153 | 154 | Args: 155 | X: original image with no preprocessing 156 | writer: TFRecordWriter 157 | batch_size: batch size 158 | ''' 159 | [X_batch, features], Y_batch = _process_batch(X, batch_size) # preprocess batch 160 | 161 | for j, (img_l, embed, y) in enumerate(zip(X_batch, features, Y_batch)): 162 | # resize the images to their smaller size to reduce space wastage in the record 163 | img_l = resize(img_l, (IMAGE_SIZE, IMAGE_SIZE, 1), mode='constant') 164 | y = resize(y, (IMAGE_SIZE, IMAGE_SIZE, 2), mode='constant') 165 | 166 | example_dict = { 167 | 'image_l': _float32_feature_list(img_l.flatten()), 168 | 'image_ab': _float32_feature_list(y.flatten()), 169 | 'image_features': _float32_feature_list(embed.flatten()) 170 | } 171 | example_feature = tf.train.Features(feature=example_dict) 172 | example = tf.train.Example(features=example_feature) 173 | writer.write(example.SerializeToString()) 174 | 175 | 176 | def _construct_dataset(record_path, batch_size, sess): 177 | def parse_record(serialized_example): 178 | # parse a single record 179 | features = tf.parse_single_example( 180 | serialized_example, 181 | features={ 182 | 'image_l': tf.FixedLenFeature([IMAGE_SIZE, IMAGE_SIZE, 1], tf.float32), 183 | 'image_ab': tf.FixedLenFeature([IMAGE_SIZE, IMAGE_SIZE, 2], tf.float32), 184 | 'image_features': tf.FixedLenFeature([1000, ], tf.float32) 185 | }) 186 | 187 | l, ab, embed = features['image_l'], features['image_ab'], features['image_features'] 188 | return l, ab, embed 189 | 190 | dataset = tfdata.TFRecordDataset([record_path], 'ZLIB') # create a Dataset to wrap the TFRecord 191 | dataset = dataset.map(parse_record, num_parallel_calls=2) # parse the record 192 | dataset = dataset.repeat() # repeat forever 193 | dataset = dataset.batch(batch_size) # batch into the required batchsize 194 | dataset = dataset.shuffle(buffer_size=5) # shuffle the batches 195 | iterator = dataset.make_initializable_iterator() # get an iterator over the dataset 196 | 197 | sess.run(iterator.initializer) # initialize the iterator 198 | next_batch = iterator.get_next() # get the iterator Tensor 199 | 200 | return dataset, next_batch 201 | 202 | 203 | def _process_batch(X, batchsize=100): 204 | ''' 205 | Process a batch of images for training 206 | 207 | Args: 208 | X: a RGB image 209 | ''' 210 | grayscaled_rgb = gray2rgb(rgb2gray(X)) # convert to 3 channeled grayscale image 211 | lab_batch = rgb2lab(X) # convert to LAB colorspace 212 | X_batch = lab_batch[:, :, :, 0] # extract L from LAB 213 | X_batch = X_batch.reshape(X_batch.shape + (1,)) # reshape into (batch, IMAGE_SIZE, IMAGE_SIZE, 1) 214 | X_batch = 2 * X_batch / 100 - 1. # normalize the batch 215 | Y_batch = lab_batch[:, :, :, 1:] / 127 # extract AB from LAB 216 | features = _extract_features(grayscaled_rgb, batchsize) # extract features from the grayscale image 217 | 218 | return ([X_batch, features], Y_batch) 219 | 220 | 221 | def generate_train_records(batch_size=100): 222 | _generate_records(TRAIN_IMAGE_PATH, TRAIN_RECORDS_PATH, batch_size) 223 | 224 | 225 | def generate_validation_records(batch_size=100): 226 | _generate_records(VALIDATION_IMAGE_PATH, VAL_RECORDS_PATH, batch_size) 227 | 228 | 229 | def train_generator(batch_size): 230 | ''' 231 | Generator which wraps a tf.data.Dataset object to read in the 232 | TFRecord more conveniently. 233 | ''' 234 | if not os.path.exists(TRAIN_RECORDS_PATH): 235 | print("\n\n", '*' * 50, "\n") 236 | print("Please create the TFRecord of this dataset by running `data_utils.py` script") 237 | exit(0) 238 | 239 | with tf.Session() as train_gen_session: 240 | dataset, next_batch = _construct_dataset(TRAIN_RECORDS_PATH, batch_size, train_gen_session) 241 | 242 | while True: 243 | try: 244 | l, ab, features = train_gen_session.run(next_batch) # retrieve a batch of records 245 | yield ([l, features], ab) 246 | except: 247 | # if it crashes due to some reason 248 | iterator = dataset.make_initializable_iterator() 249 | train_gen_session.run(iterator.initializer) 250 | next_batch = iterator.get_next() 251 | 252 | l, ab, features = train_gen_session.run(next_batch) 253 | yield ([l, features], ab) 254 | 255 | 256 | def val_batch_generator(batch_size): 257 | ''' 258 | Generator which wraps a tf.data.Dataset object to read in the 259 | TFRecord more conveniently. 260 | ''' 261 | if not os.path.exists(VAL_RECORDS_PATH): 262 | print("\n\n", '*' * 50, "\n") 263 | print("Please create the TFRecord of this dataset by running `data_utils.py` script with validation data") 264 | exit(0) 265 | 266 | with tf.Session() as val_generator_session: 267 | dataset, next_batch = _construct_dataset(VAL_RECORDS_PATH, batch_size, val_generator_session) 268 | 269 | while True: 270 | try: 271 | l, ab, features = val_generator_session.run(next_batch) # retrieve a batch of records 272 | yield ([l, features], ab) 273 | except: 274 | # if it crashes due to some reason 275 | iterator = dataset.make_initializable_iterator() 276 | val_generator_session.run(iterator.initializer) 277 | next_batch = iterator.get_next() 278 | 279 | l, ab, features = val_generator_session.run(next_batch) 280 | yield ([l, features], ab) 281 | 282 | 283 | def prepare_input_image_batch(X, batchsize=100): 284 | ''' 285 | This is a helper function which does the same as _preprocess_batch, 286 | but it is meant to be used with images during testing, not training. 287 | 288 | Args: 289 | X: A grayscale image 290 | ''' 291 | X_processed = X / 255. # normalize grayscale image 292 | X_grayscale = gray2rgb(rgb2gray(X_processed)) 293 | X_features = _extract_features(X_grayscale, batchsize) 294 | X_lab = rgb2lab(X_grayscale)[:, :, :, 0] 295 | X_lab = X_lab.reshape(X_lab.shape + (1,)) 296 | X_lab = 2 * X_lab / 100 - 1. 297 | 298 | return X_lab, X_features 299 | 300 | 301 | def postprocess_output(X_lab, y, image_size=None): 302 | ''' 303 | This is a helper function for test time to convert and save the 304 | the processed image into the 'results' directory. 305 | 306 | Args: 307 | X_lab: L channel extracted from the grayscale image 308 | y: AB channels predicted by the colorizer network 309 | image_size: output image size 310 | ''' 311 | y *= 127. # scale the predictions to [-127, 127] 312 | X_lab = (X_lab + 1) * 50. # scale the L channel to [0, 100] 313 | 314 | image_size = IMAGE_SIZE if image_size is None else image_size # set a default image size if needed 315 | 316 | for i in range(len(y)): 317 | cur = np.zeros((image_size, image_size, 3)) 318 | cur[:, :, 0] = X_lab[i, :, :, 0] 319 | cur[:, :, 1:] = y[i] 320 | imsave("results/img_%d.png" % (i + 1), lab2rgb(cur)) 321 | 322 | if i % (len(y) // 20) == 0: 323 | print("Finished processing %0.2f percentage of images" % (i / float(len(y)) * 100)) 324 | 325 | 326 | if __name__ == '__main__': 327 | # generate the train tf record file 328 | generate_train_records(batch_size=200) 329 | 330 | # generate the validation tf record file 331 | generate_validation_records(batch_size=100) -------------------------------------------------------------------------------- /evaluate_mobilenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from keras.preprocessing.image import img_to_array, load_img 5 | from data_utils import prepare_input_image_batch, postprocess_output, resize 6 | from model_utils import generate_mobilenet_model 7 | 8 | 9 | IMAGE_FOLDER_PATH = r"D:\Yue\Documents\Datasets\MSCOCO\val\valset\\" 10 | batch_size = 10 11 | image_size = 256 12 | 13 | model = generate_mobilenet_model(img_size=image_size) 14 | model.load_weights('weights/mobilenet_model_v2.h5') 15 | 16 | X = [] 17 | files = os.listdir(IMAGE_FOLDER_PATH) 18 | 19 | files = files[:100] 20 | for i, filename in enumerate(files): 21 | img = img_to_array(load_img(os.path.join(IMAGE_FOLDER_PATH, filename))) / 255. 22 | img = resize(img, (image_size, image_size, 3)) * 255. # resize needs floats to be in 0-1 range, preprocess needs in 0-255 range 23 | X.append(img) 24 | 25 | if i % (len(files) // 20) == 0: 26 | print("Loaded %0.2f percentage of images from directory" % (i / float(len(files)) * 100)) 27 | 28 | X = np.array(X, dtype='float32') 29 | print("Images loaded. Shape = ", X.shape) 30 | 31 | X_lab, X_features = prepare_input_image_batch(X, batchsize=batch_size) 32 | predictions = model.predict([X_lab, X_features], batch_size, verbose=1) 33 | 34 | postprocess_output(X_lab, predictions, image_size=image_size) 35 | 36 | -------------------------------------------------------------------------------- /images/page1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/images/page1.png -------------------------------------------------------------------------------- /images/page2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/images/page2.png -------------------------------------------------------------------------------- /images/page3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/images/page3.png -------------------------------------------------------------------------------- /images/skip_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/images/skip_model.png -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, Input, Reshape, RepeatVector, concatenate, UpSampling2D, Flatten, Conv2DTranspose 2 | from keras.models import Model 3 | 4 | from keras import backend as K 5 | from keras.losses import mean_squared_error 6 | from keras.optimizers import Adam 7 | 8 | mse_weight = 1.0 #1e-3 9 | 10 | # set these to zeros to prevent learning 11 | perceptual_weight = 1. / (2. * 128. * 128.) # scaling factor 12 | attention_weight = 1.0 # 1.0 13 | 14 | 15 | # shows the minimum value of the AB channels 16 | def y_true_min(yt, yp): 17 | return K.min(yt) 18 | 19 | 20 | # shows the maximum value of the RGB AB channels 21 | def y_true_max(yt, yp): 22 | return K.max(yt) 23 | 24 | 25 | # shows the minimum value of the predicted AB channels 26 | def y_pred_min(yt, yp): 27 | return K.min(yp) 28 | 29 | 30 | # shows the maximum value of the predicted AB channels 31 | def y_pred_max(yt, yp): 32 | return K.max(yp) 33 | 34 | 35 | def gram_matrix(x): 36 | assert K.ndim(x) == 4 37 | 38 | with K.name_scope('gram_matrix'): 39 | if K.image_data_format() == "channels_first": 40 | batch, channels, width, height = K.int_shape(x) 41 | features = K.batch_flatten(x) 42 | else: 43 | batch, width, height, channels = K.int_shape(x) 44 | features = K.batch_flatten(K.permute_dimensions(x, (0, 3, 1, 2))) 45 | 46 | gram = K.dot(features, K.transpose(features)) # / (channels * width * height) 47 | return gram 48 | 49 | 50 | def l2_norm(x): 51 | return K.sqrt(K.sum(K.square(x))) 52 | 53 | 54 | def attention_vector(x): 55 | if K.image_data_format() == "channels_first": 56 | batch, channels, width, height = K.int_shape(x) 57 | filters = K.batch_flatten(K.permute_dimensions(x, (1, 0, 2, 3))) # (channels, batch*width*height) 58 | else: 59 | batch, width, height, channels = K.int_shape(x) 60 | filters = K.batch_flatten(K.permute_dimensions(x, (3, 0, 1, 2))) # (channels, batch*width*height) 61 | 62 | filters = K.mean(K.square(filters), axis=0) # (batch*width*height,) 63 | filters = filters / l2_norm(filters) # (batch*width*height,) 64 | return filters 65 | 66 | 67 | def total_loss(y_true, y_pred): 68 | mse_loss = mse_weight * mean_squared_error(y_true, y_pred) 69 | perceptual_loss = perceptual_weight * K.sum(K.square(gram_matrix(y_true) - gram_matrix(y_pred))) 70 | attention_loss = attention_weight * l2_norm(attention_vector(y_true) - attention_vector(y_pred)) 71 | 72 | return mse_loss + perceptual_loss + attention_loss 73 | 74 | 75 | def generate_mobilenet_model(lr=1e-3, img_size=128): 76 | ''' 77 | Creates a Colorizer model. Note the difference from the report 78 | - https://github.com/baldassarreFe/deep-koalarization/blob/master/report.pdf 79 | 80 | I use a long skip connection network to speed up convergence and 81 | boost the output quality. 82 | ''' 83 | # encoder model 84 | encoder_ip = Input(shape=(img_size, img_size, 1)) 85 | encoder1 = Conv2D(64, (3, 3), padding='same', activation='relu', strides=(2, 2))(encoder_ip) 86 | encoder = Conv2D(128, (3, 3), padding='same', activation='relu')(encoder1) 87 | encoder2 = Conv2D(128, (3, 3), padding='same', activation='relu', strides=(2, 2))(encoder) 88 | encoder = Conv2D(256, (3, 3), padding='same', activation='relu')(encoder2) 89 | encoder = Conv2D(256, (3, 3), padding='same', activation='relu', strides=(2, 2))(encoder) 90 | encoder = Conv2D(512, (3, 3), padding='same', activation='relu')(encoder) 91 | encoder = Conv2D(512, (3, 3), padding='same', activation='relu')(encoder) 92 | encoder = Conv2D(256, (3, 3), padding='same', activation='relu')(encoder) 93 | 94 | # input fusion 95 | # Decide the image shape at runtime to allow prediction on 96 | # any size image, even if training is on 128x128 97 | batch, height, width, channels = K.int_shape(encoder) 98 | 99 | mobilenet_features_ip = Input(shape=(1000,)) 100 | fusion = RepeatVector(height * width)(mobilenet_features_ip) 101 | fusion = Reshape((height, width, 1000))(fusion) 102 | fusion = concatenate([encoder, fusion], axis=-1) 103 | fusion = Conv2D(256, (1, 1), padding='same', activation='relu')(fusion) 104 | 105 | # decoder model 106 | decoder = Conv2D(128, (3, 3), padding='same', activation='relu')(fusion) 107 | decoder = UpSampling2D()(decoder) 108 | #decoder = Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', activation='relu')(decoder) 109 | decoder = concatenate([decoder, encoder2], axis=-1) 110 | decoder = Conv2D(64, (3, 3), padding='same', activation='relu')(decoder) 111 | decoder = Conv2D(64, (3, 3), padding='same', activation='relu')(decoder) 112 | decoder = UpSampling2D()(decoder) 113 | #decoder = Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', activation='relu')(decoder) 114 | decoder = concatenate([decoder, encoder1], axis=-1) 115 | decoder = Conv2D(32, (3, 3), padding='same', activation='relu')(decoder) 116 | decoder = Conv2DTranspose(2, (4, 4), strides=(2, 2), padding='same', activation='tanh')(decoder) 117 | # decoder = Conv2D(2, (3, 3), padding='same', activation='tanh')(decoder) 118 | # decoder = UpSampling2D((2, 2))(decoder) 119 | 120 | model = Model([encoder_ip, mobilenet_features_ip], decoder, name='Colorizer') 121 | model.compile(optimizer=Adam(lr), loss=total_loss, metrics=[y_true_max, 122 | y_true_min, 123 | y_pred_max, 124 | y_pred_min]) 125 | 126 | print("Colorization model built and compiled") 127 | return model 128 | 129 | 130 | if __name__ == '__main__': 131 | model = generate_mobilenet_model() 132 | model.summary() 133 | 134 | from keras.utils.vis_utils import plot_model 135 | 136 | plot_model(model, to_file='skip_model.png', show_shapes=True) 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras 2 | numpy 3 | scikit-image 4 | tensorflow 5 | -------------------------------------------------------------------------------- /train_mobilenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data_utils import train_generator, val_batch_generator 4 | from model_utils import generate_mobilenet_model 5 | from train_utils import TensorBoardBatch 6 | 7 | from keras.callbacks import ModelCheckpoint 8 | 9 | nb_train_images = 60000 # there are 82783 images in MS-COCO, set this to how many samples you want to train on. 10 | batch_size = 125 11 | 12 | model = generate_mobilenet_model(lr=1e-3) 13 | model.summary() 14 | 15 | # continue training if weights are available 16 | #if os.path.exists('weights/mobilenet_model.h5'): 17 | # model.load_weights('weights/mobilenet_model.h5') 18 | 19 | # use Batchwise TensorBoard callback 20 | tensorboard = TensorBoardBatch(batch_size=batch_size) 21 | checkpoint = ModelCheckpoint('weights/mobilenet_model_v2.h5', monitor='loss', verbose=1, 22 | save_best_only=True, save_weights_only=True) 23 | callbacks = [checkpoint, tensorboard] 24 | 25 | 26 | model.fit_generator(generator=train_generator(batch_size), 27 | steps_per_epoch=nb_train_images // batch_size, 28 | epochs=100, 29 | verbose=1, 30 | callbacks=callbacks, 31 | validation_data=val_batch_generator(batch_size), 32 | validation_steps=1 33 | ) 34 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | from keras.callbacks import TensorBoard 2 | import tensorflow as tf 3 | 4 | ''' 5 | Below is a modification to the TensorBoard callback to perform 6 | batchwise writing to the tensorboard, instead of only at the end 7 | of the batch. 8 | ''' 9 | class TensorBoardBatch(TensorBoard): 10 | def __init__(self, *args, **kwargs): 11 | super(TensorBoardBatch, self).__init__(*args) 12 | 13 | def on_batch_end(self, batch, logs=None): 14 | logs = logs or {} 15 | 16 | for name, value in logs.items(): 17 | if name in ['batch', 'size']: 18 | continue 19 | summary = tf.Summary() 20 | summary_value = summary.value.add() 21 | summary_value.simple_value = value.item() 22 | summary_value.tag = name 23 | self.writer.add_summary(summary, batch) 24 | 25 | self.writer.flush() 26 | 27 | def on_epoch_end(self, epoch, logs=None): 28 | logs = logs or {} 29 | 30 | for name, value in logs.items(): 31 | if name in ['batch', 'size']: 32 | continue 33 | summary = tf.Summary() 34 | summary_value = summary.value.add() 35 | summary_value.simple_value = value.item() 36 | summary_value.tag = name 37 | self.writer.add_summary(summary, epoch * self.batch_size) 38 | 39 | self.writer.flush() 40 | -------------------------------------------------------------------------------- /weights/mobilenet_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/weights/mobilenet_model.h5 -------------------------------------------------------------------------------- /weights/mobilenet_model_v2.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/weights/mobilenet_model_v2.h5 --------------------------------------------------------------------------------