├── README.md ├── augmentation ├── autoaugment.py ├── gaussian_filter.py ├── simclr_augment.py └── transforms.py ├── config.yaml ├── data └── create_tf_records.ipynb ├── feature_eval └── contrastive_feature_eval.ipynb ├── models ├── cnn_small.py └── resnet_simclr.py ├── train.py └── utils ├── helpers.py └── losses.py /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations 2 | 3 | ## Still under development!! 4 | 5 | ### Blog post with full documentation: [Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations](https://sthalles.github.io/simple-self-supervised-learning/) 6 | 7 | #### For a Pytorch Implementation: [PyTorch SimCLR](https://github.com/sthalles/SimCLR) 8 | 9 | ![Image of SimCLR Arch](https://sthalles.github.io/assets/contrastive-self-supervised/cover.png) 10 | 11 | ## Dependencies 12 | 13 | - tensorflow 2.x 14 | 15 | ## Config file 16 | 17 | Before running SimCLR, make sure you choose the correct running configurations on the ```config.yaml``` file. 18 | 19 | ```yaml 20 | batch_size: 256 # A batch size of N, produces 2 * (N-1) negative samples. Original implementation uses a batch size of 8192 21 | out_dim: 64 # Output dimensionality of the embedding vector z. Original implementation uses 2048 22 | s: 1 23 | temperature: 0.5 # Temperature parameter for the contrastive objective 24 | base_convnet: "resnet18" # The ConvNet base model. Choose one of: "resnet18 or resnet50". Original implementation uses resnet50 25 | use_cosine_similarity: True # Distance metric for contrastive loss. If False, uses dot product 26 | epochs: 40 # Number of epochs to train 27 | num_workers: 4 # Number of workers for the data loader 28 | ``` 29 | 30 | ## Feature Evaluation 31 | 32 | Feature evaluation is done using a linear model protocol. Feature are learned using the ```STL10 unsupervised``` set and evaluated in the train/test splits; 33 | 34 | Check the ```feature_eval/FeatureEvaluation.ipynb``` notebook for reproducebility. 35 | 36 | | Feature Extractor | Method | Architecture | Top 1 | 37 | |:-------------------:|:------------:|:------------:|:-----:| 38 | | Logistic Regression | PCA Features | - | - | 39 | | KNN | PCA Features | - | - | 40 | | Logistic Regression | SimCLR | ResNet-18 | - | 41 | | KNN | SimCLR | ResNet-18 | - | 42 | 43 | ## Download pre-trained model 44 | 45 | --- 46 | -------------------------------------------------------------------------------- /augmentation/autoaugment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """AutoAugment and RandAugment policies for enhanced image preprocessing. 16 | 17 | AutoAugment Reference: https://arxiv.org/abs/1805.09501 18 | RandAugment Reference: https://arxiv.org/abs/1909.13719 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | import six 24 | import inspect 25 | import math 26 | import tensorflow.compat.v1 as tf 27 | import tensorflow_addons as tfa 28 | 29 | # from tensorflow.contrib import training as contrib_training 30 | 31 | # This signifies the max integer that the controller RNN could predict for the 32 | # augmentation scheme. 33 | _MAX_LEVEL = 10. 34 | 35 | 36 | def policy_v0(): 37 | """Autoaugment policy that was used in AutoAugment Paper.""" 38 | # Each tuple is an augmentation operation of the form 39 | # (operation, probability, magnitude). Each element in policy is a 40 | # sub-policy that will be applied sequentially on the image. 41 | policy = [ 42 | [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], 43 | [('Color', 0.4, 9), ('Equalize', 0.6, 3)], 44 | [('Color', 0.4, 1), ('Rotate', 0.6, 8)], 45 | [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], 46 | [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], 47 | [('Color', 0.2, 0), ('Equalize', 0.8, 8)], 48 | [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], 49 | [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], 50 | [('Color', 0.6, 1), ('Equalize', 1.0, 2)], 51 | [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], 52 | [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], 53 | [('Color', 0.4, 7), ('Equalize', 0.6, 0)], 54 | [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], 55 | [('Solarize', 0.6, 8), ('Color', 0.6, 9)], 56 | [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], 57 | [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)], 58 | [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], 59 | [('ShearY', 0.8, 0), ('Color', 0.6, 4)], 60 | [('Color', 1.0, 0), ('Rotate', 0.6, 2)], 61 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], 62 | [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], 63 | [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], 64 | [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], 65 | [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], 66 | [('Color', 0.8, 6), ('Rotate', 0.4, 5)], 67 | ] 68 | return policy 69 | 70 | 71 | def policy_vtest(): 72 | """Autoaugment test policy for debugging.""" 73 | # Each tuple is an augmentation operation of the form 74 | # (operation, probability, magnitude). Each element in policy is a 75 | # sub-policy that will be applied sequentially on the image. 76 | policy = [ 77 | [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)], 78 | ] 79 | return policy 80 | 81 | 82 | def blend(image1, image2, factor): 83 | """Blend image1 and image2 using 'factor'. 84 | 85 | Factor can be above 0.0. A value of 0.0 means only image1 is used. 86 | A value of 1.0 means only image2 is used. A value between 0.0 and 87 | 1.0 means we linearly interpolate the pixel values between the two 88 | images. A value greater than 1.0 "extrapolates" the difference 89 | between the two pixel values, and we clip the results to values 90 | between 0 and 255. 91 | 92 | Args: 93 | image1: An image Tensor of type uint8. 94 | image2: An image Tensor of type uint8. 95 | factor: A floating point value above 0.0. 96 | 97 | Returns: 98 | A blended image Tensor of type uint8. 99 | """ 100 | if factor == 0.0: 101 | return tf.convert_to_tensor(image1) 102 | if factor == 1.0: 103 | return tf.convert_to_tensor(image2) 104 | 105 | image1 = tf.cast(image1, tf.float32) 106 | image2 = tf.cast(image2, tf.float32) 107 | 108 | difference = image2 - image1 109 | scaled = factor * difference 110 | 111 | # Do addition in float. 112 | temp = tf.cast(image1, tf.float32) + scaled 113 | 114 | # Interpolate 115 | if factor > 0.0 and factor < 1.0: 116 | # Interpolation means we always stay within 0 and 255. 117 | return tf.cast(temp, tf.uint8) 118 | 119 | # Extrapolate: 120 | # 121 | # We need to clip and then cast. 122 | return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8) 123 | 124 | 125 | def cutout(image, pad_size, replace=0): 126 | """Apply cutout (https://arxiv.org/abs/1708.04552) to image. 127 | 128 | This operation applies a (2*pad_size x 2*pad_size) mask of zeros to 129 | a random location within `img`. The pixel values filled in will be of the 130 | value `replace`. The located where the mask will be applied is randomly 131 | chosen uniformly over the whole image. 132 | 133 | Args: 134 | image: An image Tensor of type uint8. 135 | pad_size: Specifies how big the zero mask that will be generated is that 136 | is applied to the image. The mask will be of size 137 | (2*pad_size x 2*pad_size). 138 | replace: What pixel value to fill in the image in the area that has 139 | the cutout mask applied to it. 140 | 141 | Returns: 142 | An image Tensor that is of type uint8. 143 | """ 144 | image_height = tf.shape(image)[0] 145 | image_width = tf.shape(image)[1] 146 | 147 | # Sample the center location in the image where the zero mask will be applied. 148 | cutout_center_height = tf.random_uniform( 149 | shape=[], minval=0, maxval=image_height, 150 | dtype=tf.int32) 151 | 152 | cutout_center_width = tf.random_uniform( 153 | shape=[], minval=0, maxval=image_width, 154 | dtype=tf.int32) 155 | 156 | lower_pad = tf.maximum(0, cutout_center_height - pad_size) 157 | upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) 158 | left_pad = tf.maximum(0, cutout_center_width - pad_size) 159 | right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) 160 | 161 | cutout_shape = [image_height - (lower_pad + upper_pad), 162 | image_width - (left_pad + right_pad)] 163 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] 164 | mask = tf.pad( 165 | tf.zeros(cutout_shape, dtype=image.dtype), 166 | padding_dims, constant_values=1) 167 | mask = tf.expand_dims(mask, -1) 168 | mask = tf.tile(mask, [1, 1, 3]) 169 | image = tf.where( 170 | tf.equal(mask, 0), 171 | tf.ones_like(image, dtype=image.dtype) * replace, 172 | image) 173 | return image 174 | 175 | 176 | def solarize(image, threshold=128): 177 | # For each pixel in the image, select the pixel 178 | # if the value is less than the threshold. 179 | # Otherwise, subtract 255 from the pixel. 180 | return tf.where(image < threshold, image, 255 - image) 181 | 182 | 183 | def solarize_add(image, addition=0, threshold=128): 184 | # For each pixel in the image less than threshold 185 | # we add 'addition' amount to it and then clip the 186 | # pixel value to be between 0 and 255. The value 187 | # of 'addition' is between -128 and 128. 188 | added_image = tf.cast(image, tf.int64) + addition 189 | added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) 190 | return tf.where(image < threshold, added_image, image) 191 | 192 | 193 | def color(image, factor): 194 | """Equivalent of PIL Color.""" 195 | degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) 196 | return blend(degenerate, image, factor) 197 | 198 | 199 | def contrast(image, factor): 200 | """Equivalent of PIL Contrast.""" 201 | degenerate = tf.image.rgb_to_grayscale(image) 202 | # Cast before calling tf.histogram. 203 | degenerate = tf.cast(degenerate, tf.int32) 204 | 205 | # Compute the grayscale histogram, then compute the mean pixel value, 206 | # and create a constant image size of that value. Use that as the 207 | # blending degenerate target of the original image. 208 | hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) 209 | mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 210 | degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean 211 | degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) 212 | degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) 213 | return blend(degenerate, image, factor) 214 | 215 | 216 | def brightness(image, factor): 217 | """Equivalent of PIL Brightness.""" 218 | degenerate = tf.zeros_like(image) 219 | return blend(degenerate, image, factor) 220 | 221 | 222 | def posterize(image, bits): 223 | """Equivalent of PIL Posterize.""" 224 | shift = 8 - bits 225 | return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) 226 | 227 | 228 | def rotate(image, degrees, replace): 229 | """Rotates the image by degrees either clockwise or counterclockwise. 230 | 231 | Args: 232 | image: An image Tensor of type uint8. 233 | degrees: Float, a scalar angle in degrees to rotate all images by. If 234 | degrees is positive the image will be rotated clockwise otherwise it will 235 | be rotated counterclockwise. 236 | replace: A one or three value 1D tensor to fill empty pixels caused by 237 | the rotate operation. 238 | 239 | Returns: 240 | The rotated version of image. 241 | """ 242 | # Convert from degrees to radians. 243 | degrees_to_radians = math.pi / 180.0 244 | radians = degrees * degrees_to_radians 245 | 246 | # In practice, we should randomize the rotation degrees by flipping 247 | # it negatively half the time, but that's done on 'degrees' outside 248 | # of the function. 249 | image = tfa.image.rotate(wrap(image), radians) 250 | return unwrap(image, replace) 251 | 252 | 253 | def translate_x(image, pixels, replace): 254 | """Equivalent of PIL Translate in X dimension.""" 255 | image = tfa.image.translate(wrap(image), [-pixels, 0]) 256 | return unwrap(image, replace) 257 | 258 | 259 | def translate_y(image, pixels, replace): 260 | """Equivalent of PIL Translate in Y dimension.""" 261 | image = tfa.image.translate(wrap(image), [0, -pixels]) 262 | return unwrap(image, replace) 263 | 264 | 265 | def shear_x(image, level, replace): 266 | """Equivalent of PIL Shearing in X dimension.""" 267 | # Shear parallel to x axis is a projective transform 268 | # with a matrix form of: 269 | # [1 level 270 | # 0 1]. 271 | image = tfa.image.transform( 272 | wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) 273 | return unwrap(image, replace) 274 | 275 | 276 | def shear_y(image, level, replace): 277 | """Equivalent of PIL Shearing in Y dimension.""" 278 | # Shear parallel to y axis is a projective transform 279 | # with a matrix form of: 280 | # [1 0 281 | # level 1]. 282 | image = tfa.image.transform( 283 | wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) 284 | return unwrap(image, replace) 285 | 286 | 287 | def autocontrast(image): 288 | """Implements Autocontrast function from PIL using TF ops. 289 | 290 | Args: 291 | image: A 3D uint8 tensor. 292 | 293 | Returns: 294 | The image after it has had autocontrast applied to it and will be of type 295 | uint8. 296 | """ 297 | 298 | def scale_channel(image): 299 | """Scale the 2D image using the autocontrast rule.""" 300 | # A possibly cheaper version can be done using cumsum/unique_with_counts 301 | # over the histogram values, rather than iterating over the entire image. 302 | # to compute mins and maxes. 303 | lo = tf.cast(tf.reduce_min(image), tf.float32) 304 | hi = tf.cast(tf.reduce_max(image), tf.float32) 305 | 306 | # Scale the image, making the lowest value 0 and the highest value 255. 307 | def scale_values(im): 308 | scale = 255.0 / (hi - lo) 309 | offset = -lo * scale 310 | im = tf.cast(im, tf.float32) * scale + offset 311 | im = tf.clip_by_value(im, 0.0, 255.0) 312 | return tf.cast(im, tf.uint8) 313 | 314 | result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) 315 | return result 316 | 317 | # Assumes RGB for now. Scales each channel independently 318 | # and then stacks the result. 319 | s1 = scale_channel(image[:, :, 0]) 320 | s2 = scale_channel(image[:, :, 1]) 321 | s3 = scale_channel(image[:, :, 2]) 322 | image = tf.stack([s1, s2, s3], 2) 323 | return image 324 | 325 | 326 | def sharpness(image, factor): 327 | """Implements Sharpness function from PIL using TF ops.""" 328 | orig_image = image 329 | image = tf.cast(image, tf.float32) 330 | # Make image 4D for conv operation. 331 | image = tf.expand_dims(image, 0) 332 | # SMOOTH PIL Kernel. 333 | kernel = tf.constant( 334 | [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, 335 | shape=[3, 3, 1, 1]) / 13. 336 | # Tile across channel dimension. 337 | kernel = tf.tile(kernel, [1, 1, 3, 1]) 338 | strides = [1, 1, 1, 1] 339 | degenerate = tf.nn.depthwise_conv2d( 340 | image, kernel, strides, padding='VALID', rate=[1, 1]) 341 | degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) 342 | degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) 343 | 344 | # For the borders of the resulting image, fill in the values of the 345 | # original image. 346 | mask = tf.ones_like(degenerate) 347 | padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) 348 | padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) 349 | result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) 350 | 351 | # Blend the final result. 352 | return blend(result, orig_image, factor) 353 | 354 | 355 | def equalize(image): 356 | """Implements Equalize function from PIL using TF ops.""" 357 | 358 | def scale_channel(im, c): 359 | """Scale the data in the channel to implement equalize.""" 360 | im = tf.cast(im[:, :, c], tf.int32) 361 | # Compute the histogram of the image channel. 362 | histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) 363 | 364 | # For the purposes of computing the step, filter out the nonzeros. 365 | nonzero = tf.where(tf.not_equal(histo, 0)) 366 | nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) 367 | step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 368 | 369 | def build_lut(histo, step): 370 | # Compute the cumulative sum, shifting by step // 2 371 | # and then normalization by step. 372 | lut = (tf.cumsum(histo) + (step // 2)) // step 373 | # Shift lut, prepending with 0. 374 | lut = tf.concat([[0], lut[:-1]], 0) 375 | # Clip the counts to be in range. This is done 376 | # in the C code for image.point. 377 | return tf.clip_by_value(lut, 0, 255) 378 | 379 | # If step is zero, return the original image. Otherwise, build 380 | # lut from the full histogram and step and then index from it. 381 | result = tf.cond(tf.equal(step, 0), 382 | lambda: im, 383 | lambda: tf.gather(build_lut(histo, step), im)) 384 | 385 | return tf.cast(result, tf.uint8) 386 | 387 | # Assumes RGB for now. Scales each channel independently 388 | # and then stacks the result. 389 | s1 = scale_channel(image, 0) 390 | s2 = scale_channel(image, 1) 391 | s3 = scale_channel(image, 2) 392 | image = tf.stack([s1, s2, s3], 2) 393 | return image 394 | 395 | 396 | def invert(image): 397 | """Inverts the image pixels.""" 398 | image = tf.convert_to_tensor(image) 399 | return 255 - image 400 | 401 | 402 | def wrap(image): 403 | """Returns 'image' with an extra channel set to all 1s.""" 404 | shape = tf.shape(image) 405 | extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) 406 | extended = tf.concat([image, extended_channel], 2) 407 | return extended 408 | 409 | 410 | def unwrap(image, replace): 411 | """Unwraps an image produced by wrap. 412 | 413 | Where there is a 0 in the last channel for every spatial position, 414 | the rest of the three channels in that spatial dimension are grayed 415 | (set to 128). Operations like translate and shear on a wrapped 416 | Tensor will leave 0s in empty locations. Some transformations look 417 | at the intensity of values to do preprocessing, and we want these 418 | empty pixels to assume the 'average' value, rather than pure black. 419 | 420 | 421 | Args: 422 | image: A 3D Image Tensor with 4 channels. 423 | replace: A one or three value 1D tensor to fill empty pixels. 424 | 425 | Returns: 426 | image: A 3D image Tensor with 3 channels. 427 | """ 428 | image_shape = tf.shape(image) 429 | # Flatten the spatial dimensions. 430 | flattened_image = tf.reshape(image, [-1, image_shape[2]]) 431 | 432 | # Find all pixels where the last channel is zero. 433 | alpha_channel = flattened_image[:, 3] 434 | 435 | replace = tf.concat([replace, tf.ones([1], image.dtype)], 0) 436 | 437 | # Where they are zero, fill them in with 'replace'. 438 | flattened_image = tf.where( 439 | tf.equal(alpha_channel, 0), 440 | tf.ones_like(flattened_image, dtype=image.dtype) * replace, 441 | flattened_image) 442 | 443 | image = tf.reshape(flattened_image, image_shape) 444 | image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) 445 | return image 446 | 447 | 448 | NAME_TO_FUNC = { 449 | 'AutoContrast': autocontrast, 450 | 'Equalize': equalize, 451 | 'Invert': invert, 452 | 'Rotate': rotate, 453 | 'Posterize': posterize, 454 | 'Solarize': solarize, 455 | 'SolarizeAdd': solarize_add, 456 | 'Color': color, 457 | 'Contrast': contrast, 458 | 'Brightness': brightness, 459 | 'Sharpness': sharpness, 460 | 'ShearX': shear_x, 461 | 'ShearY': shear_y, 462 | 'TranslateX': translate_x, 463 | 'TranslateY': translate_y, 464 | 'Cutout': cutout, 465 | } 466 | 467 | 468 | def _randomly_negate_tensor(tensor): 469 | """With 50% prob turn the tensor negative.""" 470 | should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool) 471 | final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) 472 | return final_tensor 473 | 474 | 475 | def _rotate_level_to_arg(level): 476 | level = (level / _MAX_LEVEL) * 30. 477 | level = _randomly_negate_tensor(level) 478 | return (level,) 479 | 480 | 481 | def _shrink_level_to_arg(level): 482 | """Converts level to ratio by which we shrink the image content.""" 483 | if level == 0: 484 | return (1.0,) # if level is zero, do not shrink the image 485 | # Maximum shrinking ratio is 2.9. 486 | level = 2. / (_MAX_LEVEL / level) + 0.9 487 | return (level,) 488 | 489 | 490 | def _enhance_level_to_arg(level): 491 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 492 | 493 | 494 | def _shear_level_to_arg(level): 495 | level = (level / _MAX_LEVEL) * 0.3 496 | # Flip level to negative with 50% chance. 497 | level = _randomly_negate_tensor(level) 498 | return (level,) 499 | 500 | 501 | def _translate_level_to_arg(level, translate_const): 502 | level = (level / _MAX_LEVEL) * float(translate_const) 503 | # Flip level to negative with 50% chance. 504 | level = _randomly_negate_tensor(level) 505 | return (level,) 506 | 507 | 508 | def level_to_arg(hparams): 509 | return { 510 | 'AutoContrast': lambda level: (), 511 | 'Equalize': lambda level: (), 512 | 'Invert': lambda level: (), 513 | 'Rotate': _rotate_level_to_arg, 514 | 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), 515 | 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), 516 | 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), 517 | 'Color': _enhance_level_to_arg, 518 | 'Contrast': _enhance_level_to_arg, 519 | 'Brightness': _enhance_level_to_arg, 520 | 'Sharpness': _enhance_level_to_arg, 521 | 'ShearX': _shear_level_to_arg, 522 | 'ShearY': _shear_level_to_arg, 523 | 'Cutout': lambda level: (int((level / _MAX_LEVEL) * hparams.cutout_const),), 524 | # pylint:disable=g-long-lambda 525 | 'TranslateX': lambda level: _translate_level_to_arg( 526 | level, hparams.translate_const), 527 | 'TranslateY': lambda level: _translate_level_to_arg( 528 | level, hparams.translate_const), 529 | # pylint:enable=g-long-lambda 530 | } 531 | 532 | 533 | def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): 534 | """Return the function that corresponds to `name` and update `level` param.""" 535 | func = NAME_TO_FUNC[name] 536 | args = level_to_arg(augmentation_hparams)[name](level) 537 | 538 | # Check to see if prob is passed into function. This is used for operations 539 | # where we alter bboxes independently. 540 | # pytype:disable=wrong-arg-types 541 | if 'prob' in inspect.getargspec(func)[0]: 542 | args = tuple([prob] + list(args)) 543 | # pytype:enable=wrong-arg-types 544 | 545 | # Add in replace arg if it is required for the function that is being called. 546 | if 'replace' in inspect.getargspec(func)[0]: 547 | # Make sure replace is the final argument 548 | assert 'replace' == inspect.getargspec(func)[0][-1] 549 | args = tuple(list(args) + [replace_value]) 550 | 551 | return (func, prob, args) 552 | 553 | 554 | def _apply_func_with_prob(func, image, args, prob): 555 | """Apply `func` to image w/ `args` as input with probability `prob`.""" 556 | assert isinstance(args, tuple) 557 | 558 | # If prob is a function argument, then this randomness is being handled 559 | # inside the function, so make sure it is always called. 560 | if 'prob' in inspect.getargspec(func)[0]: 561 | prob = 1.0 562 | 563 | # Apply the function with probability `prob`. 564 | should_apply_op = tf.cast( 565 | tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool) 566 | augmented_image = tf.cond( 567 | should_apply_op, 568 | lambda: func(image, *args), 569 | lambda: image) 570 | return augmented_image 571 | 572 | 573 | def select_and_apply_random_policy(policies, image): 574 | """Select a random policy from `policies` and apply it to `image`.""" 575 | policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) 576 | # Note that using tf.case instead of tf.conds would result in significantly 577 | # larger graphs and would even break export for some larger policies. 578 | for (i, policy) in enumerate(policies): 579 | image = tf.cond( 580 | tf.equal(i, policy_to_select), 581 | lambda selected_policy=policy: selected_policy(image), 582 | lambda: image) 583 | return image 584 | 585 | 586 | def build_and_apply_nas_policy(policies, image, 587 | augmentation_hparams): 588 | """Build a policy from the given policies passed in and apply to image. 589 | 590 | Args: 591 | policies: list of lists of tuples in the form `(func, prob, level)`, `func` 592 | is a string name of the augmentation function, `prob` is the probability 593 | of applying the `func` operation, `level` is the input argument for 594 | `func`. 595 | image: tf.Tensor that the resulting policy will be applied to. 596 | augmentation_hparams: Hparams associated with the NAS learned policy. 597 | 598 | Returns: 599 | A version of image that now has data augmentation applied to it based on 600 | the `policies` pass into the function. 601 | """ 602 | replace_value = [128, 128, 128] 603 | 604 | # func is the string name of the augmentation function, prob is the 605 | # probability of applying the operation and level is the parameter associated 606 | # with the tf op. 607 | 608 | # tf_policies are functions that take in an image and return an augmented 609 | # image. 610 | tf_policies = [] 611 | for policy in policies: 612 | tf_policy = [] 613 | # Link string name to the correct python function and make sure the correct 614 | # argument is passed into that function. 615 | for policy_info in policy: 616 | policy_info = list(policy_info) + [replace_value, augmentation_hparams] 617 | 618 | tf_policy.append(_parse_policy_info(*policy_info)) 619 | 620 | # Now build the tf policy that will apply the augmentation procedue 621 | # on image. 622 | def make_final_policy(tf_policy_): 623 | def final_policy(image_): 624 | for func, prob, args in tf_policy_: 625 | image_ = _apply_func_with_prob( 626 | func, image_, args, prob) 627 | return image_ 628 | 629 | return final_policy 630 | 631 | tf_policies.append(make_final_policy(tf_policy)) 632 | 633 | augmented_image = select_and_apply_random_policy( 634 | tf_policies, image) 635 | return augmented_image 636 | 637 | 638 | def distort_image_with_autoaugment(image, augmentation_name): 639 | """Applies the AutoAugment policy to `image`. 640 | 641 | AutoAugment is from the paper: https://arxiv.org/abs/1805.09501. 642 | 643 | Args: 644 | image: `Tensor` of shape [height, width, 3] representing an image. 645 | augmentation_name: The name of the AutoAugment policy to use. The available 646 | options are `v0` and `test`. `v0` is the policy used for 647 | all of the results in the paper and was found to achieve the best results 648 | on the COCO dataset. `v1`, `v2` and `v3` are additional good policies 649 | found on the COCO dataset that have slight variation in what operations 650 | were used during the search procedure along with how many operations are 651 | applied in parallel to a single image (2 vs 3). 652 | 653 | Returns: 654 | A tuple containing the augmented versions of `image`. 655 | """ 656 | available_policies = {'v0': policy_v0, 657 | 'test': policy_vtest} 658 | if augmentation_name not in available_policies: 659 | raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name)) 660 | 661 | policy = available_policies[augmentation_name]() 662 | # Hparams that will be used for AutoAugment. 663 | augmentation_hparams = HParams(cutout_const=100, translate_const=250) 664 | 665 | return build_and_apply_nas_policy(policy, image, augmentation_hparams) 666 | 667 | 668 | class HParams(object): 669 | def __init__(self, **kwargs): 670 | self._hparam_types = {} 671 | for name, value in six.iteritems(kwargs): 672 | self.add_hparam(name, value) 673 | 674 | def add_hparam(self, name, value): 675 | """Adds {name, value} pair to hyperparameters. 676 | Args: 677 | name: Name of the hyperparameter. 678 | value: Value of the hyperparameter. Can be one of the following types: 679 | int, float, string, int list, float list, or string list. 680 | Raises: 681 | ValueError: if one of the arguments is invalid. 682 | """ 683 | # Keys in kwargs are unique, but 'name' could the name of a pre-existing 684 | # attribute of this object. In that case we refuse to use it as a 685 | # hyperparameter name. 686 | if getattr(self, name, None) is not None: 687 | raise ValueError('Hyperparameter name is reserved: %s' % name) 688 | if isinstance(value, (list, tuple)): 689 | if not value: 690 | raise ValueError( 691 | 'Multi-valued hyperparameters cannot be empty: %s' % name) 692 | self._hparam_types[name] = (type(value[0]), True) 693 | else: 694 | self._hparam_types[name] = (type(value), False) 695 | setattr(self, name, value) 696 | 697 | 698 | def distort_image_with_randaugment(image, num_layers, magnitude): 699 | """Applies the RandAugment policy to `image`. 700 | 701 | RandAugment is from the paper https://arxiv.org/abs/1909.13719, 702 | 703 | Args: 704 | image: `Tensor` of shape [height, width, 3] representing an image. 705 | num_layers: Integer, the number of augmentation transformations to apply 706 | sequentially to an image. Represented as (N) in the paper. Usually best 707 | values will be in the range [1, 3]. 708 | magnitude: Integer, shared magnitude across all augmentation operations. 709 | Represented as (M) in the paper. Usually best values are in the range 710 | [5, 30]. 711 | 712 | Returns: 713 | The augmented version of `image`. 714 | """ 715 | replace_value = [128] * 3 716 | tf.logging.info('Using RandAug.') 717 | 718 | augmentation_hparams = HParams(cutout_const=40//7, translate_const=100//7) 719 | 720 | available_ops = [ 721 | 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 722 | 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness', 723 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'] 724 | 725 | for layer_num in range(num_layers): 726 | op_to_select = tf.random_uniform( 727 | [], maxval=len(available_ops), dtype=tf.int32) 728 | random_magnitude = float(magnitude) 729 | with tf.name_scope('randaug_layer_{}'.format(layer_num)): 730 | for (i, op_name) in enumerate(available_ops): 731 | prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32) 732 | func, _, args = _parse_policy_info(op_name, prob, random_magnitude, 733 | replace_value, augmentation_hparams) 734 | image = tf.cond( 735 | tf.equal(i, op_to_select), 736 | # pylint:disable=g-long-lambda 737 | lambda selected_func=func, selected_args=args: selected_func( 738 | image, *selected_args), 739 | # pylint:enable=g-long-lambda 740 | lambda: image) 741 | return image 742 | -------------------------------------------------------------------------------- /augmentation/gaussian_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | class GaussianBlur(object): 6 | # Implements Gaussian blur as described in the SimCLR paper 7 | def __init__(self, kernel_size, min=0.1, max=2.0): 8 | self.min = min 9 | self.max = max 10 | # kernel size is set to be 10% of the image height/width 11 | self.kernel_size = kernel_size 12 | 13 | def __call__(self, sample): 14 | sample = np.array(sample) 15 | 16 | # blur the image with a 50% chance 17 | prob = np.random.random_sample() 18 | 19 | if prob < 0.5: 20 | sigma = (self.max - self.min) * np.random.random_sample() + self.min 21 | sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) 22 | 23 | return sample 24 | -------------------------------------------------------------------------------- /augmentation/simclr_augment.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def color_distortion(image, s=1.0): 5 | # image is a tensor with value range in [0, 1]. 6 | # s is the strength of color distortion. 7 | 8 | def color_jitter(x): 9 | # one can also shuffle the order of following augmentations 10 | # each time they are applied. 11 | x = tf.image.random_brightness(x, max_delta=0.8 * s) 12 | x = tf.image.random_contrast(x, lower=1 - 0.8 * s, upper=1 + 0.8 * s) 13 | x = tf.image.random_saturation(x, lower=1 - 0.8 * s, upper=1 + 0.8 * s) 14 | x = tf.image.random_hue(x, max_delta=0.2 * s) 15 | x = tf.clip_by_value(x, 0, 1) 16 | return x 17 | 18 | def color_drop(x): 19 | x = tf.image.rgb_to_grayscale(x) 20 | x = tf.tile(x, [1, 1, 3]) 21 | return x 22 | 23 | rand_ = tf.random.uniform(shape=(), minval=0, maxval=1) 24 | # randomly apply transformation with probability p. 25 | if rand_ < 0.8: 26 | image = color_jitter(image) 27 | 28 | rand_ = tf.random.uniform(shape=(), minval=0, maxval=1) 29 | if rand_ < 0.2: 30 | image = color_drop(image) 31 | return image 32 | -------------------------------------------------------------------------------- /augmentation/transforms.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from augmentation.autoaugment import distort_image_with_randaugment 3 | from augmentation.simclr_augment import color_distortion 4 | 5 | 6 | def read_images(features): 7 | return features['image'] 8 | 9 | 10 | def distort_simclr(image): 11 | image = tf.cast(image, tf.float32) 12 | v1 = color_distortion(image / 255.) 13 | v2 = color_distortion(image / 255.) 14 | return v1, v2 15 | 16 | 17 | def read_record(record, input_shape): 18 | keys_to_features = { 19 | "image_raw": tf.io.FixedLenFeature((), tf.string, default_value=""), 20 | } 21 | 22 | features = tf.io.parse_single_example(record, keys_to_features) 23 | 24 | image = tf.io.decode_raw(features['image_raw'], tf.uint8) 25 | 26 | # reshape input and annotation images 27 | image = tf.reshape(image, input_shape, name="image_reshape") 28 | return image 29 | 30 | 31 | def distort_with_rand_aug(image): 32 | image = tf.cast(image, dtype=tf.uint8) 33 | v1 = distort_image_with_randaugment(image, num_layers=3, magnitude=5) 34 | v2 = distort_image_with_randaugment(image, num_layers=3, magnitude=5) 35 | return tf.cast(v1, tf.float32) / 255., tf.cast(v2, tf.float32) / 255. 36 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 128 2 | out_dim: 256 3 | s: 1 4 | temperature: 0.5 5 | use_cosine_similarity: True 6 | epochs: 30 7 | input_shape: (96,96,3) -------------------------------------------------------------------------------- /data/create_tf_records.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import numpy as np\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 4, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "def read_stl10_dataset(path_to_data):\n", 21 | " \"\"\"\n", 22 | " :param path_to_data: the file containing the binary images from the STL-10 dataset\n", 23 | " :return: an array containing all the images\n", 24 | " \"\"\"\n", 25 | "\n", 26 | " with open(path_to_data, 'rb') as f:\n", 27 | " # read whole file in uint8 chunks\n", 28 | " everything = np.fromfile(f, dtype=np.uint8)\n", 29 | "\n", 30 | " # We force the data into 3x96x96 chunks, since the\n", 31 | " # images are stored in \"column-major order\", meaning\n", 32 | " # that \"the first 96*96 values are the red channel,\n", 33 | " # the next 96*96 are green, and the last are blue.\"\n", 34 | " # The -1 is since the size of the pictures depends\n", 35 | " # on the input file, and this way numpy determines\n", 36 | " # the size on its own.\n", 37 | "\n", 38 | " images = np.reshape(everything, (-1, 3, 96, 96))\n", 39 | "\n", 40 | " # Now transpose the images into a standard image format\n", 41 | " # readable by, for example, matplotlib.imshow\n", 42 | " # You might want to comment this line or reverse the shuffle\n", 43 | " # if you will use a learning algorithm like CNN, since they like\n", 44 | " # their channels separated.\n", 45 | " images = np.transpose(images, (0, 3, 2, 1))\n", 46 | " return images" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 5, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "data_path = '/home/thalles/Downloads/stl10_binary/unlabeled_X.bin'" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 6, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "images = read_stl10_dataset(data_path)\n", 65 | "print(images.shape)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 7, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "TRAIN_DATASET_DIR=\"./tfrecords/\"\n", 75 | "if not os.path.exists(TRAIN_DATASET_DIR):\n", 76 | " os.mkdir(TRAIN_DATASET_DIR)\n", 77 | " \n", 78 | "TRAIN_FILE = 'train.tfrecords'\n", 79 | "writer = tf.io.TFRecordWriter(os.path.join(TRAIN_DATASET_DIR,TRAIN_FILE))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 8, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "def _bytes_feature(value):\n", 89 | " return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n", 90 | "\n", 91 | "def _int64_feature(value):\n", 92 | " return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 9, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "def create_tfrecord_dataset(images, writer):\n", 102 | "\n", 103 | " # create training tfrecord\n", 104 | " read_imgs_counter = 0\n", 105 | " for i, image in enumerate(images):\n", 106 | " \n", 107 | " read_imgs_counter += 1\n", 108 | " image_h = image.shape[0]\n", 109 | " image_w = image.shape[1]\n", 110 | "\n", 111 | " img_raw = image.tostring()\n", 112 | "\n", 113 | " example = tf.train.Example(features=tf.train.Features(feature={\n", 114 | " 'height': _int64_feature(image_h),\n", 115 | " 'width': _int64_feature(image_w),\n", 116 | " 'image_raw': _bytes_feature(img_raw)}))\n", 117 | "\n", 118 | " writer.write(example.SerializeToString())\n", 119 | " \n", 120 | " print(\"End of TfRecord. Total of image written:\", read_imgs_counter)\n", 121 | " writer.close()" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 10, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "End of TfRecord. Total of image written: 100000\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "create_tfrecord_dataset(images, writer)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "tf2_cpu", 152 | "language": "python", 153 | "name": "tf2_cpu" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.6.10" 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 2 170 | } 171 | -------------------------------------------------------------------------------- /feature_eval/contrastive_feature_eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.insert(1, '../')" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import tensorflow as tf\n", 20 | "import numpy as np\n", 21 | "from sklearn.model_selection import train_test_split\n", 22 | "import os\n", 23 | "import yaml\n", 24 | "from pydoc import locate" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "def get_log_regression():\n", 34 | " return LogisticRegression(random_state=0, max_iter=1000, solver='lbfgs', C=1.0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "## define model id\n", 44 | "model_id = '20200309-230336'\n", 45 | "model_path = os.path.join('../logs', model_id, 'train', 'checkpoints')" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 5, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "def _load_stl10(prefix=\"train\"):\n", 55 | " X_train = np.fromfile('../data/stl10_binary/' + prefix + '_X.bin', dtype=np.uint8)\n", 56 | " y_train = np.fromfile('../data/stl10_binary/' + prefix + '_y.bin', dtype=np.uint8)\n", 57 | "\n", 58 | " X_train = np.reshape(X_train, (-1, 3, 96, 96))\n", 59 | " X_train = np.transpose(X_train, (0, 3, 2, 1))\n", 60 | " print(\"{} images\".format(prefix))\n", 61 | " print(X_train.shape)\n", 62 | " print(y_train.shape)\n", 63 | " y_train -= 1\n", 64 | " return X_train, y_train" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 6, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "train images\n", 77 | "(5000, 96, 96, 3)\n", 78 | "(5000,)\n", 79 | "test images\n", 80 | "(8000, 96, 96, 3)\n", 81 | "(8000,)\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "# load STL-10 train data\n", 87 | "X_train, y_train = _load_stl10(\"train\")\n", 88 | "\n", 89 | "# load STL-10 test data\n", 90 | "X_test, y_test = _load_stl10(\"test\")" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 7, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "batch_size=256" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 8, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "Training data\n", 112 | "(5000, 96, 96, 3)\n", 113 | "(5000,)\n", 114 | "Testing data\n", 115 | "(8000, 96, 96, 3)\n", 116 | "(8000,)\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "print(\"Training data\")\n", 122 | "print(X_train.shape)\n", 123 | "print(y_train.shape)\n", 124 | "print(\"Testing data\")\n", 125 | "print(X_test.shape)\n", 126 | "print(y_test.shape)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 9, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "0 255\n", 139 | "0 255\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "print(np.min(X_train), np.max(X_train))\n", 145 | "print(np.min(X_test), np.max(X_test))" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 10, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "[1 5 1 6 3] [6 7 5 0 3]\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "print(y_train[:5], y_test[:5])\n", 163 | "# [5 1 7 2 1] [7 9 3 8 0]" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "## Test protocol #1 PCA features" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 11, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "from sklearn.decomposition import PCA\n", 180 | "from sklearn.linear_model import LogisticRegression\n", 181 | "from sklearn import preprocessing" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 12, 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "name": "stdout", 191 | "output_type": "stream", 192 | "text": [ 193 | "PCA features\n", 194 | "(5000, 128)\n", 195 | "(8000, 128)\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "scaler = preprocessing.StandardScaler()\n", 201 | "scaler.fit(X_train.reshape((X_train.shape[0], -1)))\n", 202 | "\n", 203 | "pca = PCA(n_components=128)\n", 204 | "\n", 205 | "X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n", 206 | "X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n", 207 | "\n", 208 | "print(\"PCA features\")\n", 209 | "print(X_train_pca.shape)\n", 210 | "print(X_test_pca.shape)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 13, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "name": "stdout", 220 | "output_type": "stream", 221 | "text": [ 222 | "PCA feature evaluation\n", 223 | "Train score: 0.429\n", 224 | "Test score: 0.3615\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "clf = get_log_regression()\n", 230 | "clf.fit(X_train_pca, y_train)\n", 231 | "print(\"PCA feature evaluation\")\n", 232 | "print(\"Train score:\", clf.score(X_train_pca, y_train))\n", 233 | "print(\"Test score:\", clf.score(X_test_pca, y_test))\n", 234 | "\n", 235 | "# PCA feature evaluation\n", 236 | "# Train score: 0.4592\n", 237 | "# Test score: 0.3632" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 14, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "data": { 247 | "text/plain": [ 248 | "'../logs/20200309-230336/train/checkpoints/baseline_encoder.py'" 249 | ] 250 | }, 251 | "execution_count": 14, 252 | "metadata": {}, 253 | "output_type": "execute_result" 254 | } 255 | ], 256 | "source": [ 257 | "os.path.join(model_path, \"baseline_encoder.py\")" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 15, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "from models.cnn_small import SmallCNN\n", 267 | "from models.resnet_simclr import ResNetSimCLR\n", 268 | "model = ResNetSimCLR(input_shape=(96,96,3), out_dim=256)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 16, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "model(tf.ones((1,96,96,3)))\n", 278 | "model.load_weights(os.path.join(model_path, 'model.h5'))" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 17, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "def next_batch(X, y, batch_size):\n", 288 | " for i in range(0, X.shape[0], batch_size):\n", 289 | " X_batch = X[i: i+batch_size] / 255.\n", 290 | " y_batch = y[i: i+batch_size]\n", 291 | " yield X_batch.astype(np.float32), y_batch" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "## Protocol #2 Linear separability evaluation" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 18, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | "Train features\n", 311 | "(5000, 2048)\n" 312 | ] 313 | } 314 | ], 315 | "source": [ 316 | "X_train_feature = []\n", 317 | "\n", 318 | "for batch_x, batch_y in next_batch(X_train, y_train, batch_size):\n", 319 | " features, _ = model(batch_x)\n", 320 | " X_train_feature.extend(features.numpy())\n", 321 | " \n", 322 | "X_train_feature = np.array(X_train_feature)\n", 323 | "\n", 324 | "print(\"Train features\")\n", 325 | "print(X_train_feature.shape)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 19, 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | "Test features\n", 338 | "(8000, 2048)\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "X_test_feature = []\n", 344 | "\n", 345 | "for batch_x, batch_y in next_batch(X_test, y_test, batch_size):\n", 346 | " features, _ = model(batch_x)\n", 347 | " X_test_feature.extend(features.numpy())\n", 348 | " \n", 349 | "X_test_feature = np.array(X_test_feature)\n", 350 | "\n", 351 | "print(\"Test features\")\n", 352 | "print(X_test_feature.shape)" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 20, 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "data": { 362 | "text/plain": [ 363 | "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", 364 | " intercept_scaling=1, l1_ratio=None, max_iter=1000,\n", 365 | " multi_class='auto', n_jobs=None, penalty='l2',\n", 366 | " random_state=0, solver='lbfgs', tol=0.0001, verbose=0,\n", 367 | " warm_start=False)" 368 | ] 369 | }, 370 | "execution_count": 20, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "clf = get_log_regression()\n", 377 | "\n", 378 | "scaler = preprocessing.StandardScaler()\n", 379 | "scaler.fit(X_train_feature)\n", 380 | "\n", 381 | "clf.fit(scaler.transform(X_train_feature), y_train)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 21, 387 | "metadata": {}, 388 | "outputs": [ 389 | { 390 | "name": "stdout", 391 | "output_type": "stream", 392 | "text": [ 393 | "SimCLR feature evaluation\n", 394 | "Train score: 1.0\n", 395 | "Test score: 0.5895\n" 396 | ] 397 | } 398 | ], 399 | "source": [ 400 | "print(\"SimCLR feature evaluation\")\n", 401 | "print(\"Train score:\", clf.score(scaler.transform(X_train_feature), y_train))\n", 402 | "print(\"Test score:\", clf.score(scaler.transform(X_test_feature), y_test))\n", 403 | "\n", 404 | "# SimCLR feature evaluation\n", 405 | "# Train score: 0.5946\n", 406 | "# Test score: 0.5202" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 22, 412 | "metadata": {}, 413 | "outputs": [ 414 | { 415 | "data": { 416 | "text/plain": [ 417 | "array([1, 5, 1, 6, 3, 9, 7, 4, 5, 8, 0, 6, 0, 8, 7], dtype=uint8)" 418 | ] 419 | }, 420 | "execution_count": 22, 421 | "metadata": {}, 422 | "output_type": "execute_result" 423 | } 424 | ], 425 | "source": [ 426 | "y_train[:15]" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 23, 432 | "metadata": {}, 433 | "outputs": [ 434 | { 435 | "data": { 436 | "text/plain": [ 437 | "array([1, 5, 1, 6, 3, 9, 7, 4, 5, 8, 0, 6, 0, 8, 7], dtype=uint8)" 438 | ] 439 | }, 440 | "execution_count": 23, 441 | "metadata": {}, 442 | "output_type": "execute_result" 443 | } 444 | ], 445 | "source": [ 446 | "clf.predict(scaler.transform(X_train_feature))[:15]" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 24, 452 | "metadata": {}, 453 | "outputs": [ 454 | { 455 | "data": { 456 | "text/plain": [ 457 | "array([6, 7, 5, 0, 3, 1, 1, 1, 4, 4, 0, 0, 4, 0, 1], dtype=uint8)" 458 | ] 459 | }, 460 | "execution_count": 24, 461 | "metadata": {}, 462 | "output_type": "execute_result" 463 | } 464 | ], 465 | "source": [ 466 | "y_test[:15]" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 25, 472 | "metadata": {}, 473 | "outputs": [ 474 | { 475 | "data": { 476 | "text/plain": [ 477 | "array([6, 7, 3, 0, 6, 6, 1, 1, 4, 0, 8, 0, 7, 0, 7], dtype=uint8)" 478 | ] 479 | }, 480 | "execution_count": 25, 481 | "metadata": {}, 482 | "output_type": "execute_result" 483 | } 484 | ], 485 | "source": [ 486 | "clf.predict(scaler.transform(X_test_feature))[:15]" 487 | ] 488 | } 489 | ], 490 | "metadata": { 491 | "kernelspec": { 492 | "display_name": "tf2_cpu", 493 | "language": "python", 494 | "name": "tf2_cpu" 495 | }, 496 | "language_info": { 497 | "codemirror_mode": { 498 | "name": "ipython", 499 | "version": 3 500 | }, 501 | "file_extension": ".py", 502 | "mimetype": "text/x-python", 503 | "name": "python", 504 | "nbconvert_exporter": "python", 505 | "pygments_lexer": "ipython3", 506 | "version": "3.6.10" 507 | } 508 | }, 509 | "nbformat": 4, 510 | "nbformat_minor": 2 511 | } 512 | -------------------------------------------------------------------------------- /models/cnn_small.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class SmallCNN(tf.keras.Model): 5 | def __init__(self, out_dim): 6 | super(SmallCNN, self).__init__() 7 | self.conv1 = tf.keras.layers.Conv2D(filters=8, kernel_size=3, padding='same', strides=1) 8 | self.conv2 = tf.keras.layers.Conv2D(filters=16, kernel_size=3, padding='same', strides=1) 9 | self.conv3 = tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', strides=1) 10 | self.conv4 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', strides=1) 11 | 12 | self.l1 = tf.keras.layers.Dense(units=out_dim) 13 | self.l2 = tf.keras.layers.Dense(units=out_dim) 14 | 15 | self.activation = tf.keras.layers.Activation('relu') 16 | self.max_pool = tf.keras.layers.MaxPool2D(pool_size=2, strides=2) 17 | 18 | self.global_pool = tf.keras.layers.GlobalAveragePooling2D() 19 | 20 | def call(self, x): 21 | x = self.conv1(x) 22 | x = self.activation(x) 23 | x = self.max_pool(x) 24 | 25 | x = self.conv2(x) 26 | x = self.activation(x) 27 | x = self.max_pool(x) 28 | 29 | x = self.conv3(x) 30 | x = self.activation(x) 31 | x = self.max_pool(x) 32 | 33 | x = self.conv4(x) 34 | x = self.activation(x) 35 | x = self.max_pool(x) 36 | 37 | h = self.global_pool(x) 38 | 39 | x = self.l1(h) 40 | x = self.activation(x) 41 | x = self.l2(x) 42 | 43 | return h, x 44 | -------------------------------------------------------------------------------- /models/resnet_simclr.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def ResNetSimCLR(input_shape, out_dim, base_model='resnet18'): 5 | inputs = tf.keras.layers.Input(shape=(input_shape)) 6 | 7 | base_encoder = tf.keras.applications.ResNet50(include_top=False, weights=None, input_tensor=None, 8 | input_shape=None, pooling='avg') 9 | base_encoder.training = True 10 | h = base_encoder(inputs) 11 | 12 | # projection head 13 | x = tf.keras.layers.Dense(units=out_dim)(h) 14 | x = tf.keras.layers.Activation('relu')(x) 15 | x = tf.keras.layers.Dense(units=out_dim)(x) 16 | 17 | return tf.keras.Model(inputs=inputs, outputs=[h, x]) 18 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import datetime 3 | import os 4 | import yaml 5 | import time 6 | import shutil 7 | 8 | print(tf.__version__) 9 | from models.cnn_small import SmallCNN 10 | from models.resnet_simclr import ResNetSimCLR 11 | from utils.losses import _dot_simililarity_dim1 as sim_func_dim1, _dot_simililarity_dim2 as sim_func_dim2 12 | from utils.helpers import get_negative_mask, gaussian_filter 13 | from augmentation.transforms import read_images, distort_simclr, read_record, distort_with_rand_aug 14 | 15 | from tensorflow.compat.v1 import ConfigProto 16 | from tensorflow.compat.v1 import InteractiveSession 17 | 18 | config = ConfigProto() 19 | config.gpu_options.allow_growth = True 20 | session = InteractiveSession(config=config) 21 | 22 | config = yaml.load(open("./config.yaml", "r"), Loader=yaml.FullLoader) 23 | input_shape = eval(config['input_shape']) 24 | 25 | train_dataset = tf.data.TFRecordDataset('./data/tfrecords/train.tfrecords') 26 | train_dataset = train_dataset.map(lambda x: read_record(x, input_shape), 27 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 28 | train_dataset = train_dataset.map(distort_simclr, num_parallel_calls=tf.data.experimental.AUTOTUNE) 29 | train_dataset = train_dataset.map(gaussian_filter, num_parallel_calls=tf.data.experimental.AUTOTUNE) 30 | train_dataset = train_dataset.repeat(config['epochs']) 31 | train_dataset = train_dataset.shuffle(4096) 32 | train_dataset = train_dataset.batch(config['batch_size'], drop_remainder=True) 33 | train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE) 34 | 35 | criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM) 36 | optimizer = tf.keras.optimizers.Adam(3e-4) 37 | 38 | # model = SmallCNN(out_dim=config['out_dim']) 39 | model = ResNetSimCLR(input_shape=input_shape, out_dim=config['out_dim']) 40 | 41 | # Mask to remove positive examples from the batch of negative samples 42 | negative_mask = get_negative_mask(config['batch_size']) 43 | 44 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 45 | train_log_dir = os.path.join('logs', current_time, 'train') 46 | train_summary_writer = tf.summary.create_file_writer(train_log_dir) 47 | 48 | 49 | @tf.function 50 | def train_step(xis, xjs): 51 | with tf.GradientTape() as tape: 52 | ris, zis = model(xis) 53 | rjs, zjs = model(xjs) 54 | 55 | # normalize projection feature vectors 56 | zis = tf.math.l2_normalize(zis, axis=1) 57 | zjs = tf.math.l2_normalize(zjs, axis=1) 58 | 59 | # tf.summary.histogram('zis', zis, step=optimizer.iterations) 60 | # tf.summary.histogram('zjs', zjs, step=optimizer.iterations) 61 | 62 | l_pos = sim_func_dim1(zis, zjs) 63 | l_pos = tf.reshape(l_pos, (config['batch_size'], 1)) 64 | l_pos /= config['temperature'] 65 | # assert l_pos.shape == (config['batch_size'], 1), "l_pos shape not valid" + str(l_pos.shape) # [N,1] 66 | 67 | negatives = tf.concat([zjs, zis], axis=0) 68 | 69 | loss = 0 70 | 71 | for positives in [zis, zjs]: 72 | l_neg = sim_func_dim2(positives, negatives) 73 | 74 | labels = tf.zeros(config['batch_size'], dtype=tf.int32) 75 | 76 | l_neg = tf.boolean_mask(l_neg, negative_mask) 77 | l_neg = tf.reshape(l_neg, (config['batch_size'], -1)) 78 | l_neg /= config['temperature'] 79 | 80 | # assert l_neg.shape == ( 81 | # config['batch_size'], 2 * (config['batch_size'] - 1)), "Shape of negatives not expected." + str( 82 | # l_neg.shape) 83 | logits = tf.concat([l_pos, l_neg], axis=1) # [N,K+1] 84 | loss += criterion(y_pred=logits, y_true=labels) 85 | 86 | loss = loss / (2 * config['batch_size']) 87 | tf.summary.scalar('loss', loss, step=optimizer.iterations) 88 | 89 | gradients = tape.gradient(loss, model.trainable_variables) 90 | optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 91 | 92 | 93 | with train_summary_writer.as_default(): 94 | for xis, xjs in train_dataset: 95 | # print(tf.reduce_min(xis), tf.reduce_max(xjs)) 96 | # fig, axs = plt.subplots(nrows=2, ncols=2, constrained_layout=False) 97 | # axs[0, 0].imshow(xis[0]) 98 | # axs[0, 1].imshow(xis[1]) 99 | # axs[1, 0].imshow(xis[2]) 100 | # axs[1, 1].imshow(xis[3]) 101 | # plt.show() 102 | # start = time.time() 103 | train_step(xis, xjs) 104 | # end = time.time() 105 | # print("Total time per batch:", end - start) 106 | 107 | model_checkpoints_folder = os.path.join(train_log_dir, 'checkpoints') 108 | if not os.path.exists(model_checkpoints_folder): 109 | os.makedirs(model_checkpoints_folder) 110 | shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml')) 111 | 112 | model.save_weights(os.path.join(model_checkpoints_folder, 'model.h5')) 113 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from augmentation.gaussian_filter import GaussianBlur 4 | 5 | 6 | def get_negative_mask(batch_size): 7 | # return a mask that removes the similarity score of equal/similar images. 8 | # this function ensures that only distinct pair of images get their similarity scores 9 | # passed as negative examples 10 | negative_mask = np.ones((batch_size, 2 * batch_size), dtype=bool) 11 | for i in range(batch_size): 12 | negative_mask[i, i] = 0 13 | negative_mask[i, i + batch_size] = 0 14 | return tf.constant(negative_mask) 15 | 16 | 17 | def gaussian_filter(v1, v2): 18 | k_size = int(v1.shape[1] * 0.1) # kernel size is set to be 10% of the image height/width 19 | gaussian_ope = GaussianBlur(kernel_size=k_size, min=0.1, max=2.0) 20 | [v1, ] = tf.py_function(gaussian_ope, [v1], [tf.float32]) 21 | [v2, ] = tf.py_function(gaussian_ope, [v2], [tf.float32]) 22 | return v1, v2 23 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | cosine_sim_1d = tf.keras.losses.CosineSimilarity(axis=1, reduction=tf.keras.losses.Reduction.NONE) 4 | cosine_sim_2d = tf.keras.losses.CosineSimilarity(axis=2, reduction=tf.keras.losses.Reduction.NONE) 5 | 6 | 7 | def _cosine_simililarity_dim1(x, y): 8 | v = cosine_sim_1d(x, y) 9 | return v 10 | 11 | 12 | def _cosine_simililarity_dim2(x, y): 13 | # x shape: (N, 1, C) 14 | # y shape: (1, 2N, C) 15 | # v shape: (N, 2N) 16 | v = cosine_sim_2d(tf.expand_dims(x, 1), tf.expand_dims(y, 0)) 17 | return v 18 | 19 | 20 | def _dot_simililarity_dim1(x, y): 21 | # x shape: (N, 1, C) 22 | # y shape: (N, C, 1) 23 | # v shape: (N, 1, 1) 24 | v = tf.matmul(tf.expand_dims(x, 1), tf.expand_dims(y, 2)) 25 | return v 26 | 27 | 28 | def _dot_simililarity_dim2(x, y): 29 | v = tf.tensordot(tf.expand_dims(x, 1), tf.expand_dims(tf.transpose(y), 0), axes=2) 30 | # x shape: (N, 1, C) 31 | # y shape: (1, C, 2N) 32 | # v shape: (N, 2N) 33 | return v 34 | --------------------------------------------------------------------------------