├── README.md
├── benchmark_model.py
├── data_utils.py
├── models
├── 28x28
│ ├── cpc.h5
│ ├── encoder.h5
│ └── supervised.h5
└── 64x64
│ ├── cpc.h5
│ ├── encoder.h5
│ └── supervised.h5
├── resources
├── batch_sample_sorted.png
├── figure.png
├── figure.svg
├── lena.jpg
├── t10k-images-idx3-ubyte.gz
├── t10k-labels-idx1-ubyte.gz
├── train-images-idx3-ubyte.gz
└── train-labels-idx1-ubyte.gz
└── train_model.py
/README.md:
--------------------------------------------------------------------------------
1 | ### Representation Learning with Contrastive Predictive Coding
2 |
3 | This repository contains a Keras implementation of the algorithm presented in the paper [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748).
4 |
5 | The goal of unsupervised representation learning is to capture semantic information about the world, recognizing patterns in the data without using annotations. This paper presents a new method called Contrastive Predictive Coding (CPC) that can do so across multiple applications. The main ideas of the paper are:
6 | * Contrastive: it is trained using a contrastive approach, that is, the main model has to discern between *right* and *wrong* data sequences.
7 | * Predictive: the model has to predict future patterns given the current context.
8 | * Coding: the model performs this prediction in a latent space, transforming code vectors into other code vectors (in contrast with predicting high-dimensional data directly).
9 |
10 | CPC has to predict the next item in a sequence using only an embedded representation of the data, provided by an encoder. In order to solve the task, this encoder has to learn a meaningful representation of the data space. After training, this encoder can be used for other downstream tasks like supervised classification.
11 |
12 |
13 |
14 |
15 |
16 | To train the CPC algorithm, I have created a toy dataset. This dataset consists of sequences of modified MNIST numbers (64x64 RGB). Positive sequence samples contain *sorted* numbers, and negative ones *random* numbers. For example, let's assume that the context sequence length is S=4, and CPC is asked to predict the next P=2 numbers. A positive sample could look like ```[2, 3, 4, 5]->[6, 7]```, whereas a negative one could be ```[1, 2, 3, 4]->[0, 8]```. Of course CPC will only see the patches, not the actual numbers.
17 |
18 | Disclaimer: this code is provided *as is*, if you encounter a bug please report it as an issue. Your help will be much welcomed!
19 |
20 | ### Results
21 |
22 | After 10 training epochs, CPC reports a 99% accuracy on the contrastive task. After training, I froze the encoder and trained a MLP on top of it to perform supervised digit classification on the same MNIST data. It achieved 90% accuracy after 10 epochs, demonstrating the effectiveness of CPC for unsupervised feature extraction.
23 |
24 | ### Usage
25 |
26 | - Execute ```python train_model.py``` to train the CPC model.
27 | - Execute ```python benchmark_model.py``` to train the MLP on top of the CPC encoder.
28 |
29 | ### Requisites
30 |
31 | - [Anaconda Python 3.5.3](https://www.continuum.io/downloads)
32 | - [Keras 2.0.6](https://keras.io/)
33 | - [Tensorflow 1.4.0](https://www.tensorflow.org/)
34 | - GPU for fast training.
35 |
36 | ### References
37 |
38 | - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748)
39 |
--------------------------------------------------------------------------------
/benchmark_model.py:
--------------------------------------------------------------------------------
1 | ''' This module evaluates the performance of a trained CPC encoder '''
2 |
3 | from data_utils import MnistGenerator
4 | from os.path import join, basename, dirname, exists
5 | import keras
6 |
7 |
8 | def build_model(encoder_path, image_shape, learning_rate):
9 |
10 | # Read the encoder
11 | encoder = keras.models.load_model(encoder_path)
12 |
13 | # Freeze weights
14 | encoder.trainable = False
15 | for layer in encoder.layers:
16 | layer.trainable = False
17 |
18 | # Define the classifier
19 | x_input = keras.layers.Input(image_shape)
20 | x = encoder(x_input)
21 | x = keras.layers.Dense(units=128, activation='linear')(x)
22 | x = keras.layers.BatchNormalization()(x)
23 | x = keras.layers.LeakyReLU()(x)
24 | x = keras.layers.Dense(units=10, activation='softmax')(x)
25 |
26 | # Model
27 | model = keras.models.Model(inputs=x_input, outputs=x)
28 |
29 | # Compile model
30 | model.compile(
31 | optimizer=keras.optimizers.Adam(lr=learning_rate),
32 | loss='categorical_crossentropy',
33 | metrics=['categorical_accuracy']
34 | )
35 | model.summary()
36 |
37 | return model
38 |
39 |
40 | def benchmark_model(encoder_path, epochs, batch_size, output_dir, lr=1e-4, image_size=28, color=False):
41 |
42 | # Prepare data
43 | train_data = MnistGenerator(batch_size, subset='train', image_size=image_size, color=color, rescale=True)
44 |
45 | validation_data = MnistGenerator(batch_size, subset='valid', image_size=image_size, color=color, rescale=True)
46 |
47 | # Prepares the model
48 | model = build_model(encoder_path, image_shape=(image_size, image_size, 3), learning_rate=lr)
49 |
50 | # Callbacks
51 | callbacks = [keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=1/3, patience=2, min_lr=1e-4)]
52 |
53 | # Trains the model
54 | model.fit_generator(
55 | generator=train_data,
56 | steps_per_epoch=len(train_data),
57 | validation_data=validation_data,
58 | validation_steps=len(validation_data),
59 | epochs=epochs,
60 | verbose=1,
61 | callbacks=callbacks
62 | )
63 |
64 | # Saves the model
65 | model.save(join(output_dir, 'supervised.h5'))
66 |
67 |
68 | if __name__ == "__main__":
69 |
70 | benchmark_model(
71 | encoder_path='models/64x64/encoder.h5',
72 | epochs=15,
73 | batch_size=64,
74 | output_dir='models/64x64',
75 | lr=1e-3,
76 | image_size=64,
77 | color=True
78 | )
79 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | ''' This module contains code to handle data '''
2 |
3 | import os
4 | import numpy as np
5 | import scipy.ndimage
6 | from PIL import Image
7 | import scipy
8 | import sys
9 | from matplotlib import pyplot as plt
10 |
11 |
12 | class MnistHandler(object):
13 |
14 | ''' Provides a convenient interface to manipulate MNIST data '''
15 |
16 | def __init__(self):
17 |
18 | # Download data if needed
19 | self.X_train, self.y_train, self.X_val, self.y_val, self.X_test, self.y_test = self.load_dataset()
20 |
21 | # Load Lena image to memory
22 | self.lena = Image.open('resources/lena.jpg')
23 |
24 | def load_dataset(self):
25 | # Credit for this function: https://github.com/Lasagne/Lasagne/blob/master/examples/mnist.py
26 |
27 | # We first define a download function, supporting both Python 2 and 3.
28 | if sys.version_info[0] == 2:
29 | from urllib import urlretrieve
30 | else:
31 | from urllib.request import urlretrieve
32 |
33 | def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
34 | print("Downloading %s" % filename)
35 | urlretrieve(source + filename, filename)
36 |
37 | # We then define functions for loading MNIST images and labels.
38 | # For convenience, they also download the requested files if needed.
39 | import gzip
40 |
41 | def load_mnist_images(filename):
42 | if not os.path.exists(filename):
43 | download(filename)
44 | # Read the inputs in Yann LeCun's binary format.
45 | with gzip.open(filename, 'rb') as f:
46 | data = np.frombuffer(f.read(), np.uint8, offset=16)
47 | # The inputs are vectors now, we reshape them to monochrome 2D images,
48 | # following the shape convention: (examples, channels, rows, columns)
49 | data = data.reshape(-1, 1, 28, 28)
50 | # The inputs come as bytes, we convert them to float32 in range [0,1].
51 | # (Actually to range [0, 255/256], for compatibility to the version
52 | # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.)
53 | return data / np.float32(256)
54 |
55 | def load_mnist_labels(filename):
56 | if not os.path.exists(filename):
57 | download(filename)
58 | # Read the labels in Yann LeCun's binary format.
59 | with gzip.open(filename, 'rb') as f:
60 | data = np.frombuffer(f.read(), np.uint8, offset=8)
61 | # The labels are vectors of integers now, that's exactly what we want.
62 | return data
63 |
64 | # We can now download and read the training and test set images and labels.
65 | X_train = load_mnist_images('resources/train-images-idx3-ubyte.gz')
66 | y_train = load_mnist_labels('resources/train-labels-idx1-ubyte.gz')
67 | X_test = load_mnist_images('resources/t10k-images-idx3-ubyte.gz')
68 | y_test = load_mnist_labels('resources/t10k-labels-idx1-ubyte.gz')
69 |
70 | # We reserve the last 10000 training examples for validation.
71 | X_train, X_val = X_train[:-10000], X_train[-10000:]
72 | y_train, y_val = y_train[:-10000], y_train[-10000:]
73 |
74 | # We just return all the arrays in order, as expected in main().
75 | # (It doesn't matter how we do this as long as we can read them again.)
76 | return X_train, y_train, X_val, y_val, X_test, y_test
77 |
78 | def process_batch(self, batch, batch_size, image_size=28, color=False, rescale=True):
79 |
80 | # Resize from 28x28 to 64x64
81 | if image_size == 64:
82 | batch_resized = []
83 | for i in range(batch.shape[0]):
84 | # resize to 64x64 pixels
85 | batch_resized.append(scipy.ndimage.zoom(batch[i, :, :], 2.3, order=1))
86 | batch = np.stack(batch_resized)
87 |
88 | # Convert to RGB
89 | batch = batch.reshape((batch_size, 1, image_size, image_size))
90 | batch = np.concatenate([batch, batch, batch], axis=1)
91 |
92 | # Modify images if color distribution requested
93 | if color:
94 |
95 | # Binarize images
96 | batch[batch >= 0.5] = 1
97 | batch[batch < 0.5] = 0
98 |
99 | # For each image in the mini batch
100 | for i in range(batch_size):
101 |
102 | # Take a random crop of the Lena image (background)
103 | x_c = np.random.randint(0, self.lena.size[0] - image_size)
104 | y_c = np.random.randint(0, self.lena.size[1] - image_size)
105 | image = self.lena.crop((x_c, y_c, x_c + image_size, y_c + image_size))
106 | image = np.asarray(image).transpose((2, 0, 1)) / 255.0
107 |
108 | # Randomly alter the color distribution of the crop
109 | for j in range(3):
110 | image[j, :, :] = (image[j, :, :] + np.random.uniform(0, 1)) / 2.0
111 |
112 | # Invert the color of pixels where there is a number
113 | image[batch[i, :, :, :] == 1] = 1 - image[batch[i, :, :, :] == 1]
114 | batch[i, :, :, :] = image
115 |
116 | # Rescale to range [-1, +1]
117 | if rescale:
118 | batch = batch * 2 - 1
119 |
120 | # Channel last
121 | batch = batch.transpose((0, 2, 3, 1))
122 |
123 | return batch
124 |
125 | def get_batch(self, subset, batch_size, image_size=28, color=False, rescale=True):
126 |
127 | # Select a subset
128 | if subset == 'train':
129 | X = self.X_train
130 | y = self.y_train
131 | elif subset == 'valid':
132 | X = self.X_val
133 | y = self.y_val
134 | elif subset == 'test':
135 | X = self.X_test
136 | y = self.y_test
137 |
138 | # Random choice of samples
139 | idx = np.random.choice(X.shape[0], batch_size)
140 | batch = X[idx, 0, :].reshape((batch_size, 28, 28))
141 |
142 | # Process batch
143 | batch = self.process_batch(batch, batch_size, image_size, color, rescale)
144 |
145 | # Image label
146 | labels = y[idx]
147 |
148 | return batch.astype('float32'), labels.astype('int32')
149 |
150 | def get_batch_by_labels(self, subset, labels, image_size=28, color=False, rescale=True):
151 |
152 | # Select a subset
153 | if subset == 'train':
154 | X = self.X_train
155 | y = self.y_train
156 | elif subset == 'valid':
157 | X = self.X_val
158 | y = self.y_val
159 | elif subset == 'test':
160 | X = self.X_test
161 | y = self.y_test
162 |
163 | # Find samples matching labels
164 | idxs = []
165 | for i, label in enumerate(labels):
166 |
167 | idx = np.where(y == label)[0]
168 | idx_sel = np.random.choice(idx, 1)[0]
169 | idxs.append(idx_sel)
170 |
171 | # Retrieve images
172 | batch = X[np.array(idxs), 0, :].reshape((len(labels), 28, 28))
173 |
174 | # Process batch
175 | batch = self.process_batch(batch, len(labels), image_size, color, rescale)
176 |
177 | return batch.astype('float32'), labels.astype('int32')
178 |
179 | def get_n_samples(self, subset):
180 |
181 | if subset == 'train':
182 | y_len = self.y_train.shape[0]
183 | elif subset == 'valid':
184 | y_len = self.y_val.shape[0]
185 | elif subset == 'test':
186 | y_len = self.y_test.shape[0]
187 |
188 | return y_len
189 |
190 |
191 | class MnistGenerator(object):
192 |
193 | ''' Data generator providing MNIST data '''
194 |
195 | def __init__(self, batch_size, subset, image_size=28, color=False, rescale=True):
196 |
197 | # Set params
198 | self.batch_size = batch_size
199 | self.subset = subset
200 | self.image_size = image_size
201 | self.color = color
202 | self.rescale = rescale
203 |
204 | # Initialize MNIST dataset
205 | self.mnist_handler = MnistHandler()
206 | self.n_samples = self.mnist_handler.get_n_samples(subset)
207 | self.n_batches = self.n_samples // batch_size
208 |
209 | def __iter__(self):
210 | return self
211 |
212 | def __next__(self):
213 | return self.next()
214 |
215 | def __len__(self):
216 | return self.n_batches
217 |
218 | def next(self):
219 |
220 | # Get data
221 | x, y = self.mnist_handler.get_batch(self.subset, self.batch_size, self.image_size, self.color, self.rescale)
222 |
223 | # Convert y to one-hot
224 | y_h = np.eye(10)[y]
225 |
226 | return x, y_h
227 |
228 |
229 | class SortedNumberGenerator(object):
230 |
231 | ''' Data generator providing lists of sorted numbers '''
232 |
233 | def __init__(self, batch_size, subset, terms, positive_samples=1, predict_terms=1, image_size=28, color=False, rescale=True):
234 |
235 | # Set params
236 | self.positive_samples = positive_samples
237 | self.predict_terms = predict_terms
238 | self.batch_size = batch_size
239 | self.subset = subset
240 | self.terms = terms
241 | self.image_size = image_size
242 | self.color = color
243 | self.rescale = rescale
244 |
245 | # Initialize MNIST dataset
246 | self.mnist_handler = MnistHandler()
247 | self.n_samples = self.mnist_handler.get_n_samples(subset) // terms
248 | self.n_batches = self.n_samples // batch_size
249 |
250 | def __iter__(self):
251 | return self
252 |
253 | def __next__(self):
254 | return self.next()
255 |
256 | def __len__(self):
257 | return self.n_batches
258 |
259 | def next(self):
260 |
261 | # Build sentences
262 | image_labels = np.zeros((self.batch_size, self.terms + self.predict_terms))
263 | sentence_labels = np.ones((self.batch_size, 1)).astype('int32')
264 | positive_samples_n = self.positive_samples
265 | for b in range(self.batch_size):
266 |
267 | # Set ordered predictions for positive samples
268 | seed = np.random.randint(0, 10)
269 | sentence = np.mod(np.arange(seed, seed + self.terms + self.predict_terms), 10)
270 |
271 | if positive_samples_n <= 0:
272 |
273 | # Set random predictions for negative samples
274 | # Each predicted term draws a number from a distribution that excludes itself
275 | numbers = np.arange(0, 10)
276 | predicted_terms = sentence[-self.predict_terms:]
277 | for i, p in enumerate(predicted_terms):
278 | predicted_terms[i] = np.random.choice(numbers[numbers != p], 1)
279 | sentence[-self.predict_terms:] = np.mod(predicted_terms, 10)
280 | sentence_labels[b, :] = 0
281 |
282 | # Save sentence
283 | image_labels[b, :] = sentence
284 |
285 | positive_samples_n -= 1
286 |
287 | # Retrieve actual images
288 | images, _ = self.mnist_handler.get_batch_by_labels(self.subset, image_labels.flatten(), self.image_size, self.color, self.rescale)
289 |
290 | # Assemble batch
291 | images = images.reshape((self.batch_size, self.terms + self.predict_terms, images.shape[1], images.shape[2], images.shape[3]))
292 | x_images = images[:, :-self.predict_terms, ...]
293 | y_images = images[:, -self.predict_terms:, ...]
294 |
295 | # Randomize
296 | idxs = np.random.choice(sentence_labels.shape[0], sentence_labels.shape[0], replace=False)
297 |
298 | return [x_images[idxs, ...], y_images[idxs, ...]], sentence_labels[idxs, ...]
299 |
300 |
301 | class SameNumberGenerator(object):
302 |
303 | ''' Data generator providing lists of similar numbers '''
304 |
305 | def __init__(self, batch_size, subset, terms, positive_samples=1, predict_terms=1, image_size=28, color=False, rescale=True):
306 |
307 | # Set params
308 | self.positive_samples = positive_samples
309 | self.predict_terms = predict_terms
310 | self.batch_size = batch_size
311 | self.subset = subset
312 | self.terms = terms
313 | self.image_size = image_size
314 | self.color = color
315 | self.rescale = rescale
316 |
317 | # Initialize MNIST dataset
318 | self.mnist_handler = MnistHandler()
319 | self.n_samples = self.mnist_handler.get_n_samples(subset) // terms
320 | self.n_batches = self.n_samples // batch_size
321 |
322 | def __iter__(self):
323 | return self
324 |
325 | def __next__(self):
326 | return self.next()
327 |
328 | def __len__(self):
329 | return self.n_batches
330 |
331 | def next(self):
332 |
333 | # Build sentences
334 | image_labels = np.zeros((self.batch_size, self.terms + self.predict_terms))
335 | sentence_labels = np.ones((self.batch_size, 1)).astype('int32')
336 | positive_samples_n = self.positive_samples
337 | for b in range(self.batch_size):
338 |
339 | # Set positive samples
340 | seed = np.random.randint(0, 10)
341 | sentence = seed * np.ones(self.terms + self.predict_terms)
342 |
343 | if positive_samples_n <= 0:
344 |
345 | # Set random predictions for negative samples
346 | sentence[-self.predict_terms:] = np.mod(sentence[-self.predict_terms:] + np.random.randint(1, 10, self.predict_terms), 10)
347 | sentence_labels[b, :] = 0
348 |
349 | # Save sentence
350 | image_labels[b, :] = sentence
351 |
352 | positive_samples_n -= 1
353 |
354 | # Retrieve actual images
355 | images, _ = self.mnist_handler.get_batch_by_labels(self.subset, image_labels.flatten(), self.image_size, self.color, self.rescale)
356 |
357 | # Assemble batch
358 | images = images.reshape((self.batch_size, self.terms + self.predict_terms, images.shape[1], images.shape[2], images.shape[3]))
359 | x_images = images[:, :-self.predict_terms, ...]
360 | y_images = images[:, -self.predict_terms:, ...]
361 |
362 | # Randomize
363 | idxs = np.random.choice(sentence_labels.shape[0], sentence_labels.shape[0], replace=False)
364 |
365 | return [x_images[idxs, ...], y_images[idxs, ...]], sentence_labels[idxs, ...]
366 |
367 |
368 | def plot_sequences(x, y, labels=None, output_path=None):
369 |
370 | ''' Draws a plot where sequences of numbers can be studied conveniently '''
371 |
372 | images = np.concatenate([x, y], axis=1)
373 | n_batches = images.shape[0]
374 | n_terms = images.shape[1]
375 | counter = 1
376 | for n_b in range(n_batches):
377 | for n_t in range(n_terms):
378 | plt.subplot(n_batches, n_terms, counter)
379 | plt.imshow(images[n_b, n_t, :, :, :])
380 | plt.axis('off')
381 | counter += 1
382 | if labels is not None:
383 | plt.title(labels[n_b, 0])
384 |
385 | if output_path is not None:
386 | plt.savefig(output_path, dpi=600)
387 | else:
388 | plt.show()
389 |
390 |
391 | if __name__ == "__main__":
392 |
393 | # Test SortedNumberGenerator
394 | ag = SortedNumberGenerator(batch_size=8, subset='train', terms=4, positive_samples=4, predict_terms=4, image_size=64, color=True, rescale=False)
395 | for (x, y), labels in ag:
396 | plot_sequences(x, y, labels, output_path=r'resources/batch_sample_sorted.png')
397 | break
398 |
399 |
--------------------------------------------------------------------------------
/models/28x28/cpc.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/28x28/cpc.h5
--------------------------------------------------------------------------------
/models/28x28/encoder.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/28x28/encoder.h5
--------------------------------------------------------------------------------
/models/28x28/supervised.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/28x28/supervised.h5
--------------------------------------------------------------------------------
/models/64x64/cpc.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/64x64/cpc.h5
--------------------------------------------------------------------------------
/models/64x64/encoder.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/64x64/encoder.h5
--------------------------------------------------------------------------------
/models/64x64/supervised.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/64x64/supervised.h5
--------------------------------------------------------------------------------
/resources/batch_sample_sorted.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/batch_sample_sorted.png
--------------------------------------------------------------------------------
/resources/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/figure.png
--------------------------------------------------------------------------------
/resources/lena.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/lena.jpg
--------------------------------------------------------------------------------
/resources/t10k-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/t10k-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/resources/t10k-labels-idx1-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/t10k-labels-idx1-ubyte.gz
--------------------------------------------------------------------------------
/resources/train-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/train-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/resources/train-labels-idx1-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/train-labels-idx1-ubyte.gz
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | This module describes the contrastive predictive coding model from DeepMind:
3 |
4 | Oord, Aaron van den, Yazhe Li, and Oriol Vinyals.
5 | "Representation Learning with Contrastive Predictive Coding."
6 | arXiv preprint arXiv:1807.03748 (2018).
7 | '''
8 | from data_utils import SortedNumberGenerator
9 | from os.path import join, basename, dirname, exists
10 | import keras
11 | from keras import backend as K
12 |
13 |
14 | def network_encoder(x, code_size):
15 |
16 | ''' Define the network mapping images to embeddings '''
17 |
18 | x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x)
19 | x = keras.layers.BatchNormalization()(x)
20 | x = keras.layers.LeakyReLU()(x)
21 | x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x)
22 | x = keras.layers.BatchNormalization()(x)
23 | x = keras.layers.LeakyReLU()(x)
24 | x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x)
25 | x = keras.layers.BatchNormalization()(x)
26 | x = keras.layers.LeakyReLU()(x)
27 | x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x)
28 | x = keras.layers.BatchNormalization()(x)
29 | x = keras.layers.LeakyReLU()(x)
30 | x = keras.layers.Flatten()(x)
31 | x = keras.layers.Dense(units=256, activation='linear')(x)
32 | x = keras.layers.BatchNormalization()(x)
33 | x = keras.layers.LeakyReLU()(x)
34 | x = keras.layers.Dense(units=code_size, activation='linear', name='encoder_embedding')(x)
35 |
36 | return x
37 |
38 |
39 | def network_autoregressive(x):
40 |
41 | ''' Define the network that integrates information along the sequence '''
42 |
43 | # x = keras.layers.GRU(units=256, return_sequences=True)(x)
44 | # x = keras.layers.BatchNormalization()(x)
45 | x = keras.layers.GRU(units=256, return_sequences=False, name='ar_context')(x)
46 |
47 | return x
48 |
49 |
50 | def network_prediction(context, code_size, predict_terms):
51 |
52 | ''' Define the network mapping context to multiple embeddings '''
53 |
54 | outputs = []
55 | for i in range(predict_terms):
56 | outputs.append(keras.layers.Dense(units=code_size, activation="linear", name='z_t_{i}'.format(i=i))(context))
57 |
58 | if len(outputs) == 1:
59 | output = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(outputs[0])
60 | else:
61 | output = keras.layers.Lambda(lambda x: K.stack(x, axis=1))(outputs)
62 |
63 | return output
64 |
65 |
66 | class CPCLayer(keras.layers.Layer):
67 |
68 | ''' Computes dot product between true and predicted embedding vectors '''
69 |
70 | def __init__(self, **kwargs):
71 | super(CPCLayer, self).__init__(**kwargs)
72 |
73 | def call(self, inputs):
74 |
75 | # Compute dot product among vectors
76 | preds, y_encoded = inputs
77 | dot_product = K.mean(y_encoded * preds, axis=-1)
78 | dot_product = K.mean(dot_product, axis=-1, keepdims=True) # average along the temporal dimension
79 |
80 | # Keras loss functions take probabilities
81 | dot_product_probs = K.sigmoid(dot_product)
82 |
83 | return dot_product_probs
84 |
85 | def compute_output_shape(self, input_shape):
86 | return (input_shape[0][0], 1)
87 |
88 |
89 | def network_cpc(image_shape, terms, predict_terms, code_size, learning_rate):
90 |
91 | ''' Define the CPC network combining encoder and autoregressive model '''
92 |
93 | # Set learning phase (https://stackoverflow.com/questions/42969779/keras-error-you-must-feed-a-value-for-placeholder-tensor-bidirectional-1-keras)
94 | K.set_learning_phase(1)
95 |
96 | # Define encoder model
97 | encoder_input = keras.layers.Input(image_shape)
98 | encoder_output = network_encoder(encoder_input, code_size)
99 | encoder_model = keras.models.Model(encoder_input, encoder_output, name='encoder')
100 | encoder_model.summary()
101 |
102 | # Define rest of model
103 | x_input = keras.layers.Input((terms, image_shape[0], image_shape[1], image_shape[2]))
104 | x_encoded = keras.layers.TimeDistributed(encoder_model)(x_input)
105 | context = network_autoregressive(x_encoded)
106 | preds = network_prediction(context, code_size, predict_terms)
107 |
108 | y_input = keras.layers.Input((predict_terms, image_shape[0], image_shape[1], image_shape[2]))
109 | y_encoded = keras.layers.TimeDistributed(encoder_model)(y_input)
110 |
111 | # Loss
112 | dot_product_probs = CPCLayer()([preds, y_encoded])
113 |
114 | # Model
115 | cpc_model = keras.models.Model(inputs=[x_input, y_input], outputs=dot_product_probs)
116 |
117 | # Compile model
118 | cpc_model.compile(
119 | optimizer=keras.optimizers.Adam(lr=learning_rate),
120 | loss='binary_crossentropy',
121 | metrics=['binary_accuracy']
122 | )
123 | cpc_model.summary()
124 |
125 | return cpc_model
126 |
127 |
128 | def train_model(epochs, batch_size, output_dir, code_size, lr=1e-4, terms=4, predict_terms=4, image_size=28, color=False):
129 |
130 | # Prepare data
131 | train_data = SortedNumberGenerator(batch_size=batch_size, subset='train', terms=terms,
132 | positive_samples=batch_size // 2, predict_terms=predict_terms,
133 | image_size=image_size, color=color, rescale=True)
134 |
135 | validation_data = SortedNumberGenerator(batch_size=batch_size, subset='valid', terms=terms,
136 | positive_samples=batch_size // 2, predict_terms=predict_terms,
137 | image_size=image_size, color=color, rescale=True)
138 |
139 | # Prepares the model
140 | model = network_cpc(image_shape=(image_size, image_size, 3), terms=terms, predict_terms=predict_terms,
141 | code_size=code_size, learning_rate=lr)
142 |
143 | # Callbacks
144 | callbacks = [keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=1/3, patience=2, min_lr=1e-4)]
145 |
146 | # Trains the model
147 | model.fit_generator(
148 | generator=train_data,
149 | steps_per_epoch=len(train_data),
150 | validation_data=validation_data,
151 | validation_steps=len(validation_data),
152 | epochs=epochs,
153 | verbose=1,
154 | callbacks=callbacks
155 | )
156 |
157 | # Saves the model
158 | # Remember to add custom_objects={'CPCLayer': CPCLayer} to load_model when loading from disk
159 | model.save(join(output_dir, 'cpc.h5'))
160 |
161 | # Saves the encoder alone
162 | encoder = model.layers[1].layer
163 | encoder.save(join(output_dir, 'encoder.h5'))
164 |
165 |
166 | if __name__ == "__main__":
167 |
168 | train_model(
169 | epochs=10,
170 | batch_size=32,
171 | output_dir='models/64x64',
172 | code_size=128,
173 | lr=1e-3,
174 | terms=4,
175 | predict_terms=4,
176 | image_size=64,
177 | color=True
178 | )
179 |
180 |
--------------------------------------------------------------------------------