├── .gitignore ├── .python-version ├── README.md ├── model.py ├── nets ├── resnet_utils.py └── resnet_v2.py ├── resources ├── cpc-explanation1.png └── cpc-vision.png └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .venv 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.6.2 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPC-Tensorflow (Contrastive Predictive Coding) 2 | 3 | Tensorflow implementation of **Representation Learning with Contrastive Predictive Coding** 4 | 5 | - [Paper](https://arxiv.org/abs/1807.03748) 6 | - [My Notes](https://github.com/flrngel/understanding-ai/issues/24) 7 | 8 | ## Features 9 | - Resnet101 10 | 11 | ## Explanations 12 | 13 | ![cpc vision model](https://raw.githubusercontent.com/flrngel/cpc-tensorflow/master/resources/cpc-vision.png) 14 | ![cpc explanation](https://raw.githubusercontent.com/flrngel/cpc-tensorflow/master/resources/cpc-explanation1.png) 15 | 16 | 17 | ## Status 18 | 19 | `model.py` is intend to adopt various domain. 20 | 21 | ## References 22 | - [davidtellez/contrastive-predictive-coding](https://github.com/davidtellez/contrastive-predictive-coding) 23 | - davidtellez used sequence generator to make sequential cpc 24 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class CPC(): 4 | def __init__(self, X, X_len, Y, Y_label, n=7, code=7, k=2, code_dim=1024, cell_dimension=128): 5 | """ 6 | Autoregressive part from CPC 7 | """ 8 | 9 | with tf.variable_scope('CPC'): 10 | self.X = X 11 | self.X_len = X_len 12 | self.Y = Y 13 | self.Y_label = Y_label 14 | self.batch_size = X.shape[0] 15 | 16 | 17 | self.n = n 18 | self.k = k 19 | self.code = code 20 | self.code_dim = code_dim 21 | self.cell_dimension = cell_dimension 22 | 23 | cell = tf.contrib.rnn.GRUCell(cell_dimension, name='cell') 24 | initial_state = cell.zero_state(self.batch_size, dtype=tf.float32) 25 | 26 | # Autoregressive model 27 | with tf.variable_scope('g_ar'): 28 | _, c_t = tf.nn.dynamic_rnn(cell, self.X, sequence_length=self.X_len, initial_state=initial_state) 29 | self.c_t = c_t 30 | 31 | with tf.variable_scope('coding'): 32 | losses = [] 33 | y = self.Y 34 | for i in range(k): 35 | W = tf.get_variable('x_t_'+str(i+1), shape=[cell_dimension, self.code_dim]) 36 | y_ = tf.reshape(y[:, i, :], [self.batch_size, self.code_dim]) 37 | self.cpc = tf.map_fn(lambda x: tf.squeeze(tf.transpose(W) @ tf.expand_dims(x, -1), axis=-1), c_t) * y_ 38 | nce = tf.sigmoid(tf.reduce_mean(self.cpc, -1)) 39 | losses.append(tf.keras.losses.binary_crossentropy(self.Y_label, nce)) 40 | 41 | losses = tf.stack(losses, axis=0) 42 | 43 | # Loss function 44 | with tf.variable_scope('train'): 45 | self.loss = tf.reduce_mean(losses) 46 | -------------------------------------------------------------------------------- /nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | store_non_strided_activations=False, 128 | outputs_collections=None): 129 | """Stacks ResNet `Blocks` and controls output feature density. 130 | 131 | First, this function creates scopes for the ResNet in the form of 132 | 'block_name/unit_1', 'block_name/unit_2', etc. 133 | 134 | Second, this function allows the user to explicitly control the ResNet 135 | output_stride, which is the ratio of the input to output spatial resolution. 136 | This is useful for dense prediction tasks such as semantic segmentation or 137 | object detection. 138 | 139 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 140 | factor of 2 when transitioning between consecutive ResNet blocks. This results 141 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 142 | half the nominal network stride (e.g., output_stride=4), then we compute 143 | responses twice. 144 | 145 | Control of the output feature density is implemented by atrous convolution. 146 | 147 | Args: 148 | net: A `Tensor` of size [batch, height, width, channels]. 149 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 150 | element is a ResNet `Block` object describing the units in the `Block`. 151 | output_stride: If `None`, then the output will be computed at the nominal 152 | network stride. If output_stride is not `None`, it specifies the requested 153 | ratio of input to output spatial resolution, which needs to be equal to 154 | the product of unit strides from the start up to some level of the ResNet. 155 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 156 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 157 | is equivalent to output_stride=24). 158 | store_non_strided_activations: If True, we compute non-strided (undecimated) 159 | activations at the last unit of each block and store them in the 160 | `outputs_collections` before subsampling them. This gives us access to 161 | higher resolution intermediate activations which are useful in some 162 | dense prediction problems but increases 4x the computation and memory cost 163 | at the last unit of each block. 164 | outputs_collections: Collection to add the ResNet block outputs. 165 | 166 | Returns: 167 | net: Output tensor with stride equal to the specified output_stride. 168 | 169 | Raises: 170 | ValueError: If the target output_stride is not valid. 171 | """ 172 | # The current_stride variable keeps track of the effective stride of the 173 | # activations. This allows us to invoke atrous convolution whenever applying 174 | # the next residual unit would result in the activations having stride larger 175 | # than the target output_stride. 176 | current_stride = 1 177 | 178 | # The atrous convolution rate parameter. 179 | rate = 1 180 | 181 | for block in blocks: 182 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 183 | block_stride = 1 184 | for i, unit in enumerate(block.args): 185 | if store_non_strided_activations and i == len(block.args) - 1: 186 | # Move stride from the block's last unit to the end of the block. 187 | block_stride = unit.get('stride', 1) 188 | unit = dict(unit, stride=1) 189 | 190 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 191 | # If we have reached the target output_stride, then we need to employ 192 | # atrous convolution with stride=1 and multiply the atrous rate by the 193 | # current unit's stride for use in subsequent layers. 194 | if output_stride is not None and current_stride == output_stride: 195 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1)) 196 | rate *= unit.get('stride', 1) 197 | 198 | else: 199 | net = block.unit_fn(net, rate=1, **unit) 200 | current_stride *= unit.get('stride', 1) 201 | if output_stride is not None and current_stride > output_stride: 202 | raise ValueError('The target output_stride cannot be reached.') 203 | 204 | # Collect activations at the block's end before performing subsampling. 205 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 206 | 207 | # Subsampling of the block's output activations. 208 | if output_stride is not None and current_stride == output_stride: 209 | rate *= block_stride 210 | else: 211 | net = subsample(net, block_stride) 212 | current_stride *= block_stride 213 | if output_stride is not None and current_stride > output_stride: 214 | raise ValueError('The target output_stride cannot be reached.') 215 | 216 | if output_stride is not None and current_stride != output_stride: 217 | raise ValueError('The target output_stride cannot be reached.') 218 | 219 | return net 220 | 221 | 222 | def resnet_arg_scope(weight_decay=0.0001, 223 | batch_norm_decay=0.997, 224 | batch_norm_epsilon=1e-5, 225 | batch_norm_scale=True, 226 | activation_fn=tf.nn.relu, 227 | use_batch_norm=True, 228 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): 229 | """Defines the default ResNet arg scope. 230 | 231 | TODO(gpapan): The batch-normalization related default values above are 232 | appropriate for use in conjunction with the reference ResNet models 233 | released at https://github.com/KaimingHe/deep-residual-networks. When 234 | training ResNets from scratch, they might need to be tuned. 235 | 236 | Args: 237 | weight_decay: The weight decay to use for regularizing the model. 238 | batch_norm_decay: The moving average decay when estimating layer activation 239 | statistics in batch normalization. 240 | batch_norm_epsilon: Small constant to prevent division by zero when 241 | normalizing activations by their variance in batch normalization. 242 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 243 | activations in the batch normalization layer. 244 | activation_fn: The activation function which is used in ResNet. 245 | use_batch_norm: Whether or not to use batch normalization. 246 | batch_norm_updates_collections: Collection for the update ops for 247 | batch norm. 248 | 249 | Returns: 250 | An `arg_scope` to use for the resnet models. 251 | """ 252 | batch_norm_params = { 253 | 'decay': batch_norm_decay, 254 | 'epsilon': batch_norm_epsilon, 255 | 'scale': batch_norm_scale, 256 | 'updates_collections': batch_norm_updates_collections, 257 | 'fused': None, # Use fused batch norm if possible. 258 | } 259 | 260 | with slim.arg_scope( 261 | [slim.conv2d], 262 | weights_regularizer=slim.l2_regularizer(weight_decay), 263 | weights_initializer=slim.variance_scaling_initializer(), 264 | activation_fn=activation_fn, 265 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 266 | normalizer_params=batch_norm_params): 267 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 268 | # The following implies padding='SAME' for pool1, which makes feature 269 | # alignment easier for dense prediction tasks. This is also used in 270 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 271 | # code of 'Deep Residual Learning for Image Recognition' uses 272 | # padding='VALID' for pool1. You can switch to that choice by setting 273 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 274 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 275 | return arg_sc 276 | -------------------------------------------------------------------------------- /nets/resnet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the preactivation form of Residual Networks. 16 | 17 | Residual networks (ResNets) were originally proposed in: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | 21 | The full preactivation 'v2' ResNet variant implemented in this module was 22 | introduced by: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The key difference of the full preactivation 'v2' variant compared to the 27 | 'v1' variant in [1] is the use of batch normalization before every weight layer. 28 | 29 | Typical use: 30 | 31 | from tensorflow.contrib.slim.nets import resnet_v2 32 | 33 | ResNet-101 for image classification into 1000 classes: 34 | 35 | # inputs has shape [batch, 224, 224, 3] 36 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 37 | net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False) 38 | 39 | ResNet-101 for semantic segmentation into 21 classes: 40 | 41 | # inputs has shape [batch, 513, 513, 3] 42 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 43 | net, end_points = resnet_v2.resnet_v2_101(inputs, 44 | 21, 45 | is_training=False, 46 | global_pool=False, 47 | output_stride=16) 48 | """ 49 | from __future__ import absolute_import 50 | from __future__ import division 51 | from __future__ import print_function 52 | 53 | import tensorflow as tf 54 | 55 | from nets import resnet_utils 56 | 57 | slim = tf.contrib.slim 58 | resnet_arg_scope = resnet_utils.resnet_arg_scope 59 | 60 | 61 | @slim.add_arg_scope 62 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 63 | outputs_collections=None, scope=None): 64 | """Bottleneck residual unit variant with BN before convolutions. 65 | 66 | This is the full preactivation residual unit variant proposed in [2]. See 67 | Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck 68 | variant which has an extra bottleneck layer. 69 | 70 | When putting together two consecutive ResNet blocks that use this unit, one 71 | should use stride = 2 in the last unit of the first block. 72 | 73 | Args: 74 | inputs: A tensor of size [batch, height, width, channels]. 75 | depth: The depth of the ResNet unit output. 76 | depth_bottleneck: The depth of the bottleneck layers. 77 | stride: The ResNet unit's stride. Determines the amount of downsampling of 78 | the units output compared to its input. 79 | rate: An integer, rate for atrous convolution. 80 | outputs_collections: Collection to add the ResNet unit output. 81 | scope: Optional variable_scope. 82 | 83 | Returns: 84 | The ResNet unit's output. 85 | """ 86 | with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc: 87 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 88 | preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact') 89 | if depth == depth_in: 90 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 91 | else: 92 | shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride, 93 | normalizer_fn=None, activation_fn=None, 94 | scope='shortcut') 95 | 96 | residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1, 97 | scope='conv1') 98 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 99 | rate=rate, scope='conv2') 100 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 101 | normalizer_fn=None, activation_fn=None, 102 | scope='conv3') 103 | 104 | output = shortcut + residual 105 | 106 | return slim.utils.collect_named_outputs(outputs_collections, 107 | sc.name, 108 | output) 109 | 110 | 111 | def resnet_v2(inputs, 112 | blocks, 113 | num_classes=None, 114 | is_training=True, 115 | global_pool=True, 116 | output_stride=None, 117 | include_root_block=True, 118 | spatial_squeeze=True, 119 | reuse=None, 120 | scope=None): 121 | """Generator for v2 (preactivation) ResNet models. 122 | 123 | This function generates a family of ResNet v2 models. See the resnet_v2_*() 124 | methods for specific model instantiations, obtained by selecting different 125 | block instantiations that produce ResNets of various depths. 126 | 127 | Training for image classification on Imagenet is usually done with [224, 224] 128 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 129 | block for the ResNets defined in [1] that have nominal stride equal to 32. 130 | However, for dense prediction tasks we advise that one uses inputs with 131 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 132 | this case the feature maps at the ResNet output will have spatial shape 133 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 134 | and corners exactly aligned with the input image corners, which greatly 135 | facilitates alignment of the features to the image. Using as input [225, 225] 136 | images results in [8, 8] feature maps at the output of the last ResNet block. 137 | 138 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 139 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 140 | have nominal stride equal to 32 and a good choice in FCN mode is to use 141 | output_stride=16 in order to increase the density of the computed features at 142 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 143 | 144 | Args: 145 | inputs: A tensor of size [batch, height_in, width_in, channels]. 146 | blocks: A list of length equal to the number of ResNet blocks. Each element 147 | is a resnet_utils.Block object describing the units in the block. 148 | num_classes: Number of predicted classes for classification tasks. 149 | If 0 or None, we return the features before the logit layer. 150 | is_training: whether batch_norm layers are in training mode. 151 | global_pool: If True, we perform global average pooling before computing the 152 | logits. Set to True for image classification, False for dense prediction. 153 | output_stride: If None, then the output will be computed at the nominal 154 | network stride. If output_stride is not None, it specifies the requested 155 | ratio of input to output spatial resolution. 156 | include_root_block: If True, include the initial convolution followed by 157 | max-pooling, if False excludes it. If excluded, `inputs` should be the 158 | results of an activation-less convolution. 159 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 160 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 161 | To use this parameter, the input images must be smaller than 300x300 162 | pixels, in which case the output logit layer does not contain spatial 163 | information and can be removed. 164 | reuse: whether or not the network and its variables should be reused. To be 165 | able to reuse 'scope' must be given. 166 | scope: Optional variable_scope. 167 | 168 | 169 | Returns: 170 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 171 | If global_pool is False, then height_out and width_out are reduced by a 172 | factor of output_stride compared to the respective height_in and width_in, 173 | else both height_out and width_out equal one. If num_classes is 0 or None, 174 | then net is the output of the last ResNet block, potentially after global 175 | average pooling. If num_classes is a non-zero integer, net contains the 176 | pre-softmax activations. 177 | end_points: A dictionary from components of the network to the corresponding 178 | activation. 179 | 180 | Raises: 181 | ValueError: If the target output_stride is not valid. 182 | """ 183 | with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc: 184 | end_points_collection = sc.original_name_scope + '_end_points' 185 | with slim.arg_scope([slim.conv2d, bottleneck, 186 | resnet_utils.stack_blocks_dense], 187 | outputs_collections=end_points_collection): 188 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 189 | net = inputs 190 | if include_root_block: 191 | if output_stride is not None: 192 | if output_stride % 4 != 0: 193 | raise ValueError('The output_stride needs to be a multiple of 4.') 194 | output_stride /= 4 195 | # We do not include batch normalization or activation functions in 196 | # conv1 because the first ResNet unit will perform these. Cf. 197 | # Appendix of [2]. 198 | with slim.arg_scope([slim.conv2d], 199 | activation_fn=None, normalizer_fn=None): 200 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 201 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 202 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 203 | # This is needed because the pre-activation variant does not have batch 204 | # normalization or activation functions in the residual unit output. See 205 | # Appendix of [2]. 206 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm') 207 | # Convert end_points_collection into a dictionary of end_points. 208 | end_points = slim.utils.convert_collection_to_dict( 209 | end_points_collection) 210 | 211 | if global_pool: 212 | # Global average pooling. 213 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 214 | end_points['global_pool'] = net 215 | if num_classes: 216 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 217 | normalizer_fn=None, scope='logits') 218 | end_points[sc.name + '/logits'] = net 219 | if spatial_squeeze: 220 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 221 | end_points[sc.name + '/spatial_squeeze'] = net 222 | end_points['predictions'] = slim.softmax(net, scope='predictions') 223 | return net, end_points 224 | resnet_v2.default_image_size = 224 225 | 226 | 227 | def resnet_v2_block(scope, base_depth, num_units, stride): 228 | """Helper function for creating a resnet_v2 bottleneck block. 229 | 230 | Args: 231 | scope: The scope of the block. 232 | base_depth: The depth of the bottleneck layer for each unit. 233 | num_units: The number of units in the block. 234 | stride: The stride of the block, implemented as a stride in the last unit. 235 | All other units have stride=1. 236 | 237 | Returns: 238 | A resnet_v2 bottleneck block. 239 | """ 240 | return resnet_utils.Block(scope, bottleneck, [{ 241 | 'depth': base_depth * 4, 242 | 'depth_bottleneck': base_depth, 243 | 'stride': 1 244 | }] * (num_units - 1) + [{ 245 | 'depth': base_depth * 4, 246 | 'depth_bottleneck': base_depth, 247 | 'stride': stride 248 | }]) 249 | resnet_v2.default_image_size = 224 250 | 251 | 252 | def resnet_v2_50(inputs, 253 | num_classes=None, 254 | is_training=True, 255 | global_pool=True, 256 | output_stride=None, 257 | spatial_squeeze=True, 258 | reuse=None, 259 | scope='resnet_v2_50'): 260 | """ResNet-50 model of [1]. See resnet_v2() for arg and return description.""" 261 | blocks = [ 262 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 263 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), 264 | resnet_v2_block('block3', base_depth=256, num_units=6, stride=2), 265 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 266 | ] 267 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 268 | global_pool=global_pool, output_stride=output_stride, 269 | include_root_block=True, spatial_squeeze=spatial_squeeze, 270 | reuse=reuse, scope=scope) 271 | resnet_v2_50.default_image_size = resnet_v2.default_image_size 272 | 273 | 274 | def resnet_v2_101(inputs, 275 | num_classes=None, 276 | is_training=True, 277 | global_pool=True, 278 | output_stride=None, 279 | spatial_squeeze=True, 280 | reuse=None, 281 | scope='resnet_v2_101'): 282 | """ResNet-101 model of [1]. See resnet_v2() for arg and return description.""" 283 | blocks = [ 284 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 285 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), 286 | resnet_v2_block('block3', base_depth=256, num_units=23, stride=2), 287 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 288 | ] 289 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 290 | global_pool=global_pool, output_stride=output_stride, 291 | include_root_block=True, spatial_squeeze=spatial_squeeze, 292 | reuse=reuse, scope=scope) 293 | resnet_v2_101.default_image_size = resnet_v2.default_image_size 294 | 295 | 296 | def resnet_v2_152(inputs, 297 | num_classes=None, 298 | is_training=True, 299 | global_pool=True, 300 | output_stride=None, 301 | spatial_squeeze=True, 302 | reuse=None, 303 | scope='resnet_v2_152'): 304 | """ResNet-152 model of [1]. See resnet_v2() for arg and return description.""" 305 | blocks = [ 306 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 307 | resnet_v2_block('block2', base_depth=128, num_units=8, stride=2), 308 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2), 309 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 310 | ] 311 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 312 | global_pool=global_pool, output_stride=output_stride, 313 | include_root_block=True, spatial_squeeze=spatial_squeeze, 314 | reuse=reuse, scope=scope) 315 | resnet_v2_152.default_image_size = resnet_v2.default_image_size 316 | 317 | 318 | def resnet_v2_200(inputs, 319 | num_classes=None, 320 | is_training=True, 321 | global_pool=True, 322 | output_stride=None, 323 | spatial_squeeze=True, 324 | reuse=None, 325 | scope='resnet_v2_200'): 326 | """ResNet-200 model of [2]. See resnet_v2() for arg and return description.""" 327 | blocks = [ 328 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 329 | resnet_v2_block('block2', base_depth=128, num_units=24, stride=2), 330 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2), 331 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 332 | ] 333 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 334 | global_pool=global_pool, output_stride=output_stride, 335 | include_root_block=True, spatial_squeeze=spatial_squeeze, 336 | reuse=reuse, scope=scope) 337 | resnet_v2_200.default_image_size = resnet_v2.default_image_size 338 | -------------------------------------------------------------------------------- /resources/cpc-explanation1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flrngel/cpc-tensorflow/eca44df18b3f7a1536a2497762ea05efcc95685c/resources/cpc-explanation1.png -------------------------------------------------------------------------------- /resources/cpc-vision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flrngel/cpc-tensorflow/eca44df18b3f7a1536a2497762ea05efcc95685c/resources/cpc-vision.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow import keras 4 | from model import CPC 5 | from nets.resnet_v2 import resnet_v2_101 as resnet 6 | 7 | tf.app.flags.DEFINE_string('mode', 'train', 'mode') 8 | tf.app.flags.DEFINE_integer('epochs', 10, 'epochs') 9 | tf.app.flags.DEFINE_integer('batch_size', 16, 'batch size to train in one step') 10 | tf.app.flags.DEFINE_float('learn_rate', 2e-4, 'learn rate for training optimization') 11 | tf.app.flags.DEFINE_integer('K', 2, 'hyperparameter K') 12 | 13 | FLAGS = tf.app.flags.FLAGS 14 | 15 | mode = FLAGS.mode 16 | epochs = FLAGS.epochs 17 | learn_rate = FLAGS.learn_rate 18 | batch_size = FLAGS.batch_size 19 | K = FLAGS.K 20 | n = 7 21 | 22 | def image_preprocess(x): 23 | x = tf.expand_dims(x, axis=-1) 24 | x = tf.concat([x, x, x], axis=-1) 25 | x = tf.image.resize_images(x, (256, 256)) 26 | arr = [] 27 | for i in range(7): 28 | for j in range(7): 29 | arr.append(x[:, i*32:i*32+64, j*32:j*32+64, :]) 30 | x = tf.concat(arr, axis=0) 31 | return x 32 | 33 | def chunks(l, n): 34 | for i in range(0, len(l), n): 35 | yield l[i:i+n] 36 | 37 | # load data 38 | fashion_mnist = keras.datasets.fashion_mnist 39 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 40 | if mode == 'train': 41 | batches = tf.data.Dataset.from_tensor_slices(train_images).repeat(epochs).shuffle(100, 42 | reshuffle_each_iteration=True).batch(batch_size) 43 | elif mode == 'validation': 44 | batches = tf.data.Dataset.from_tensor_slices(train_images).repeat(epochs).batch(batch_size) 45 | elif mode == 'infer': 46 | batches = tf.data.Dataset.from_tensor_slices(test_images).repeat(epochs).shuffle(100, 47 | reshuffle_each_iteration=True).batch(batch_size) 48 | 49 | iterator = batches.make_initializable_iterator() 50 | items = iterator.get_next() 51 | data = image_preprocess(items) 52 | 53 | # build graph 54 | ## resnet encoding 55 | _, features = resnet(data, is_training=False) 56 | features = features['resnet_v2_101/block3'] 57 | 58 | # mean pooling 59 | features = tf.reduce_mean(features, axis=[1, 2]) 60 | features = tf.reshape(features, shape=[batch_size, 7, 7, 1024]) 61 | 62 | X = tf.reshape(features, shape=[batch_size, 7, 7, 1024]) 63 | X = tf.transpose(X, perm=[0, 2, 1, 3]) 64 | tmp = [] 65 | for i in range(batch_size): 66 | for j in range(n): 67 | tmp.append(X[i][j]) 68 | X = tf.stack(tmp, 0) 69 | batch_size *= n 70 | 71 | # for random row 72 | nl = [] 73 | nrr = [] 74 | nrri = [] 75 | for i in range(K): 76 | nlist = np.arange(0, n) 77 | nlist = nlist[nlist != (n-K+i)] 78 | nl.append(tf.constant(nlist)) 79 | nrr.append([tf.random_shuffle(nl[i]) for j in range(batch_size)]) 80 | nrri = [tf.stack([nrr[j][i][0] for j in range(K)], axis=0) for i in range(batch_size)] 81 | 82 | Y = [] 83 | Y_label = np.zeros((batch_size), dtype=np.float32) 84 | n_p = batch_size // 2 85 | 86 | for i in range(batch_size): 87 | if i <= n_p: 88 | Y.append(tf.expand_dims(features[int(i/n), -K:, i%n, :], axis=0)) 89 | Y_label[i] = 1 90 | else: 91 | Y.append(tf.expand_dims(tf.gather(features[int(i/n)], nrri[i])[:, i%n, :], axis=0)) 92 | 93 | Y = tf.concat(Y, axis=0) 94 | Y_label = tf.constant(Y_label, dtype=np.float32) 95 | 96 | nr = tf.random_shuffle(tf.constant(list(range(batch_size)), dtype=tf.int32)) 97 | 98 | ## cpc 99 | X_len = [5] * batch_size 100 | X_len = tf.constant(X_len, dtype=tf.int32) 101 | 102 | cpc = CPC(X, X_len, Y, Y_label, k=K) 103 | train_op = tf.train.AdamOptimizer(learn_rate).minimize(cpc.loss) 104 | 105 | saver = tf.train.Saver() 106 | 107 | # tensorflow 108 | with tf.Session() as sess: 109 | if mode == 'train': 110 | sess.run(tf.global_variables_initializer()) 111 | sess.run(iterator.initializer) 112 | 113 | step = 0 114 | total = int((len(train_images) * epochs * n) / batch_size) 115 | 116 | while True: 117 | try: 118 | #print(sess.run([Y_label])) 119 | #sess.run([nr, nrr]) 120 | _, loss, _ = sess.run([train_op, cpc.loss, items]) 121 | if step % 100 == 0: 122 | print('loss: ', loss, 'step:', step, '/', total) 123 | step += 1 124 | except tf.errors.OutOfRangeError: 125 | break 126 | 127 | saver.save(sess, './model.ckpt') 128 | 129 | elif mode == 'validation': 130 | with tf.variable_scope('validation'): 131 | batch_size = int(batch_size / n) 132 | features = tf.reshape(features, shape=[batch_size, 7 * 7 * 1024]) 133 | out = tf.layers.dense(features, 10) 134 | labels = tf.placeholder(tf.int32, shape=[batch_size]) 135 | labels_onehot = tf.one_hot(labels, depth=10) 136 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=out, labels=labels_onehot)) 137 | train_op = tf.train.AdamOptimizer(learn_rate).minimize(loss, var_list=[tf.trainable_variables(scope='validation')]) 138 | 139 | i=0 140 | sess.run(tf.global_variables_initializer()) 141 | sess.run(iterator.initializer) 142 | saver.restore(sess, './model.ckpt') 143 | s = 0 144 | 145 | debug = tf.reduce_mean(features) 146 | while True: 147 | try: 148 | _, _loss, _out, _ = sess.run([train_op, loss, out, items], feed_dict={labels: train_labels[i*batch_size:(i+1)*batch_size]}) 149 | s += np.sum(np.argmax(_out, axis=1) == train_labels[i*batch_size:(i+1)*batch_size]) 150 | if i % 100 == 0: 151 | print(_loss, s/(batch_size*100)) 152 | s=0 153 | if i % 1000 == 0: 154 | print(np.argmax(_out, axis=1), train_labels[i*batch_size:(i+1)*batch_size]) 155 | i+=1 156 | if len(train_images)//batch_size <= i: 157 | i=0 158 | 159 | except tf.errors.OutOfRangeError: 160 | break 161 | 162 | #saver.save(sess, './model_infer.ckpt') 163 | --------------------------------------------------------------------------------