├── README.md ├── benchmark_model.py ├── data_utils.py ├── models ├── 28x28 │ ├── cpc.h5 │ ├── encoder.h5 │ └── supervised.h5 └── 64x64 │ ├── cpc.h5 │ ├── encoder.h5 │ └── supervised.h5 ├── resources ├── batch_sample_sorted.png ├── figure.png ├── figure.svg ├── lena.jpg ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz └── train_model.py /README.md: -------------------------------------------------------------------------------- 1 | ### Representation Learning with Contrastive Predictive Coding 2 | 3 | This repository contains a Keras implementation of the algorithm presented in the paper [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748). 4 | 5 | The goal of unsupervised representation learning is to capture semantic information about the world, recognizing patterns in the data without using annotations. This paper presents a new method called Contrastive Predictive Coding (CPC) that can do so across multiple applications. The main ideas of the paper are: 6 | * Contrastive: it is trained using a contrastive approach, that is, the main model has to discern between *right* and *wrong* data sequences. 7 | * Predictive: the model has to predict future patterns given the current context. 8 | * Coding: the model performs this prediction in a latent space, transforming code vectors into other code vectors (in contrast with predicting high-dimensional data directly). 9 | 10 | CPC has to predict the next item in a sequence using only an embedded representation of the data, provided by an encoder. In order to solve the task, this encoder has to learn a meaningful representation of the data space. After training, this encoder can be used for other downstream tasks like supervised classification. 11 | 12 |

13 | CPC algorithm 14 |

15 | 16 | To train the CPC algorithm, I have created a toy dataset. This dataset consists of sequences of modified MNIST numbers (64x64 RGB). Positive sequence samples contain *sorted* numbers, and negative ones *random* numbers. For example, let's assume that the context sequence length is S=4, and CPC is asked to predict the next P=2 numbers. A positive sample could look like ```[2, 3, 4, 5]->[6, 7]```, whereas a negative one could be ```[1, 2, 3, 4]->[0, 8]```. Of course CPC will only see the patches, not the actual numbers. 17 | 18 | Disclaimer: this code is provided *as is*, if you encounter a bug please report it as an issue. Your help will be much welcomed! 19 | 20 | ### Results 21 | 22 | After 10 training epochs, CPC reports a 99% accuracy on the contrastive task. After training, I froze the encoder and trained a MLP on top of it to perform supervised digit classification on the same MNIST data. It achieved 90% accuracy after 10 epochs, demonstrating the effectiveness of CPC for unsupervised feature extraction. 23 | 24 | ### Usage 25 | 26 | - Execute ```python train_model.py``` to train the CPC model. 27 | - Execute ```python benchmark_model.py``` to train the MLP on top of the CPC encoder. 28 | 29 | ### Requisites 30 | 31 | - [Anaconda Python 3.5.3](https://www.continuum.io/downloads) 32 | - [Keras 2.0.6](https://keras.io/) 33 | - [Tensorflow 1.4.0](https://www.tensorflow.org/) 34 | - GPU for fast training. 35 | 36 | ### References 37 | 38 | - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748) 39 | -------------------------------------------------------------------------------- /benchmark_model.py: -------------------------------------------------------------------------------- 1 | ''' This module evaluates the performance of a trained CPC encoder ''' 2 | 3 | from data_utils import MnistGenerator 4 | from os.path import join, basename, dirname, exists 5 | import keras 6 | 7 | 8 | def build_model(encoder_path, image_shape, learning_rate): 9 | 10 | # Read the encoder 11 | encoder = keras.models.load_model(encoder_path) 12 | 13 | # Freeze weights 14 | encoder.trainable = False 15 | for layer in encoder.layers: 16 | layer.trainable = False 17 | 18 | # Define the classifier 19 | x_input = keras.layers.Input(image_shape) 20 | x = encoder(x_input) 21 | x = keras.layers.Dense(units=128, activation='linear')(x) 22 | x = keras.layers.BatchNormalization()(x) 23 | x = keras.layers.LeakyReLU()(x) 24 | x = keras.layers.Dense(units=10, activation='softmax')(x) 25 | 26 | # Model 27 | model = keras.models.Model(inputs=x_input, outputs=x) 28 | 29 | # Compile model 30 | model.compile( 31 | optimizer=keras.optimizers.Adam(lr=learning_rate), 32 | loss='categorical_crossentropy', 33 | metrics=['categorical_accuracy'] 34 | ) 35 | model.summary() 36 | 37 | return model 38 | 39 | 40 | def benchmark_model(encoder_path, epochs, batch_size, output_dir, lr=1e-4, image_size=28, color=False): 41 | 42 | # Prepare data 43 | train_data = MnistGenerator(batch_size, subset='train', image_size=image_size, color=color, rescale=True) 44 | 45 | validation_data = MnistGenerator(batch_size, subset='valid', image_size=image_size, color=color, rescale=True) 46 | 47 | # Prepares the model 48 | model = build_model(encoder_path, image_shape=(image_size, image_size, 3), learning_rate=lr) 49 | 50 | # Callbacks 51 | callbacks = [keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=1/3, patience=2, min_lr=1e-4)] 52 | 53 | # Trains the model 54 | model.fit_generator( 55 | generator=train_data, 56 | steps_per_epoch=len(train_data), 57 | validation_data=validation_data, 58 | validation_steps=len(validation_data), 59 | epochs=epochs, 60 | verbose=1, 61 | callbacks=callbacks 62 | ) 63 | 64 | # Saves the model 65 | model.save(join(output_dir, 'supervised.h5')) 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | benchmark_model( 71 | encoder_path='models/64x64/encoder.h5', 72 | epochs=15, 73 | batch_size=64, 74 | output_dir='models/64x64', 75 | lr=1e-3, 76 | image_size=64, 77 | color=True 78 | ) 79 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | ''' This module contains code to handle data ''' 2 | 3 | import os 4 | import numpy as np 5 | import scipy.ndimage 6 | from PIL import Image 7 | import scipy 8 | import sys 9 | from matplotlib import pyplot as plt 10 | 11 | 12 | class MnistHandler(object): 13 | 14 | ''' Provides a convenient interface to manipulate MNIST data ''' 15 | 16 | def __init__(self): 17 | 18 | # Download data if needed 19 | self.X_train, self.y_train, self.X_val, self.y_val, self.X_test, self.y_test = self.load_dataset() 20 | 21 | # Load Lena image to memory 22 | self.lena = Image.open('resources/lena.jpg') 23 | 24 | def load_dataset(self): 25 | # Credit for this function: https://github.com/Lasagne/Lasagne/blob/master/examples/mnist.py 26 | 27 | # We first define a download function, supporting both Python 2 and 3. 28 | if sys.version_info[0] == 2: 29 | from urllib import urlretrieve 30 | else: 31 | from urllib.request import urlretrieve 32 | 33 | def download(filename, source='http://yann.lecun.com/exdb/mnist/'): 34 | print("Downloading %s" % filename) 35 | urlretrieve(source + filename, filename) 36 | 37 | # We then define functions for loading MNIST images and labels. 38 | # For convenience, they also download the requested files if needed. 39 | import gzip 40 | 41 | def load_mnist_images(filename): 42 | if not os.path.exists(filename): 43 | download(filename) 44 | # Read the inputs in Yann LeCun's binary format. 45 | with gzip.open(filename, 'rb') as f: 46 | data = np.frombuffer(f.read(), np.uint8, offset=16) 47 | # The inputs are vectors now, we reshape them to monochrome 2D images, 48 | # following the shape convention: (examples, channels, rows, columns) 49 | data = data.reshape(-1, 1, 28, 28) 50 | # The inputs come as bytes, we convert them to float32 in range [0,1]. 51 | # (Actually to range [0, 255/256], for compatibility to the version 52 | # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) 53 | return data / np.float32(256) 54 | 55 | def load_mnist_labels(filename): 56 | if not os.path.exists(filename): 57 | download(filename) 58 | # Read the labels in Yann LeCun's binary format. 59 | with gzip.open(filename, 'rb') as f: 60 | data = np.frombuffer(f.read(), np.uint8, offset=8) 61 | # The labels are vectors of integers now, that's exactly what we want. 62 | return data 63 | 64 | # We can now download and read the training and test set images and labels. 65 | X_train = load_mnist_images('resources/train-images-idx3-ubyte.gz') 66 | y_train = load_mnist_labels('resources/train-labels-idx1-ubyte.gz') 67 | X_test = load_mnist_images('resources/t10k-images-idx3-ubyte.gz') 68 | y_test = load_mnist_labels('resources/t10k-labels-idx1-ubyte.gz') 69 | 70 | # We reserve the last 10000 training examples for validation. 71 | X_train, X_val = X_train[:-10000], X_train[-10000:] 72 | y_train, y_val = y_train[:-10000], y_train[-10000:] 73 | 74 | # We just return all the arrays in order, as expected in main(). 75 | # (It doesn't matter how we do this as long as we can read them again.) 76 | return X_train, y_train, X_val, y_val, X_test, y_test 77 | 78 | def process_batch(self, batch, batch_size, image_size=28, color=False, rescale=True): 79 | 80 | # Resize from 28x28 to 64x64 81 | if image_size == 64: 82 | batch_resized = [] 83 | for i in range(batch.shape[0]): 84 | # resize to 64x64 pixels 85 | batch_resized.append(scipy.ndimage.zoom(batch[i, :, :], 2.3, order=1)) 86 | batch = np.stack(batch_resized) 87 | 88 | # Convert to RGB 89 | batch = batch.reshape((batch_size, 1, image_size, image_size)) 90 | batch = np.concatenate([batch, batch, batch], axis=1) 91 | 92 | # Modify images if color distribution requested 93 | if color: 94 | 95 | # Binarize images 96 | batch[batch >= 0.5] = 1 97 | batch[batch < 0.5] = 0 98 | 99 | # For each image in the mini batch 100 | for i in range(batch_size): 101 | 102 | # Take a random crop of the Lena image (background) 103 | x_c = np.random.randint(0, self.lena.size[0] - image_size) 104 | y_c = np.random.randint(0, self.lena.size[1] - image_size) 105 | image = self.lena.crop((x_c, y_c, x_c + image_size, y_c + image_size)) 106 | image = np.asarray(image).transpose((2, 0, 1)) / 255.0 107 | 108 | # Randomly alter the color distribution of the crop 109 | for j in range(3): 110 | image[j, :, :] = (image[j, :, :] + np.random.uniform(0, 1)) / 2.0 111 | 112 | # Invert the color of pixels where there is a number 113 | image[batch[i, :, :, :] == 1] = 1 - image[batch[i, :, :, :] == 1] 114 | batch[i, :, :, :] = image 115 | 116 | # Rescale to range [-1, +1] 117 | if rescale: 118 | batch = batch * 2 - 1 119 | 120 | # Channel last 121 | batch = batch.transpose((0, 2, 3, 1)) 122 | 123 | return batch 124 | 125 | def get_batch(self, subset, batch_size, image_size=28, color=False, rescale=True): 126 | 127 | # Select a subset 128 | if subset == 'train': 129 | X = self.X_train 130 | y = self.y_train 131 | elif subset == 'valid': 132 | X = self.X_val 133 | y = self.y_val 134 | elif subset == 'test': 135 | X = self.X_test 136 | y = self.y_test 137 | 138 | # Random choice of samples 139 | idx = np.random.choice(X.shape[0], batch_size) 140 | batch = X[idx, 0, :].reshape((batch_size, 28, 28)) 141 | 142 | # Process batch 143 | batch = self.process_batch(batch, batch_size, image_size, color, rescale) 144 | 145 | # Image label 146 | labels = y[idx] 147 | 148 | return batch.astype('float32'), labels.astype('int32') 149 | 150 | def get_batch_by_labels(self, subset, labels, image_size=28, color=False, rescale=True): 151 | 152 | # Select a subset 153 | if subset == 'train': 154 | X = self.X_train 155 | y = self.y_train 156 | elif subset == 'valid': 157 | X = self.X_val 158 | y = self.y_val 159 | elif subset == 'test': 160 | X = self.X_test 161 | y = self.y_test 162 | 163 | # Find samples matching labels 164 | idxs = [] 165 | for i, label in enumerate(labels): 166 | 167 | idx = np.where(y == label)[0] 168 | idx_sel = np.random.choice(idx, 1)[0] 169 | idxs.append(idx_sel) 170 | 171 | # Retrieve images 172 | batch = X[np.array(idxs), 0, :].reshape((len(labels), 28, 28)) 173 | 174 | # Process batch 175 | batch = self.process_batch(batch, len(labels), image_size, color, rescale) 176 | 177 | return batch.astype('float32'), labels.astype('int32') 178 | 179 | def get_n_samples(self, subset): 180 | 181 | if subset == 'train': 182 | y_len = self.y_train.shape[0] 183 | elif subset == 'valid': 184 | y_len = self.y_val.shape[0] 185 | elif subset == 'test': 186 | y_len = self.y_test.shape[0] 187 | 188 | return y_len 189 | 190 | 191 | class MnistGenerator(object): 192 | 193 | ''' Data generator providing MNIST data ''' 194 | 195 | def __init__(self, batch_size, subset, image_size=28, color=False, rescale=True): 196 | 197 | # Set params 198 | self.batch_size = batch_size 199 | self.subset = subset 200 | self.image_size = image_size 201 | self.color = color 202 | self.rescale = rescale 203 | 204 | # Initialize MNIST dataset 205 | self.mnist_handler = MnistHandler() 206 | self.n_samples = self.mnist_handler.get_n_samples(subset) 207 | self.n_batches = self.n_samples // batch_size 208 | 209 | def __iter__(self): 210 | return self 211 | 212 | def __next__(self): 213 | return self.next() 214 | 215 | def __len__(self): 216 | return self.n_batches 217 | 218 | def next(self): 219 | 220 | # Get data 221 | x, y = self.mnist_handler.get_batch(self.subset, self.batch_size, self.image_size, self.color, self.rescale) 222 | 223 | # Convert y to one-hot 224 | y_h = np.eye(10)[y] 225 | 226 | return x, y_h 227 | 228 | 229 | class SortedNumberGenerator(object): 230 | 231 | ''' Data generator providing lists of sorted numbers ''' 232 | 233 | def __init__(self, batch_size, subset, terms, positive_samples=1, predict_terms=1, image_size=28, color=False, rescale=True): 234 | 235 | # Set params 236 | self.positive_samples = positive_samples 237 | self.predict_terms = predict_terms 238 | self.batch_size = batch_size 239 | self.subset = subset 240 | self.terms = terms 241 | self.image_size = image_size 242 | self.color = color 243 | self.rescale = rescale 244 | 245 | # Initialize MNIST dataset 246 | self.mnist_handler = MnistHandler() 247 | self.n_samples = self.mnist_handler.get_n_samples(subset) // terms 248 | self.n_batches = self.n_samples // batch_size 249 | 250 | def __iter__(self): 251 | return self 252 | 253 | def __next__(self): 254 | return self.next() 255 | 256 | def __len__(self): 257 | return self.n_batches 258 | 259 | def next(self): 260 | 261 | # Build sentences 262 | image_labels = np.zeros((self.batch_size, self.terms + self.predict_terms)) 263 | sentence_labels = np.ones((self.batch_size, 1)).astype('int32') 264 | positive_samples_n = self.positive_samples 265 | for b in range(self.batch_size): 266 | 267 | # Set ordered predictions for positive samples 268 | seed = np.random.randint(0, 10) 269 | sentence = np.mod(np.arange(seed, seed + self.terms + self.predict_terms), 10) 270 | 271 | if positive_samples_n <= 0: 272 | 273 | # Set random predictions for negative samples 274 | # Each predicted term draws a number from a distribution that excludes itself 275 | numbers = np.arange(0, 10) 276 | predicted_terms = sentence[-self.predict_terms:] 277 | for i, p in enumerate(predicted_terms): 278 | predicted_terms[i] = np.random.choice(numbers[numbers != p], 1) 279 | sentence[-self.predict_terms:] = np.mod(predicted_terms, 10) 280 | sentence_labels[b, :] = 0 281 | 282 | # Save sentence 283 | image_labels[b, :] = sentence 284 | 285 | positive_samples_n -= 1 286 | 287 | # Retrieve actual images 288 | images, _ = self.mnist_handler.get_batch_by_labels(self.subset, image_labels.flatten(), self.image_size, self.color, self.rescale) 289 | 290 | # Assemble batch 291 | images = images.reshape((self.batch_size, self.terms + self.predict_terms, images.shape[1], images.shape[2], images.shape[3])) 292 | x_images = images[:, :-self.predict_terms, ...] 293 | y_images = images[:, -self.predict_terms:, ...] 294 | 295 | # Randomize 296 | idxs = np.random.choice(sentence_labels.shape[0], sentence_labels.shape[0], replace=False) 297 | 298 | return [x_images[idxs, ...], y_images[idxs, ...]], sentence_labels[idxs, ...] 299 | 300 | 301 | class SameNumberGenerator(object): 302 | 303 | ''' Data generator providing lists of similar numbers ''' 304 | 305 | def __init__(self, batch_size, subset, terms, positive_samples=1, predict_terms=1, image_size=28, color=False, rescale=True): 306 | 307 | # Set params 308 | self.positive_samples = positive_samples 309 | self.predict_terms = predict_terms 310 | self.batch_size = batch_size 311 | self.subset = subset 312 | self.terms = terms 313 | self.image_size = image_size 314 | self.color = color 315 | self.rescale = rescale 316 | 317 | # Initialize MNIST dataset 318 | self.mnist_handler = MnistHandler() 319 | self.n_samples = self.mnist_handler.get_n_samples(subset) // terms 320 | self.n_batches = self.n_samples // batch_size 321 | 322 | def __iter__(self): 323 | return self 324 | 325 | def __next__(self): 326 | return self.next() 327 | 328 | def __len__(self): 329 | return self.n_batches 330 | 331 | def next(self): 332 | 333 | # Build sentences 334 | image_labels = np.zeros((self.batch_size, self.terms + self.predict_terms)) 335 | sentence_labels = np.ones((self.batch_size, 1)).astype('int32') 336 | positive_samples_n = self.positive_samples 337 | for b in range(self.batch_size): 338 | 339 | # Set positive samples 340 | seed = np.random.randint(0, 10) 341 | sentence = seed * np.ones(self.terms + self.predict_terms) 342 | 343 | if positive_samples_n <= 0: 344 | 345 | # Set random predictions for negative samples 346 | sentence[-self.predict_terms:] = np.mod(sentence[-self.predict_terms:] + np.random.randint(1, 10, self.predict_terms), 10) 347 | sentence_labels[b, :] = 0 348 | 349 | # Save sentence 350 | image_labels[b, :] = sentence 351 | 352 | positive_samples_n -= 1 353 | 354 | # Retrieve actual images 355 | images, _ = self.mnist_handler.get_batch_by_labels(self.subset, image_labels.flatten(), self.image_size, self.color, self.rescale) 356 | 357 | # Assemble batch 358 | images = images.reshape((self.batch_size, self.terms + self.predict_terms, images.shape[1], images.shape[2], images.shape[3])) 359 | x_images = images[:, :-self.predict_terms, ...] 360 | y_images = images[:, -self.predict_terms:, ...] 361 | 362 | # Randomize 363 | idxs = np.random.choice(sentence_labels.shape[0], sentence_labels.shape[0], replace=False) 364 | 365 | return [x_images[idxs, ...], y_images[idxs, ...]], sentence_labels[idxs, ...] 366 | 367 | 368 | def plot_sequences(x, y, labels=None, output_path=None): 369 | 370 | ''' Draws a plot where sequences of numbers can be studied conveniently ''' 371 | 372 | images = np.concatenate([x, y], axis=1) 373 | n_batches = images.shape[0] 374 | n_terms = images.shape[1] 375 | counter = 1 376 | for n_b in range(n_batches): 377 | for n_t in range(n_terms): 378 | plt.subplot(n_batches, n_terms, counter) 379 | plt.imshow(images[n_b, n_t, :, :, :]) 380 | plt.axis('off') 381 | counter += 1 382 | if labels is not None: 383 | plt.title(labels[n_b, 0]) 384 | 385 | if output_path is not None: 386 | plt.savefig(output_path, dpi=600) 387 | else: 388 | plt.show() 389 | 390 | 391 | if __name__ == "__main__": 392 | 393 | # Test SortedNumberGenerator 394 | ag = SortedNumberGenerator(batch_size=8, subset='train', terms=4, positive_samples=4, predict_terms=4, image_size=64, color=True, rescale=False) 395 | for (x, y), labels in ag: 396 | plot_sequences(x, y, labels, output_path=r'resources/batch_sample_sorted.png') 397 | break 398 | 399 | -------------------------------------------------------------------------------- /models/28x28/cpc.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/28x28/cpc.h5 -------------------------------------------------------------------------------- /models/28x28/encoder.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/28x28/encoder.h5 -------------------------------------------------------------------------------- /models/28x28/supervised.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/28x28/supervised.h5 -------------------------------------------------------------------------------- /models/64x64/cpc.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/64x64/cpc.h5 -------------------------------------------------------------------------------- /models/64x64/encoder.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/64x64/encoder.h5 -------------------------------------------------------------------------------- /models/64x64/supervised.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/models/64x64/supervised.h5 -------------------------------------------------------------------------------- /resources/batch_sample_sorted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/batch_sample_sorted.png -------------------------------------------------------------------------------- /resources/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/figure.png -------------------------------------------------------------------------------- /resources/lena.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/lena.jpg -------------------------------------------------------------------------------- /resources/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /resources/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /resources/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /resources/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtellez/contrastive-predictive-coding/67657aac82785e97835f2b27450883b4e65b9b3f/resources/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This module describes the contrastive predictive coding model from DeepMind: 3 | 4 | Oord, Aaron van den, Yazhe Li, and Oriol Vinyals. 5 | "Representation Learning with Contrastive Predictive Coding." 6 | arXiv preprint arXiv:1807.03748 (2018). 7 | ''' 8 | from data_utils import SortedNumberGenerator 9 | from os.path import join, basename, dirname, exists 10 | import keras 11 | from keras import backend as K 12 | 13 | 14 | def network_encoder(x, code_size): 15 | 16 | ''' Define the network mapping images to embeddings ''' 17 | 18 | x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x) 19 | x = keras.layers.BatchNormalization()(x) 20 | x = keras.layers.LeakyReLU()(x) 21 | x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x) 22 | x = keras.layers.BatchNormalization()(x) 23 | x = keras.layers.LeakyReLU()(x) 24 | x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x) 25 | x = keras.layers.BatchNormalization()(x) 26 | x = keras.layers.LeakyReLU()(x) 27 | x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x) 28 | x = keras.layers.BatchNormalization()(x) 29 | x = keras.layers.LeakyReLU()(x) 30 | x = keras.layers.Flatten()(x) 31 | x = keras.layers.Dense(units=256, activation='linear')(x) 32 | x = keras.layers.BatchNormalization()(x) 33 | x = keras.layers.LeakyReLU()(x) 34 | x = keras.layers.Dense(units=code_size, activation='linear', name='encoder_embedding')(x) 35 | 36 | return x 37 | 38 | 39 | def network_autoregressive(x): 40 | 41 | ''' Define the network that integrates information along the sequence ''' 42 | 43 | # x = keras.layers.GRU(units=256, return_sequences=True)(x) 44 | # x = keras.layers.BatchNormalization()(x) 45 | x = keras.layers.GRU(units=256, return_sequences=False, name='ar_context')(x) 46 | 47 | return x 48 | 49 | 50 | def network_prediction(context, code_size, predict_terms): 51 | 52 | ''' Define the network mapping context to multiple embeddings ''' 53 | 54 | outputs = [] 55 | for i in range(predict_terms): 56 | outputs.append(keras.layers.Dense(units=code_size, activation="linear", name='z_t_{i}'.format(i=i))(context)) 57 | 58 | if len(outputs) == 1: 59 | output = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(outputs[0]) 60 | else: 61 | output = keras.layers.Lambda(lambda x: K.stack(x, axis=1))(outputs) 62 | 63 | return output 64 | 65 | 66 | class CPCLayer(keras.layers.Layer): 67 | 68 | ''' Computes dot product between true and predicted embedding vectors ''' 69 | 70 | def __init__(self, **kwargs): 71 | super(CPCLayer, self).__init__(**kwargs) 72 | 73 | def call(self, inputs): 74 | 75 | # Compute dot product among vectors 76 | preds, y_encoded = inputs 77 | dot_product = K.mean(y_encoded * preds, axis=-1) 78 | dot_product = K.mean(dot_product, axis=-1, keepdims=True) # average along the temporal dimension 79 | 80 | # Keras loss functions take probabilities 81 | dot_product_probs = K.sigmoid(dot_product) 82 | 83 | return dot_product_probs 84 | 85 | def compute_output_shape(self, input_shape): 86 | return (input_shape[0][0], 1) 87 | 88 | 89 | def network_cpc(image_shape, terms, predict_terms, code_size, learning_rate): 90 | 91 | ''' Define the CPC network combining encoder and autoregressive model ''' 92 | 93 | # Set learning phase (https://stackoverflow.com/questions/42969779/keras-error-you-must-feed-a-value-for-placeholder-tensor-bidirectional-1-keras) 94 | K.set_learning_phase(1) 95 | 96 | # Define encoder model 97 | encoder_input = keras.layers.Input(image_shape) 98 | encoder_output = network_encoder(encoder_input, code_size) 99 | encoder_model = keras.models.Model(encoder_input, encoder_output, name='encoder') 100 | encoder_model.summary() 101 | 102 | # Define rest of model 103 | x_input = keras.layers.Input((terms, image_shape[0], image_shape[1], image_shape[2])) 104 | x_encoded = keras.layers.TimeDistributed(encoder_model)(x_input) 105 | context = network_autoregressive(x_encoded) 106 | preds = network_prediction(context, code_size, predict_terms) 107 | 108 | y_input = keras.layers.Input((predict_terms, image_shape[0], image_shape[1], image_shape[2])) 109 | y_encoded = keras.layers.TimeDistributed(encoder_model)(y_input) 110 | 111 | # Loss 112 | dot_product_probs = CPCLayer()([preds, y_encoded]) 113 | 114 | # Model 115 | cpc_model = keras.models.Model(inputs=[x_input, y_input], outputs=dot_product_probs) 116 | 117 | # Compile model 118 | cpc_model.compile( 119 | optimizer=keras.optimizers.Adam(lr=learning_rate), 120 | loss='binary_crossentropy', 121 | metrics=['binary_accuracy'] 122 | ) 123 | cpc_model.summary() 124 | 125 | return cpc_model 126 | 127 | 128 | def train_model(epochs, batch_size, output_dir, code_size, lr=1e-4, terms=4, predict_terms=4, image_size=28, color=False): 129 | 130 | # Prepare data 131 | train_data = SortedNumberGenerator(batch_size=batch_size, subset='train', terms=terms, 132 | positive_samples=batch_size // 2, predict_terms=predict_terms, 133 | image_size=image_size, color=color, rescale=True) 134 | 135 | validation_data = SortedNumberGenerator(batch_size=batch_size, subset='valid', terms=terms, 136 | positive_samples=batch_size // 2, predict_terms=predict_terms, 137 | image_size=image_size, color=color, rescale=True) 138 | 139 | # Prepares the model 140 | model = network_cpc(image_shape=(image_size, image_size, 3), terms=terms, predict_terms=predict_terms, 141 | code_size=code_size, learning_rate=lr) 142 | 143 | # Callbacks 144 | callbacks = [keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=1/3, patience=2, min_lr=1e-4)] 145 | 146 | # Trains the model 147 | model.fit_generator( 148 | generator=train_data, 149 | steps_per_epoch=len(train_data), 150 | validation_data=validation_data, 151 | validation_steps=len(validation_data), 152 | epochs=epochs, 153 | verbose=1, 154 | callbacks=callbacks 155 | ) 156 | 157 | # Saves the model 158 | # Remember to add custom_objects={'CPCLayer': CPCLayer} to load_model when loading from disk 159 | model.save(join(output_dir, 'cpc.h5')) 160 | 161 | # Saves the encoder alone 162 | encoder = model.layers[1].layer 163 | encoder.save(join(output_dir, 'encoder.h5')) 164 | 165 | 166 | if __name__ == "__main__": 167 | 168 | train_model( 169 | epochs=10, 170 | batch_size=32, 171 | output_dir='models/64x64', 172 | code_size=128, 173 | lr=1e-3, 174 | terms=4, 175 | predict_terms=4, 176 | image_size=64, 177 | color=True 178 | ) 179 | 180 | --------------------------------------------------------------------------------