├── .gitignore
├── README.md
├── data_utils.py
├── evaluate_mobilenet.py
├── images
├── page1.png
├── page2.png
├── page3.png
└── skip_model.png
├── model_utils.py
├── requirements.txt
├── train_mobilenet.py
├── train_utils.py
└── weights
├── mobilenet_model.h5
└── mobilenet_model_v2.h5
/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | logs/*
3 | logs-old/*
4 |
5 | .idea/*
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | .static_storage/
62 | .media/
63 | local_settings.py
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # pyenv
82 | .python-version
83 |
84 | # celery beat schedule file
85 | celerybeat-schedule
86 |
87 | # SageMath parsed files
88 | *.sage.py
89 |
90 | # Environments
91 | .env
92 | .venv
93 | env/
94 | venv/
95 | ENV/
96 | env.bak/
97 | venv.bak/
98 |
99 | # Spyder project settings
100 | .spyderproject
101 | .spyproject
102 |
103 | # Rope project settings
104 | .ropeproject
105 |
106 | # mkdocs documentation
107 | /site
108 |
109 | # mypy
110 | .mypy_cache/
111 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Keras Mobile Colorizer
2 |
3 | Utilizes a U-Net inspired model conditioned on MobileNet class features to generate a mapping from Grayscale to Color image.
4 | Based on the work https://github.com/baldassarreFe/deep-koalarization
5 |
6 | Uses MobileNets for memory efficiency in comparison to Inception-ResNet-V2 so that training can be done on a single GPU (of 4 GB size minimum).
7 |
8 | # Installation
9 | Open the `data_utils.py` script and edit the `TRAIN_IMAGE_PATH` and `VALIDATION_IMAGE_PATH` to point to directories of images. There must be at least 1 folder pointed to by each of those paths.
10 |
11 | Then run `data_utils.py` to construct the required folders and the TFRecords which will store the training data.
12 |
13 | This is necessary to drastically improve the speed of training by extracting all the MobileNet features from each training image before training. The major bottleneck during training is the extraction of image features from MobileNet at runtime.
14 |
15 | # Training & Evaluation
16 |
17 | - To train the model : Use the `train_mobilenet.py` script. Make sure to verify the batch size and how many images are in the TF record before beginning training.
18 |
19 | - To evaluate the model : Use the `evaluate_mobilenet.py` script. Make sure that the path to the validation images is provided in `data_utils.py`
20 |
21 | # Evaluation
22 | There are a lot of splotchy reddish-brown patches. This may probably be because training was done using only 60k images from MS-COCO dataset, not the full ImageNet dataset.
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | # Requirements
31 | - Keras 2.0.8+
32 | - Numpy
33 | - Scikit-image
34 | - Tensorflow (GPU is a must for training, CPU is fine for inference)
35 |
36 | Install via `pip install -r "requirements.txt"`
37 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 |
5 | from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
6 | from skimage.transform import resize
7 | from skimage.io import imsave, imread
8 |
9 | import tensorflow as tf
10 | from tensorflow import data as tfdata
11 | sess = tf.Session()
12 |
13 | from keras import backend as K
14 | K.set_session(sess)
15 |
16 | sess = None
17 |
18 | from keras.applications.mobilenet import MobileNet, preprocess_input
19 | from keras.models import Model
20 |
21 | # change these to your local paths
22 | TRAIN_IMAGE_PATH = r""
23 | VALIDATION_IMAGE_PATH = r"D:\Yue\Documents\Datasets\MSCOCO\val\\"
24 |
25 | IMAGE_SIZE = 128 # Global constant image size
26 | EMBEDDING_IMAGE_SIZE = 224 # Global constant embedding size
27 |
28 | TRAIN_RECORDS_PATH = "data/images.tfrecord" # local path to tf record directory
29 | VAL_RECORDS_PATH = "data/val_images.tfrecord" # local path to tf record directory
30 |
31 |
32 | if not os.path.exists('weights/'):
33 | os.makedirs('weights/')
34 |
35 | if not os.path.exists('results/'):
36 | os.makedirs('results/')
37 |
38 | if not os.path.exists('data/'):
39 | os.makedirs('data/')
40 |
41 | feature_extraction_model = None
42 | mobilenet_activations = None
43 |
44 | def _load_mobilenet():
45 | global feature_extraction_model, mobilenet_activations
46 |
47 | # Feature extraction module
48 | feature_extraction_model = MobileNet(input_shape=(EMBEDDING_IMAGE_SIZE, EMBEDDING_IMAGE_SIZE, 3),
49 | alpha=1.0,
50 | depth_multiplier=1,
51 | include_top=True,
52 | weights='imagenet')
53 |
54 | # Set it up so that we can do inference on MobileNet without training it by mistake
55 | feature_extraction_model.graph = tf.get_default_graph()
56 | feature_extraction_model.trainable = False
57 |
58 | # Get the pre-softmax activations from MobileNet
59 | mobilenet_activations = Model(feature_extraction_model.input, feature_extraction_model.layers[-3].output)
60 | mobilenet_activations.trainable = False
61 |
62 |
63 | def _get_pre_activations(grayscale_image, batchsize=100):
64 | # batchwise retrieve feature map from last layer - pre softmax
65 | activations = mobilenet_activations.predict(grayscale_image, batch_size=batchsize)
66 | return activations
67 |
68 |
69 | def _extract_features(grayscaled_rgb, batchsize=100):
70 | # Load up MobileNet only when necessary, not during training
71 | if feature_extraction_model is None:
72 | _load_mobilenet()
73 |
74 | grayscaled_rgb_resized = []
75 |
76 | for i in grayscaled_rgb:
77 | # Resize to size of MobileNet Input
78 | i = resize(i, (EMBEDDING_IMAGE_SIZE, EMBEDDING_IMAGE_SIZE, 3), mode='constant')
79 | grayscaled_rgb_resized.append(i)
80 |
81 | grayscaled_rgb_resized = np.array(grayscaled_rgb_resized) * 255. # scale to 0-255 range for MobileNet preprocess_input
82 | grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
83 |
84 | with feature_extraction_model.graph.as_default(): # using the shared graph of Colorization model and MobileNet
85 | features = _get_pre_activations(grayscaled_rgb_resized, batchsize) # batchwise get the feature maps
86 | features = features.reshape((-1, 1000))
87 |
88 | return features
89 |
90 |
91 | def _float32_feature_list(floats):
92 | return tf.train.Feature(float_list=tf.train.FloatList(value=floats))
93 |
94 |
95 | def _generate_records(images_path, tf_record_name, batch_size=100):
96 | '''
97 | Creates a TF Record containing the pre-processed image consisting of
98 | 1) L channel input
99 | 2) ab channels output
100 | 3) features extracted from MobileNet
101 |
102 | This step is crucial for speed during training, as the major bottleneck
103 | is the extraction of feature maps from MobileNet. It is slow, and inefficient.
104 | '''
105 | if os.path.exists(TRAIN_RECORDS_PATH):
106 | print("**** Delete old TF Records first! ****")
107 | exit(0)
108 |
109 | files = glob.glob(images_path + "*/*.jpg")
110 | files = sorted(files)
111 | nb_files = len(files)
112 |
113 | # Use ZLIB compression to save space and create a TFRecordWriter
114 | options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
115 | writer = tf.python_io.TFRecordWriter(tf_record_name, options)
116 |
117 | size = max(EMBEDDING_IMAGE_SIZE, IMAGE_SIZE) # keep larger size until stored in TF Record
118 |
119 | X_buffer = []
120 | for i, fn in enumerate(files):
121 | try: # prevent crash due to corrupted imaged
122 | X = imread(fn)
123 | X = resize(X, (size, size, 3), mode='constant') # resize to the larger size for now
124 | except:
125 | continue
126 |
127 | X_buffer.append(X)
128 |
129 | if len(X_buffer) >= batch_size:
130 | X_buffer = np.array(X_buffer)
131 | _serialize_batch(X_buffer, writer, batch_size) # serialize the image into the TF Record
132 |
133 | del X_buffer # delete buffered images from memory
134 | X_buffer = [] # reset to new list
135 |
136 | print("Processed %d / %d images" % (i + 1, nb_files))
137 |
138 | if len(X_buffer) != 0:
139 | X_buffer = np.array(X_buffer)
140 | _serialize_batch(X_buffer, writer) # serialize the remaining images in buffer
141 |
142 | del X_buffer # delete buffer
143 |
144 | print("Processed %d / %d images" % (nb_files, nb_files))
145 | print("Finished creating TF Record")
146 |
147 | writer.close()
148 |
149 |
150 | def _serialize_batch(X, writer, batch_size=100):
151 | '''
152 | Processes a batch of images, and then serializes into the TFRecord
153 |
154 | Args:
155 | X: original image with no preprocessing
156 | writer: TFRecordWriter
157 | batch_size: batch size
158 | '''
159 | [X_batch, features], Y_batch = _process_batch(X, batch_size) # preprocess batch
160 |
161 | for j, (img_l, embed, y) in enumerate(zip(X_batch, features, Y_batch)):
162 | # resize the images to their smaller size to reduce space wastage in the record
163 | img_l = resize(img_l, (IMAGE_SIZE, IMAGE_SIZE, 1), mode='constant')
164 | y = resize(y, (IMAGE_SIZE, IMAGE_SIZE, 2), mode='constant')
165 |
166 | example_dict = {
167 | 'image_l': _float32_feature_list(img_l.flatten()),
168 | 'image_ab': _float32_feature_list(y.flatten()),
169 | 'image_features': _float32_feature_list(embed.flatten())
170 | }
171 | example_feature = tf.train.Features(feature=example_dict)
172 | example = tf.train.Example(features=example_feature)
173 | writer.write(example.SerializeToString())
174 |
175 |
176 | def _construct_dataset(record_path, batch_size, sess):
177 | def parse_record(serialized_example):
178 | # parse a single record
179 | features = tf.parse_single_example(
180 | serialized_example,
181 | features={
182 | 'image_l': tf.FixedLenFeature([IMAGE_SIZE, IMAGE_SIZE, 1], tf.float32),
183 | 'image_ab': tf.FixedLenFeature([IMAGE_SIZE, IMAGE_SIZE, 2], tf.float32),
184 | 'image_features': tf.FixedLenFeature([1000, ], tf.float32)
185 | })
186 |
187 | l, ab, embed = features['image_l'], features['image_ab'], features['image_features']
188 | return l, ab, embed
189 |
190 | dataset = tfdata.TFRecordDataset([record_path], 'ZLIB') # create a Dataset to wrap the TFRecord
191 | dataset = dataset.map(parse_record, num_parallel_calls=2) # parse the record
192 | dataset = dataset.repeat() # repeat forever
193 | dataset = dataset.batch(batch_size) # batch into the required batchsize
194 | dataset = dataset.shuffle(buffer_size=5) # shuffle the batches
195 | iterator = dataset.make_initializable_iterator() # get an iterator over the dataset
196 |
197 | sess.run(iterator.initializer) # initialize the iterator
198 | next_batch = iterator.get_next() # get the iterator Tensor
199 |
200 | return dataset, next_batch
201 |
202 |
203 | def _process_batch(X, batchsize=100):
204 | '''
205 | Process a batch of images for training
206 |
207 | Args:
208 | X: a RGB image
209 | '''
210 | grayscaled_rgb = gray2rgb(rgb2gray(X)) # convert to 3 channeled grayscale image
211 | lab_batch = rgb2lab(X) # convert to LAB colorspace
212 | X_batch = lab_batch[:, :, :, 0] # extract L from LAB
213 | X_batch = X_batch.reshape(X_batch.shape + (1,)) # reshape into (batch, IMAGE_SIZE, IMAGE_SIZE, 1)
214 | X_batch = 2 * X_batch / 100 - 1. # normalize the batch
215 | Y_batch = lab_batch[:, :, :, 1:] / 127 # extract AB from LAB
216 | features = _extract_features(grayscaled_rgb, batchsize) # extract features from the grayscale image
217 |
218 | return ([X_batch, features], Y_batch)
219 |
220 |
221 | def generate_train_records(batch_size=100):
222 | _generate_records(TRAIN_IMAGE_PATH, TRAIN_RECORDS_PATH, batch_size)
223 |
224 |
225 | def generate_validation_records(batch_size=100):
226 | _generate_records(VALIDATION_IMAGE_PATH, VAL_RECORDS_PATH, batch_size)
227 |
228 |
229 | def train_generator(batch_size):
230 | '''
231 | Generator which wraps a tf.data.Dataset object to read in the
232 | TFRecord more conveniently.
233 | '''
234 | if not os.path.exists(TRAIN_RECORDS_PATH):
235 | print("\n\n", '*' * 50, "\n")
236 | print("Please create the TFRecord of this dataset by running `data_utils.py` script")
237 | exit(0)
238 |
239 | with tf.Session() as train_gen_session:
240 | dataset, next_batch = _construct_dataset(TRAIN_RECORDS_PATH, batch_size, train_gen_session)
241 |
242 | while True:
243 | try:
244 | l, ab, features = train_gen_session.run(next_batch) # retrieve a batch of records
245 | yield ([l, features], ab)
246 | except:
247 | # if it crashes due to some reason
248 | iterator = dataset.make_initializable_iterator()
249 | train_gen_session.run(iterator.initializer)
250 | next_batch = iterator.get_next()
251 |
252 | l, ab, features = train_gen_session.run(next_batch)
253 | yield ([l, features], ab)
254 |
255 |
256 | def val_batch_generator(batch_size):
257 | '''
258 | Generator which wraps a tf.data.Dataset object to read in the
259 | TFRecord more conveniently.
260 | '''
261 | if not os.path.exists(VAL_RECORDS_PATH):
262 | print("\n\n", '*' * 50, "\n")
263 | print("Please create the TFRecord of this dataset by running `data_utils.py` script with validation data")
264 | exit(0)
265 |
266 | with tf.Session() as val_generator_session:
267 | dataset, next_batch = _construct_dataset(VAL_RECORDS_PATH, batch_size, val_generator_session)
268 |
269 | while True:
270 | try:
271 | l, ab, features = val_generator_session.run(next_batch) # retrieve a batch of records
272 | yield ([l, features], ab)
273 | except:
274 | # if it crashes due to some reason
275 | iterator = dataset.make_initializable_iterator()
276 | val_generator_session.run(iterator.initializer)
277 | next_batch = iterator.get_next()
278 |
279 | l, ab, features = val_generator_session.run(next_batch)
280 | yield ([l, features], ab)
281 |
282 |
283 | def prepare_input_image_batch(X, batchsize=100):
284 | '''
285 | This is a helper function which does the same as _preprocess_batch,
286 | but it is meant to be used with images during testing, not training.
287 |
288 | Args:
289 | X: A grayscale image
290 | '''
291 | X_processed = X / 255. # normalize grayscale image
292 | X_grayscale = gray2rgb(rgb2gray(X_processed))
293 | X_features = _extract_features(X_grayscale, batchsize)
294 | X_lab = rgb2lab(X_grayscale)[:, :, :, 0]
295 | X_lab = X_lab.reshape(X_lab.shape + (1,))
296 | X_lab = 2 * X_lab / 100 - 1.
297 |
298 | return X_lab, X_features
299 |
300 |
301 | def postprocess_output(X_lab, y, image_size=None):
302 | '''
303 | This is a helper function for test time to convert and save the
304 | the processed image into the 'results' directory.
305 |
306 | Args:
307 | X_lab: L channel extracted from the grayscale image
308 | y: AB channels predicted by the colorizer network
309 | image_size: output image size
310 | '''
311 | y *= 127. # scale the predictions to [-127, 127]
312 | X_lab = (X_lab + 1) * 50. # scale the L channel to [0, 100]
313 |
314 | image_size = IMAGE_SIZE if image_size is None else image_size # set a default image size if needed
315 |
316 | for i in range(len(y)):
317 | cur = np.zeros((image_size, image_size, 3))
318 | cur[:, :, 0] = X_lab[i, :, :, 0]
319 | cur[:, :, 1:] = y[i]
320 | imsave("results/img_%d.png" % (i + 1), lab2rgb(cur))
321 |
322 | if i % (len(y) // 20) == 0:
323 | print("Finished processing %0.2f percentage of images" % (i / float(len(y)) * 100))
324 |
325 |
326 | if __name__ == '__main__':
327 | # generate the train tf record file
328 | generate_train_records(batch_size=200)
329 |
330 | # generate the validation tf record file
331 | generate_validation_records(batch_size=100)
--------------------------------------------------------------------------------
/evaluate_mobilenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from keras.preprocessing.image import img_to_array, load_img
5 | from data_utils import prepare_input_image_batch, postprocess_output, resize
6 | from model_utils import generate_mobilenet_model
7 |
8 |
9 | IMAGE_FOLDER_PATH = r"D:\Yue\Documents\Datasets\MSCOCO\val\valset\\"
10 | batch_size = 10
11 | image_size = 256
12 |
13 | model = generate_mobilenet_model(img_size=image_size)
14 | model.load_weights('weights/mobilenet_model_v2.h5')
15 |
16 | X = []
17 | files = os.listdir(IMAGE_FOLDER_PATH)
18 |
19 | files = files[:100]
20 | for i, filename in enumerate(files):
21 | img = img_to_array(load_img(os.path.join(IMAGE_FOLDER_PATH, filename))) / 255.
22 | img = resize(img, (image_size, image_size, 3)) * 255. # resize needs floats to be in 0-1 range, preprocess needs in 0-255 range
23 | X.append(img)
24 |
25 | if i % (len(files) // 20) == 0:
26 | print("Loaded %0.2f percentage of images from directory" % (i / float(len(files)) * 100))
27 |
28 | X = np.array(X, dtype='float32')
29 | print("Images loaded. Shape = ", X.shape)
30 |
31 | X_lab, X_features = prepare_input_image_batch(X, batchsize=batch_size)
32 | predictions = model.predict([X_lab, X_features], batch_size, verbose=1)
33 |
34 | postprocess_output(X_lab, predictions, image_size=image_size)
35 |
36 |
--------------------------------------------------------------------------------
/images/page1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/images/page1.png
--------------------------------------------------------------------------------
/images/page2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/images/page2.png
--------------------------------------------------------------------------------
/images/page3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/images/page3.png
--------------------------------------------------------------------------------
/images/skip_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/images/skip_model.png
--------------------------------------------------------------------------------
/model_utils.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D, Input, Reshape, RepeatVector, concatenate, UpSampling2D, Flatten, Conv2DTranspose
2 | from keras.models import Model
3 |
4 | from keras import backend as K
5 | from keras.losses import mean_squared_error
6 | from keras.optimizers import Adam
7 |
8 | mse_weight = 1.0 #1e-3
9 |
10 | # set these to zeros to prevent learning
11 | perceptual_weight = 1. / (2. * 128. * 128.) # scaling factor
12 | attention_weight = 1.0 # 1.0
13 |
14 |
15 | # shows the minimum value of the AB channels
16 | def y_true_min(yt, yp):
17 | return K.min(yt)
18 |
19 |
20 | # shows the maximum value of the RGB AB channels
21 | def y_true_max(yt, yp):
22 | return K.max(yt)
23 |
24 |
25 | # shows the minimum value of the predicted AB channels
26 | def y_pred_min(yt, yp):
27 | return K.min(yp)
28 |
29 |
30 | # shows the maximum value of the predicted AB channels
31 | def y_pred_max(yt, yp):
32 | return K.max(yp)
33 |
34 |
35 | def gram_matrix(x):
36 | assert K.ndim(x) == 4
37 |
38 | with K.name_scope('gram_matrix'):
39 | if K.image_data_format() == "channels_first":
40 | batch, channels, width, height = K.int_shape(x)
41 | features = K.batch_flatten(x)
42 | else:
43 | batch, width, height, channels = K.int_shape(x)
44 | features = K.batch_flatten(K.permute_dimensions(x, (0, 3, 1, 2)))
45 |
46 | gram = K.dot(features, K.transpose(features)) # / (channels * width * height)
47 | return gram
48 |
49 |
50 | def l2_norm(x):
51 | return K.sqrt(K.sum(K.square(x)))
52 |
53 |
54 | def attention_vector(x):
55 | if K.image_data_format() == "channels_first":
56 | batch, channels, width, height = K.int_shape(x)
57 | filters = K.batch_flatten(K.permute_dimensions(x, (1, 0, 2, 3))) # (channels, batch*width*height)
58 | else:
59 | batch, width, height, channels = K.int_shape(x)
60 | filters = K.batch_flatten(K.permute_dimensions(x, (3, 0, 1, 2))) # (channels, batch*width*height)
61 |
62 | filters = K.mean(K.square(filters), axis=0) # (batch*width*height,)
63 | filters = filters / l2_norm(filters) # (batch*width*height,)
64 | return filters
65 |
66 |
67 | def total_loss(y_true, y_pred):
68 | mse_loss = mse_weight * mean_squared_error(y_true, y_pred)
69 | perceptual_loss = perceptual_weight * K.sum(K.square(gram_matrix(y_true) - gram_matrix(y_pred)))
70 | attention_loss = attention_weight * l2_norm(attention_vector(y_true) - attention_vector(y_pred))
71 |
72 | return mse_loss + perceptual_loss + attention_loss
73 |
74 |
75 | def generate_mobilenet_model(lr=1e-3, img_size=128):
76 | '''
77 | Creates a Colorizer model. Note the difference from the report
78 | - https://github.com/baldassarreFe/deep-koalarization/blob/master/report.pdf
79 |
80 | I use a long skip connection network to speed up convergence and
81 | boost the output quality.
82 | '''
83 | # encoder model
84 | encoder_ip = Input(shape=(img_size, img_size, 1))
85 | encoder1 = Conv2D(64, (3, 3), padding='same', activation='relu', strides=(2, 2))(encoder_ip)
86 | encoder = Conv2D(128, (3, 3), padding='same', activation='relu')(encoder1)
87 | encoder2 = Conv2D(128, (3, 3), padding='same', activation='relu', strides=(2, 2))(encoder)
88 | encoder = Conv2D(256, (3, 3), padding='same', activation='relu')(encoder2)
89 | encoder = Conv2D(256, (3, 3), padding='same', activation='relu', strides=(2, 2))(encoder)
90 | encoder = Conv2D(512, (3, 3), padding='same', activation='relu')(encoder)
91 | encoder = Conv2D(512, (3, 3), padding='same', activation='relu')(encoder)
92 | encoder = Conv2D(256, (3, 3), padding='same', activation='relu')(encoder)
93 |
94 | # input fusion
95 | # Decide the image shape at runtime to allow prediction on
96 | # any size image, even if training is on 128x128
97 | batch, height, width, channels = K.int_shape(encoder)
98 |
99 | mobilenet_features_ip = Input(shape=(1000,))
100 | fusion = RepeatVector(height * width)(mobilenet_features_ip)
101 | fusion = Reshape((height, width, 1000))(fusion)
102 | fusion = concatenate([encoder, fusion], axis=-1)
103 | fusion = Conv2D(256, (1, 1), padding='same', activation='relu')(fusion)
104 |
105 | # decoder model
106 | decoder = Conv2D(128, (3, 3), padding='same', activation='relu')(fusion)
107 | decoder = UpSampling2D()(decoder)
108 | #decoder = Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', activation='relu')(decoder)
109 | decoder = concatenate([decoder, encoder2], axis=-1)
110 | decoder = Conv2D(64, (3, 3), padding='same', activation='relu')(decoder)
111 | decoder = Conv2D(64, (3, 3), padding='same', activation='relu')(decoder)
112 | decoder = UpSampling2D()(decoder)
113 | #decoder = Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', activation='relu')(decoder)
114 | decoder = concatenate([decoder, encoder1], axis=-1)
115 | decoder = Conv2D(32, (3, 3), padding='same', activation='relu')(decoder)
116 | decoder = Conv2DTranspose(2, (4, 4), strides=(2, 2), padding='same', activation='tanh')(decoder)
117 | # decoder = Conv2D(2, (3, 3), padding='same', activation='tanh')(decoder)
118 | # decoder = UpSampling2D((2, 2))(decoder)
119 |
120 | model = Model([encoder_ip, mobilenet_features_ip], decoder, name='Colorizer')
121 | model.compile(optimizer=Adam(lr), loss=total_loss, metrics=[y_true_max,
122 | y_true_min,
123 | y_pred_max,
124 | y_pred_min])
125 |
126 | print("Colorization model built and compiled")
127 | return model
128 |
129 |
130 | if __name__ == '__main__':
131 | model = generate_mobilenet_model()
132 | model.summary()
133 |
134 | from keras.utils.vis_utils import plot_model
135 |
136 | plot_model(model, to_file='skip_model.png', show_shapes=True)
137 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | keras
2 | numpy
3 | scikit-image
4 | tensorflow
5 |
--------------------------------------------------------------------------------
/train_mobilenet.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data_utils import train_generator, val_batch_generator
4 | from model_utils import generate_mobilenet_model
5 | from train_utils import TensorBoardBatch
6 |
7 | from keras.callbacks import ModelCheckpoint
8 |
9 | nb_train_images = 60000 # there are 82783 images in MS-COCO, set this to how many samples you want to train on.
10 | batch_size = 125
11 |
12 | model = generate_mobilenet_model(lr=1e-3)
13 | model.summary()
14 |
15 | # continue training if weights are available
16 | #if os.path.exists('weights/mobilenet_model.h5'):
17 | # model.load_weights('weights/mobilenet_model.h5')
18 |
19 | # use Batchwise TensorBoard callback
20 | tensorboard = TensorBoardBatch(batch_size=batch_size)
21 | checkpoint = ModelCheckpoint('weights/mobilenet_model_v2.h5', monitor='loss', verbose=1,
22 | save_best_only=True, save_weights_only=True)
23 | callbacks = [checkpoint, tensorboard]
24 |
25 |
26 | model.fit_generator(generator=train_generator(batch_size),
27 | steps_per_epoch=nb_train_images // batch_size,
28 | epochs=100,
29 | verbose=1,
30 | callbacks=callbacks,
31 | validation_data=val_batch_generator(batch_size),
32 | validation_steps=1
33 | )
34 |
--------------------------------------------------------------------------------
/train_utils.py:
--------------------------------------------------------------------------------
1 | from keras.callbacks import TensorBoard
2 | import tensorflow as tf
3 |
4 | '''
5 | Below is a modification to the TensorBoard callback to perform
6 | batchwise writing to the tensorboard, instead of only at the end
7 | of the batch.
8 | '''
9 | class TensorBoardBatch(TensorBoard):
10 | def __init__(self, *args, **kwargs):
11 | super(TensorBoardBatch, self).__init__(*args)
12 |
13 | def on_batch_end(self, batch, logs=None):
14 | logs = logs or {}
15 |
16 | for name, value in logs.items():
17 | if name in ['batch', 'size']:
18 | continue
19 | summary = tf.Summary()
20 | summary_value = summary.value.add()
21 | summary_value.simple_value = value.item()
22 | summary_value.tag = name
23 | self.writer.add_summary(summary, batch)
24 |
25 | self.writer.flush()
26 |
27 | def on_epoch_end(self, epoch, logs=None):
28 | logs = logs or {}
29 |
30 | for name, value in logs.items():
31 | if name in ['batch', 'size']:
32 | continue
33 | summary = tf.Summary()
34 | summary_value = summary.value.add()
35 | summary_value.simple_value = value.item()
36 | summary_value.tag = name
37 | self.writer.add_summary(summary, epoch * self.batch_size)
38 |
39 | self.writer.flush()
40 |
--------------------------------------------------------------------------------
/weights/mobilenet_model.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/weights/mobilenet_model.h5
--------------------------------------------------------------------------------
/weights/mobilenet_model_v2.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-mobile-colorizer/0b013ba5e56fdfa4d16b9d0a7a847016084482d9/weights/mobilenet_model_v2.h5
--------------------------------------------------------------------------------