├── CIFAR10 ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data.py ├── data_util.py ├── imagenet_subsets │ ├── 10percent.txt │ └── 1percent.txt ├── lars_optimizer.py ├── model.py ├── model_util.py ├── objective.py ├── requirements.txt ├── resnet.py └── run.py ├── LICENSE ├── MSCOCO ├── README.md ├── ResNet_baseline.ipynb └── main.ipynb ├── Omniglot ├── README.md ├── compute_MI_CondEntro.py ├── linear.py ├── main.py ├── model.py └── utils.py └── README.md /CIFAR10/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | SimCLR needs to maintain permanent compatibility with the pre-trained model 4 | files, so we do not plan to make any major changes to this library (other than 5 | what was promised in the README). However, we can accept small patches related 6 | to re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /CIFAR10/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /CIFAR10/README.md: -------------------------------------------------------------------------------- 1 | # CIFAR10 Experiments 2 | 3 | The code is adapted from [here](https://github.com/google-research/simclr) 4 | 5 | 6 | ## Usage 7 | 8 | ### Evaluating Self-supervised Representations 9 | 10 | Contrastive Learning Objective only 11 | ``` 12 | python run.py --train_mode=pretrain --train_batch_size=512 --train_epochs=1000 \ 13 | --learning_rate=1.0 --weight_decay=1e-6 --dataset=cifar10 --image_size=32 \ 14 | --eval_split=test --resnet_depth=18 --use_blur=False --color_jitter_strength=0.5 \ 15 | --model_dir=/root/data/githubs/simclr_models/c10_cpc_1 --inv_pred_coeff=0.0 \ 16 | --use_tpu=False --temperature=0.5 --hidden_norm=True 17 | 18 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=4 \ 19 | --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)LARSOptimizer|head)' \ 20 | --global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0 \ 21 | --train_epochs=100 --train_batch_size=512 --warmup_epochs=0 --dataset=cifar10 --image_size=32 \ 22 | --eval_split=test --resnet_depth=18 --checkpoint=/root/data/githubs/simclr_models/c10_cpc_1 \ 23 | --model_dir=/root/data/githubs/simclr_models/c10_cpc_1/ft --inv_pred_coeff=0.0 --use_tpu=False \ 24 | --temperature=0.5 --hidden_norm=True 25 | ``` 26 | 27 | Contrastive Learning Objective + Inverse Predictive Learning Objective 28 | ``` 29 | python run.py --train_mode=pretrain --train_batch_size=512 --train_epochs=1000 \ 30 | --learning_rate=1.0 --weight_decay=1e-6 --dataset=cifar10 --image_size=32 \ 31 | --eval_split=test --resnet_depth=18 --use_blur=False --color_jitter_strength=0.5 \ 32 | --model_dir=/root/data/githubs/simclr_models/c10_cpc_inv_1 --inv_pred_coeff=0.03 \ 33 | --use_tpu=False --temperature=0.5 --hidden_norm=True 34 | 35 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=4 \ 36 | --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)LARSOptimizer|head)' \ 37 | --global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0 \ 38 | --train_epochs=100 --train_batch_size=512 --warmup_epochs=0 --dataset=cifar10 --image_size=32 \ 39 | --eval_split=test --resnet_depth=18 --checkpoint=/root/data/githubs/simclr_models/c10_cpc_1 \ 40 | --model_dir=/root/data/githubs/simclr_models/c10_cpc_inv_1/ft --inv_pred_coeff=0.03 --use_tpu=False \ 41 | --temperature=0.5 --hidden_norm=True 42 | ``` 43 | -------------------------------------------------------------------------------- /CIFAR10/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Data pipeline.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import functools 23 | from absl import flags 24 | 25 | import data_util as data_util 26 | import tensorflow.compat.v1 as tf 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | def pad_to_batch(dataset, batch_size): 32 | """Pad Tensors to specified batch size. 33 | 34 | Args: 35 | dataset: An instance of tf.data.Dataset. 36 | batch_size: The number of samples per batch of input requested. 37 | 38 | Returns: 39 | An instance of tf.data.Dataset that yields the same Tensors with the same 40 | structure as the original padded to batch_size along the leading 41 | dimension. 42 | 43 | Raises: 44 | ValueError: If the dataset does not comprise any tensors; if a tensor 45 | yielded by the dataset has an unknown number of dimensions or is a 46 | scalar; or if it can be statically determined that tensors comprising 47 | a single dataset element will have different leading dimensions. 48 | """ 49 | def _pad_to_batch(*args): 50 | """Given Tensors yielded by a Dataset, pads all to the batch size.""" 51 | flat_args = tf.nest.flatten(args) 52 | 53 | for tensor in flat_args: 54 | if tensor.shape.ndims is None: 55 | raise ValueError( 56 | 'Unknown number of dimensions for tensor %s.' % tensor.name) 57 | if tensor.shape.ndims == 0: 58 | raise ValueError('Tensor %s is a scalar.' % tensor.name) 59 | 60 | # This will throw if flat_args is empty. However, as of this writing, 61 | # tf.data.Dataset.map will throw first with an internal error, so we do 62 | # not check this case explicitly. 63 | first_tensor = flat_args[0] 64 | first_tensor_shape = tf.shape(first_tensor) 65 | first_tensor_batch_size = first_tensor_shape[0] 66 | difference = batch_size - first_tensor_batch_size 67 | 68 | for i, tensor in enumerate(flat_args): 69 | control_deps = [] 70 | if i != 0: 71 | # Check that leading dimensions of this tensor matches the first, 72 | # either statically or dynamically. (If the first dimensions of both 73 | # tensors are statically known, the we have to check the static 74 | # shapes at graph construction time or else we will never get to the 75 | # dynamic assertion.) 76 | if (first_tensor.shape[:1].is_fully_defined() and 77 | tensor.shape[:1].is_fully_defined()): 78 | if first_tensor.shape[0] != tensor.shape[0]: 79 | raise ValueError( 80 | 'Batch size of dataset tensors does not match. %s ' 81 | 'has shape %s, but %s has shape %s' % ( 82 | first_tensor.name, first_tensor.shape, 83 | tensor.name, tensor.shape)) 84 | else: 85 | curr_shape = tf.shape(tensor) 86 | control_deps = [tf.Assert( 87 | tf.equal(curr_shape[0], first_tensor_batch_size), 88 | ['Batch size of dataset tensors %s and %s do not match. ' 89 | 'Shapes are' % (tensor.name, first_tensor.name), curr_shape, 90 | first_tensor_shape])] 91 | 92 | with tf.control_dependencies(control_deps): 93 | # Pad to batch_size along leading dimension. 94 | flat_args[i] = tf.pad( 95 | tensor, [[0, difference]] + [[0, 0]] * (tensor.shape.ndims - 1)) 96 | flat_args[i].set_shape([batch_size] + tensor.shape.as_list()[1:]) 97 | 98 | return tf.nest.pack_sequence_as(args, flat_args) 99 | 100 | return dataset.map(_pad_to_batch) 101 | 102 | 103 | def build_input_fn(builder, is_training): 104 | """Build input function. 105 | 106 | Args: 107 | builder: TFDS builder for specified dataset. 108 | is_training: Whether to build in training mode. 109 | 110 | Returns: 111 | A function that accepts a dict of params and returns a tuple of images and 112 | features, to be used as the input_fn in TPUEstimator. 113 | """ 114 | def _input_fn(params): 115 | """Inner input function.""" 116 | preprocess_fn_pretrain = get_preprocess_fn(is_training, is_pretrain=True) 117 | preprocess_fn_finetune = get_preprocess_fn(is_training, is_pretrain=False) 118 | num_classes = builder.info.features['label'].num_classes 119 | 120 | def map_fn(image, label): 121 | """Produces multiple transformations of the same batch.""" 122 | if FLAGS.train_mode == 'pretrain': 123 | xs = [] 124 | for _ in range(2): # Two transformations 125 | xs.append(preprocess_fn_pretrain(image)) 126 | image = tf.concat(xs, -1) 127 | label = tf.zeros([num_classes]) 128 | else: 129 | image = preprocess_fn_finetune(image) 130 | label = tf.one_hot(label, num_classes) 131 | return image, label, 1.0 132 | 133 | dataset = builder.as_dataset( 134 | split=FLAGS.train_split if is_training else FLAGS.eval_split, 135 | shuffle_files=is_training, as_supervised=True) 136 | if FLAGS.cache_dataset: 137 | dataset = dataset.cache() 138 | if is_training: 139 | buffer_multiplier = 50 if FLAGS.image_size <= 32 else 10 140 | dataset = dataset.shuffle(params['batch_size'] * buffer_multiplier) 141 | dataset = dataset.repeat(-1) 142 | dataset = dataset.map(map_fn, 143 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 144 | dataset = dataset.batch(params['batch_size'], drop_remainder=is_training) 145 | dataset = pad_to_batch(dataset, params['batch_size']) 146 | images, labels, mask = tf.data.make_one_shot_iterator(dataset).get_next() 147 | 148 | return images, {'labels': labels, 'mask': mask} 149 | return _input_fn 150 | 151 | 152 | def get_preprocess_fn(is_training, is_pretrain): 153 | """Get function that accepts an image and returns a preprocessed image.""" 154 | # Disable test cropping for small images (e.g. CIFAR) 155 | if FLAGS.image_size <= 32: 156 | test_crop = False 157 | else: 158 | test_crop = True 159 | return functools.partial( 160 | data_util.preprocess_image, 161 | height=FLAGS.image_size, 162 | width=FLAGS.image_size, 163 | is_training=is_training, 164 | color_distort=is_pretrain, 165 | test_crop=test_crop) 166 | -------------------------------------------------------------------------------- /CIFAR10/data_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Data preprocessing and augmentation.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import functools 23 | from absl import flags 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | CROP_PROPORTION = 0.875 # Standard for ImageNet. 30 | 31 | 32 | def random_apply(func, p, x): 33 | """Randomly apply function func to x with probability p.""" 34 | return tf.cond( 35 | tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), 36 | tf.cast(p, tf.float32)), 37 | lambda: func(x), 38 | lambda: x) 39 | 40 | 41 | def to_grayscale(image, keep_channels=True): 42 | image = tf.image.rgb_to_grayscale(image) 43 | if keep_channels: 44 | image = tf.tile(image, [1, 1, 3]) 45 | return image 46 | 47 | 48 | def color_jitter(image, 49 | strength, 50 | random_order=True): 51 | """Distorts the color of the image. 52 | 53 | Args: 54 | image: The input image tensor. 55 | strength: the floating number for the strength of the color augmentation. 56 | random_order: A bool, specifying whether to randomize the jittering order. 57 | 58 | Returns: 59 | The distorted image tensor. 60 | """ 61 | brightness = 0.8 * strength 62 | contrast = 0.8 * strength 63 | saturation = 0.8 * strength 64 | hue = 0.2 * strength 65 | if random_order: 66 | return color_jitter_rand(image, brightness, contrast, saturation, hue) 67 | else: 68 | return color_jitter_nonrand(image, brightness, contrast, saturation, hue) 69 | 70 | 71 | def color_jitter_nonrand(image, brightness=0, contrast=0, saturation=0, hue=0): 72 | """Distorts the color of the image (jittering order is fixed). 73 | 74 | Args: 75 | image: The input image tensor. 76 | brightness: A float, specifying the brightness for color jitter. 77 | contrast: A float, specifying the contrast for color jitter. 78 | saturation: A float, specifying the saturation for color jitter. 79 | hue: A float, specifying the hue for color jitter. 80 | 81 | Returns: 82 | The distorted image tensor. 83 | """ 84 | with tf.name_scope('distort_color'): 85 | def apply_transform(i, x, brightness, contrast, saturation, hue): 86 | """Apply the i-th transformation.""" 87 | if brightness != 0 and i == 0: 88 | x = tf.image.random_brightness(x, max_delta=brightness) 89 | elif contrast != 0 and i == 1: 90 | x = tf.image.random_contrast( 91 | x, lower=1-contrast, upper=1+contrast) 92 | elif saturation != 0 and i == 2: 93 | x = tf.image.random_saturation( 94 | x, lower=1-saturation, upper=1+saturation) 95 | elif hue != 0: 96 | x = tf.image.random_hue(x, max_delta=hue) 97 | return x 98 | 99 | for i in range(4): 100 | image = apply_transform(i, image, brightness, contrast, saturation, hue) 101 | image = tf.clip_by_value(image, 0., 1.) 102 | return image 103 | 104 | 105 | def color_jitter_rand(image, brightness=0, contrast=0, saturation=0, hue=0): 106 | """Distorts the color of the image (jittering order is random). 107 | 108 | Args: 109 | image: The input image tensor. 110 | brightness: A float, specifying the brightness for color jitter. 111 | contrast: A float, specifying the contrast for color jitter. 112 | saturation: A float, specifying the saturation for color jitter. 113 | hue: A float, specifying the hue for color jitter. 114 | 115 | Returns: 116 | The distorted image tensor. 117 | """ 118 | with tf.name_scope('distort_color'): 119 | def apply_transform(i, x): 120 | """Apply the i-th transformation.""" 121 | def brightness_foo(): 122 | if brightness == 0: 123 | return x 124 | else: 125 | return tf.image.random_brightness(x, max_delta=brightness) 126 | def contrast_foo(): 127 | if contrast == 0: 128 | return x 129 | else: 130 | return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast) 131 | def saturation_foo(): 132 | if saturation == 0: 133 | return x 134 | else: 135 | return tf.image.random_saturation( 136 | x, lower=1-saturation, upper=1+saturation) 137 | def hue_foo(): 138 | if hue == 0: 139 | return x 140 | else: 141 | return tf.image.random_hue(x, max_delta=hue) 142 | x = tf.cond(tf.less(i, 2), 143 | lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo), 144 | lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo)) 145 | return x 146 | 147 | perm = tf.random_shuffle(tf.range(4)) 148 | for i in range(4): 149 | image = apply_transform(perm[i], image) 150 | image = tf.clip_by_value(image, 0., 1.) 151 | return image 152 | 153 | 154 | def _compute_crop_shape( 155 | image_height, image_width, aspect_ratio, crop_proportion): 156 | """Compute aspect ratio-preserving shape for central crop. 157 | 158 | The resulting shape retains `crop_proportion` along one side and a proportion 159 | less than or equal to `crop_proportion` along the other side. 160 | 161 | Args: 162 | image_height: Height of image to be cropped. 163 | image_width: Width of image to be cropped. 164 | aspect_ratio: Desired aspect ratio (width / height) of output. 165 | crop_proportion: Proportion of image to retain along the less-cropped side. 166 | 167 | Returns: 168 | crop_height: Height of image after cropping. 169 | crop_width: Width of image after cropping. 170 | """ 171 | image_width_float = tf.cast(image_width, tf.float32) 172 | image_height_float = tf.cast(image_height, tf.float32) 173 | 174 | def _requested_aspect_ratio_wider_than_image(): 175 | crop_height = tf.cast(tf.rint( 176 | crop_proportion / aspect_ratio * image_width_float), tf.int32) 177 | crop_width = tf.cast(tf.rint( 178 | crop_proportion * image_width_float), tf.int32) 179 | return crop_height, crop_width 180 | 181 | def _image_wider_than_requested_aspect_ratio(): 182 | crop_height = tf.cast( 183 | tf.rint(crop_proportion * image_height_float), tf.int32) 184 | crop_width = tf.cast(tf.rint( 185 | crop_proportion * aspect_ratio * 186 | image_height_float), tf.int32) 187 | return crop_height, crop_width 188 | 189 | return tf.cond( 190 | aspect_ratio > image_width_float / image_height_float, 191 | _requested_aspect_ratio_wider_than_image, 192 | _image_wider_than_requested_aspect_ratio) 193 | 194 | 195 | def center_crop(image, height, width, crop_proportion): 196 | """Crops to center of image and rescales to desired size. 197 | 198 | Args: 199 | image: Image Tensor to crop. 200 | height: Height of image to be cropped. 201 | width: Width of image to be cropped. 202 | crop_proportion: Proportion of image to retain along the less-cropped side. 203 | 204 | Returns: 205 | A `height` x `width` x channels Tensor holding a central crop of `image`. 206 | """ 207 | shape = tf.shape(image) 208 | image_height = shape[0] 209 | image_width = shape[1] 210 | crop_height, crop_width = _compute_crop_shape( 211 | image_height, image_width, height / width, crop_proportion) 212 | offset_height = ((image_height - crop_height) + 1) // 2 213 | offset_width = ((image_width - crop_width) + 1) // 2 214 | image = tf.image.crop_to_bounding_box( 215 | image, offset_height, offset_width, crop_height, crop_width) 216 | 217 | image = tf.image.resize_bicubic([image], [height, width])[0] 218 | 219 | return image 220 | 221 | 222 | def distorted_bounding_box_crop(image, 223 | bbox, 224 | min_object_covered=0.1, 225 | aspect_ratio_range=(0.75, 1.33), 226 | area_range=(0.05, 1.0), 227 | max_attempts=100, 228 | scope=None): 229 | """Generates cropped_image using one of the bboxes randomly distorted. 230 | 231 | See `tf.image.sample_distorted_bounding_box` for more documentation. 232 | 233 | Args: 234 | image: `Tensor` of image data. 235 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` 236 | where each coordinate is [0, 1) and the coordinates are arranged 237 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole 238 | image. 239 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 240 | area of the image must contain at least this fraction of any bounding 241 | box supplied. 242 | aspect_ratio_range: An optional list of `float`s. The cropped area of the 243 | image must have an aspect ratio = width / height within this range. 244 | area_range: An optional list of `float`s. The cropped area of the image 245 | must contain a fraction of the supplied image within in this range. 246 | max_attempts: An optional `int`. Number of attempts at generating a cropped 247 | region of the image of the specified constraints. After `max_attempts` 248 | failures, return the entire image. 249 | scope: Optional `str` for name scope. 250 | Returns: 251 | (cropped image `Tensor`, distorted bbox `Tensor`). 252 | """ 253 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 254 | shape = tf.shape(image) 255 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 256 | shape, 257 | bounding_boxes=bbox, 258 | min_object_covered=min_object_covered, 259 | aspect_ratio_range=aspect_ratio_range, 260 | area_range=area_range, 261 | max_attempts=max_attempts, 262 | use_image_if_no_bounding_boxes=True) 263 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 264 | 265 | # Crop the image to the specified bounding box. 266 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 267 | target_height, target_width, _ = tf.unstack(bbox_size) 268 | image = tf.image.crop_to_bounding_box( 269 | image, offset_y, offset_x, target_height, target_width) 270 | 271 | return image 272 | 273 | 274 | def crop_and_resize(image, height, width): 275 | """Make a random crop and resize it to height `height` and width `width`. 276 | 277 | Args: 278 | image: Tensor representing the image. 279 | height: Desired image height. 280 | width: Desired image width. 281 | 282 | Returns: 283 | A `height` x `width` x channels Tensor holding a random crop of `image`. 284 | """ 285 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 286 | aspect_ratio = width / height 287 | image = distorted_bounding_box_crop( 288 | image, 289 | bbox, 290 | min_object_covered=0.1, 291 | aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio), 292 | area_range=(0.08, 1.0), 293 | max_attempts=100, 294 | scope=None) 295 | return tf.image.resize_bicubic([image], [height, width])[0] 296 | 297 | 298 | def gaussian_blur(image, kernel_size, sigma, padding='SAME'): 299 | """Blurs the given image with separable convolution. 300 | 301 | 302 | Args: 303 | image: Tensor of shape [height, width, channels] and dtype float to blur. 304 | kernel_size: Integer Tensor for the size of the blur kernel. This is should 305 | be an odd number. If it is an even number, the actual kernel size will be 306 | size + 1. 307 | sigma: Sigma value for gaussian operator. 308 | padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'. 309 | 310 | Returns: 311 | A Tensor representing the blurred image. 312 | """ 313 | radius = tf.to_int32(kernel_size / 2) 314 | kernel_size = radius * 2 + 1 315 | x = tf.to_float(tf.range(-radius, radius + 1)) 316 | blur_filter = tf.exp( 317 | -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.to_float(sigma), 2.0))) 318 | blur_filter /= tf.reduce_sum(blur_filter) 319 | # One vertical and one horizontal filter. 320 | blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1]) 321 | blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1]) 322 | num_channels = tf.shape(image)[-1] 323 | blur_h = tf.tile(blur_h, [1, 1, num_channels, 1]) 324 | blur_v = tf.tile(blur_v, [1, 1, num_channels, 1]) 325 | expand_batch_dim = image.shape.ndims == 3 326 | if expand_batch_dim: 327 | # Tensorflow requires batched input to convolutions, which we can fake with 328 | # an extra dimension. 329 | image = tf.expand_dims(image, axis=0) 330 | blurred = tf.nn.depthwise_conv2d( 331 | image, blur_h, strides=[1, 1, 1, 1], padding=padding) 332 | blurred = tf.nn.depthwise_conv2d( 333 | blurred, blur_v, strides=[1, 1, 1, 1], padding=padding) 334 | if expand_batch_dim: 335 | blurred = tf.squeeze(blurred, axis=0) 336 | return blurred 337 | 338 | 339 | def random_crop_with_resize(image, height, width, p=1.0): 340 | """Randomly crop and resize an image. 341 | 342 | Args: 343 | image: `Tensor` representing an image of arbitrary size. 344 | height: Height of output image. 345 | width: Width of output image. 346 | p: Probability of applying this transformation. 347 | 348 | Returns: 349 | A preprocessed image `Tensor`. 350 | """ 351 | def _transform(image): # pylint: disable=missing-docstring 352 | image = crop_and_resize(image, height, width) 353 | return image 354 | return random_apply(_transform, p=p, x=image) 355 | 356 | 357 | def random_color_jitter(image, p=1.0): 358 | def _transform(image): 359 | color_jitter_t = functools.partial( 360 | color_jitter, strength=FLAGS.color_jitter_strength) 361 | image = random_apply(color_jitter_t, p=0.8, x=image) 362 | return random_apply(to_grayscale, p=0.2, x=image) 363 | return random_apply(_transform, p=p, x=image) 364 | 365 | 366 | def random_blur(image, height, width, p=1.0): 367 | """Randomly blur an image. 368 | 369 | Args: 370 | image: `Tensor` representing an image of arbitrary size. 371 | height: Height of output image. 372 | width: Width of output image. 373 | p: probability of applying this transformation. 374 | 375 | Returns: 376 | A preprocessed image `Tensor`. 377 | """ 378 | del width 379 | def _transform(image): 380 | sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32) 381 | return gaussian_blur( 382 | image, kernel_size=height//10, sigma=sigma, padding='SAME') 383 | return random_apply(_transform, p=p, x=image) 384 | 385 | 386 | def batch_random_blur(images_list, height, width, blur_probability=0.5): 387 | """Apply efficient batch data transformations. 388 | 389 | Args: 390 | images_list: a list of image tensors. 391 | height: the height of image. 392 | width: the width of image. 393 | blur_probability: the probaility to apply the blur operator. 394 | 395 | Returns: 396 | Preprocessed feature list. 397 | """ 398 | def generate_selector(p, bsz): 399 | shape = [bsz, 1, 1, 1] 400 | selector = tf.cast( 401 | tf.less(tf.random_uniform(shape, 0, 1, dtype=tf.float32), p), 402 | tf.float32) 403 | return selector 404 | 405 | new_images_list = [] 406 | for images in images_list: 407 | images_new = random_blur(images, height, width, p=1.) 408 | selector = generate_selector(blur_probability, tf.shape(images)[0]) 409 | images = images_new * selector + images * (1 - selector) 410 | images = tf.clip_by_value(images, 0., 1.) 411 | new_images_list.append(images) 412 | 413 | return new_images_list 414 | 415 | 416 | def preprocess_for_train(image, height, width, 417 | color_distort=True, crop=True, flip=True): 418 | """Preprocesses the given image for training. 419 | 420 | Args: 421 | image: `Tensor` representing an image of arbitrary size. 422 | height: Height of output image. 423 | width: Width of output image. 424 | color_distort: Whether to apply the color distortion. 425 | crop: Whether to crop the image. 426 | flip: Whether or not to flip left and right of an image. 427 | 428 | Returns: 429 | A preprocessed image `Tensor`. 430 | """ 431 | if crop: 432 | image = random_crop_with_resize(image, height, width) 433 | if flip: 434 | image = tf.image.random_flip_left_right(image) 435 | if color_distort: 436 | image = random_color_jitter(image) 437 | image = tf.reshape(image, [height, width, 3]) 438 | image = tf.clip_by_value(image, 0., 1.) 439 | return image 440 | 441 | 442 | def preprocess_for_eval(image, height, width, crop=True): 443 | """Preprocesses the given image for evaluation. 444 | 445 | Args: 446 | image: `Tensor` representing an image of arbitrary size. 447 | height: Height of output image. 448 | width: Width of output image. 449 | crop: Whether or not to (center) crop the test images. 450 | 451 | Returns: 452 | A preprocessed image `Tensor`. 453 | """ 454 | if crop: 455 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION) 456 | image = tf.reshape(image, [height, width, 3]) 457 | image = tf.clip_by_value(image, 0., 1.) 458 | return image 459 | 460 | 461 | def preprocess_image(image, height, width, is_training=False, 462 | color_distort=True, test_crop=True): 463 | """Preprocesses the given image. 464 | 465 | Args: 466 | image: `Tensor` representing an image of arbitrary size. 467 | height: Height of output image. 468 | width: Width of output image. 469 | is_training: `bool` for whether the preprocessing is for training. 470 | color_distort: whether to apply the color distortion. 471 | test_crop: whether or not to extract a central crop of the images 472 | (as for standard ImageNet evaluation) during the evaluation. 473 | 474 | Returns: 475 | A preprocessed image `Tensor` of range [0, 1]. 476 | """ 477 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 478 | if is_training: 479 | return preprocess_for_train(image, height, width, color_distort) 480 | else: 481 | return preprocess_for_eval(image, height, width, test_crop) 482 | -------------------------------------------------------------------------------- /CIFAR10/lars_optimizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Functions and classes related to optimization (weight updates).""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import re 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | EETA_DEFAULT = 0.001 27 | 28 | 29 | class LARSOptimizer(tf.train.Optimizer): 30 | """Layer-wise Adaptive Rate Scaling for large batch training. 31 | 32 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You, 33 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888) 34 | """ 35 | 36 | def __init__(self, 37 | learning_rate, 38 | momentum=0.9, 39 | use_nesterov=False, 40 | weight_decay=0.0, 41 | exclude_from_weight_decay=None, 42 | exclude_from_layer_adaptation=None, 43 | classic_momentum=True, 44 | eeta=EETA_DEFAULT, 45 | name="LARSOptimizer"): 46 | """Constructs a LARSOptimizer. 47 | 48 | Args: 49 | learning_rate: A `float` for learning rate. 50 | momentum: A `float` for momentum. 51 | use_nesterov: A 'Boolean' for whether to use nesterov momentum. 52 | weight_decay: A `float` for weight decay. 53 | exclude_from_weight_decay: A list of `string` for variable screening, if 54 | any of the string appears in a variable's name, the variable will be 55 | excluded for computing weight decay. For example, one could specify 56 | the list like ['batch_normalization', 'bias'] to exclude BN and bias 57 | from weight decay. 58 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but 59 | for layer adaptation. If it is None, it will be defaulted the same as 60 | exclude_from_weight_decay. 61 | classic_momentum: A `boolean` for whether to use classic (or popular) 62 | momentum. The learning rate is applied during momeuntum update in 63 | classic momentum, but after momentum for popular momentum. 64 | eeta: A `float` for scaling of learning rate when computing trust ratio. 65 | name: The name for the scope. 66 | """ 67 | super(LARSOptimizer, self).__init__(False, name) 68 | 69 | self.learning_rate = learning_rate 70 | self.momentum = momentum 71 | self.weight_decay = weight_decay 72 | self.use_nesterov = use_nesterov 73 | self.classic_momentum = classic_momentum 74 | self.eeta = eeta 75 | self.exclude_from_weight_decay = exclude_from_weight_decay 76 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the 77 | # arg is None. 78 | if exclude_from_layer_adaptation: 79 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation 80 | else: 81 | self.exclude_from_layer_adaptation = exclude_from_weight_decay 82 | 83 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 84 | if global_step is None: 85 | global_step = tf.train.get_or_create_global_step() 86 | new_global_step = global_step + 1 87 | 88 | assignments = [] 89 | for (grad, param) in grads_and_vars: 90 | if grad is None or param is None: 91 | continue 92 | 93 | param_name = param.op.name 94 | 95 | v = tf.get_variable( 96 | name=param_name + "/Momentum", 97 | shape=param.shape.as_list(), 98 | dtype=tf.float32, 99 | trainable=False, 100 | initializer=tf.zeros_initializer()) 101 | 102 | if self._use_weight_decay(param_name): 103 | grad += self.weight_decay * param 104 | 105 | if self.classic_momentum: 106 | trust_ratio = 1.0 107 | if self._do_layer_adaptation(param_name): 108 | w_norm = tf.norm(param, ord=2) 109 | g_norm = tf.norm(grad, ord=2) 110 | trust_ratio = tf.where( 111 | tf.greater(w_norm, 0), tf.where( 112 | tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 113 | 1.0), 114 | 1.0) 115 | scaled_lr = self.learning_rate * trust_ratio 116 | 117 | next_v = tf.multiply(self.momentum, v) + scaled_lr * grad 118 | if self.use_nesterov: 119 | update = tf.multiply(self.momentum, next_v) + scaled_lr * grad 120 | else: 121 | update = next_v 122 | next_param = param - update 123 | else: 124 | next_v = tf.multiply(self.momentum, v) + grad 125 | if self.use_nesterov: 126 | update = tf.multiply(self.momentum, next_v) + grad 127 | else: 128 | update = next_v 129 | 130 | trust_ratio = 1.0 131 | if self._do_layer_adaptation(param_name): 132 | w_norm = tf.norm(param, ord=2) 133 | v_norm = tf.norm(update, ord=2) 134 | trust_ratio = tf.where( 135 | tf.greater(w_norm, 0), tf.where( 136 | tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm), 137 | 1.0), 138 | 1.0) 139 | scaled_lr = trust_ratio * self.learning_rate 140 | next_param = param - scaled_lr * update 141 | 142 | assignments.extend( 143 | [param.assign(next_param), 144 | v.assign(next_v), 145 | global_step.assign(new_global_step)]) 146 | return tf.group(*assignments, name=name) 147 | 148 | def _use_weight_decay(self, param_name): 149 | """Whether to use L2 weight decay for `param_name`.""" 150 | if not self.weight_decay: 151 | return False 152 | if self.exclude_from_weight_decay: 153 | for r in self.exclude_from_weight_decay: 154 | if re.search(r, param_name) is not None: 155 | return False 156 | return True 157 | 158 | def _do_layer_adaptation(self, param_name): 159 | """Whether to do layer-wise learning rate adaptation for `param_name`.""" 160 | if self.exclude_from_layer_adaptation: 161 | for r in self.exclude_from_layer_adaptation: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | -------------------------------------------------------------------------------- /CIFAR10/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Model specification for SimCLR.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | 24 | import data_util as data_util 25 | from lars_optimizer import LARSOptimizer 26 | import model_util as model_util 27 | import objective as obj_lib 28 | 29 | import tensorflow.compat.v1 as tf 30 | import tensorflow.compat.v2 as tf2 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | 35 | def build_model_fn(model, num_classes, num_train_examples): 36 | """Build model function.""" 37 | def model_fn(features, labels, mode, params=None): 38 | """Build model and optimizer.""" 39 | is_training = mode == tf.estimator.ModeKeys.TRAIN 40 | 41 | # Check training mode. 42 | if FLAGS.train_mode == 'pretrain': 43 | num_transforms = 2 44 | if FLAGS.fine_tune_after_block > -1: 45 | raise ValueError('Does not support layer freezing during pretraining,' 46 | 'should set fine_tune_after_block<=-1 for safety.') 47 | elif FLAGS.train_mode == 'finetune': 48 | num_transforms = 1 49 | else: 50 | raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode)) 51 | 52 | # Split channels, and optionally apply extra batched augmentation. 53 | features_list = tf.split( 54 | features, num_or_size_splits=num_transforms, axis=-1) 55 | if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain': 56 | features_list = data_util.batch_random_blur( 57 | features_list, FLAGS.image_size, FLAGS.image_size) 58 | features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c) 59 | 60 | # Base network forward pass. 61 | with tf.variable_scope('base_model'): 62 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4: 63 | # Finetune just supervised (linear) head will not update BN stats. 64 | model_train_mode = False 65 | else: 66 | # Pretrain or finetuen anything else will update BN stats. 67 | model_train_mode = is_training 68 | hiddens = model(features, is_training=model_train_mode) 69 | 70 | # Add head and loss. 71 | if FLAGS.train_mode == 'pretrain': 72 | tpu_context = params['context'] if 'context' in params else None 73 | hiddens_proj = model_util.projection_head(hiddens, is_training) 74 | contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss( 75 | hiddens_proj, 76 | hidden_norm=FLAGS.hidden_norm, 77 | temperature=FLAGS.temperature, 78 | tpu_context=tpu_context if is_training else None) 79 | # note the inv_pred_loss is performed on the representation, not on the projection head 80 | inv_pred_loss = obj_lib.add_inv_pred_loss( 81 | hiddens, 82 | inv_pred_coeff=FLAGS.inv_pred_coeff) 83 | logits_sup = tf.zeros([params['batch_size'], num_classes]) 84 | else: 85 | contrast_loss = tf.zeros([]) 86 | inv_pred_loss = tf.zeros([]) 87 | logits_con = tf.zeros([params['batch_size'], 10]) 88 | labels_con = tf.zeros([params['batch_size'], 10]) 89 | logits_sup = model_util.supervised_head( 90 | hiddens, num_classes, is_training) 91 | obj_lib.add_supervised_loss( 92 | labels=labels['labels'], 93 | logits=logits_sup, 94 | weights=labels['mask']) 95 | 96 | # Add weight decay to loss, for non-LARS optimizers. 97 | model_util.add_weight_decay(adjust_per_optimizer=True) 98 | loss = tf.losses.get_total_loss() 99 | 100 | if FLAGS.train_mode == 'pretrain': 101 | variables_to_train = tf.trainable_variables() 102 | else: 103 | collection_prefix = 'trainable_variables_inblock_' 104 | variables_to_train = [] 105 | for j in range(FLAGS.fine_tune_after_block + 1, 6): 106 | variables_to_train += tf.get_collection(collection_prefix + str(j)) 107 | assert variables_to_train, 'variables_to_train shouldn\'t be empty!' 108 | 109 | tf.logging.info('===============Variables to train (begin)===============') 110 | tf.logging.info(variables_to_train) 111 | tf.logging.info('================Variables to train (end)================') 112 | 113 | learning_rate = model_util.learning_rate_schedule( 114 | FLAGS.learning_rate, num_train_examples) 115 | 116 | if is_training: 117 | if FLAGS.train_summary_steps > 0: 118 | # Compute stats for the summary. 119 | prob_con = tf.nn.softmax(logits_con) 120 | entropy_con = - tf.reduce_mean( 121 | tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1)) 122 | 123 | summary_writer = tf2.summary.create_file_writer(FLAGS.model_dir) 124 | # TODO(iamtingchen): remove this control_dependencies in the future. 125 | with tf.control_dependencies([summary_writer.init()]): 126 | with summary_writer.as_default(): 127 | should_record = tf.math.equal( 128 | tf.math.floormod(tf.train.get_global_step(), 129 | FLAGS.train_summary_steps), 0) 130 | with tf2.summary.record_if(should_record): 131 | contrast_acc = tf.equal( 132 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1)) 133 | contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32)) 134 | label_acc = tf.equal( 135 | tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1)) 136 | label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32)) 137 | tf2.summary.scalar( 138 | 'train_contrast_loss', 139 | contrast_loss, 140 | step=tf.train.get_global_step()) 141 | tf2.summary.scalar( 142 | 'train_inv_pred_loss', 143 | inv_pred_loss, 144 | step=tf.train.get_global_step()) 145 | tf2.summary.scalar( 146 | 'train_contrast_acc', 147 | contrast_acc, 148 | step=tf.train.get_global_step()) 149 | tf2.summary.scalar( 150 | 'train_label_accuracy', 151 | label_acc, 152 | step=tf.train.get_global_step()) 153 | tf2.summary.scalar( 154 | 'contrast_entropy', 155 | entropy_con, 156 | step=tf.train.get_global_step()) 157 | tf2.summary.scalar( 158 | 'learning_rate', learning_rate, 159 | step=tf.train.get_global_step()) 160 | tf2.summary.scalar( 161 | 'input_mean', 162 | tf.reduce_mean(features), 163 | step=tf.train.get_global_step()) 164 | tf2.summary.scalar( 165 | 'input_max', 166 | tf.reduce_max(features), 167 | step=tf.train.get_global_step()) 168 | tf2.summary.scalar( 169 | 'input_min', 170 | tf.reduce_min(features), 171 | step=tf.train.get_global_step()) 172 | tf2.summary.scalar( 173 | 'num_labels', 174 | tf.reduce_mean(tf.reduce_sum(labels['labels'], -1)), 175 | step=tf.train.get_global_step()) 176 | 177 | if FLAGS.optimizer == 'momentum': 178 | optimizer = tf.train.MomentumOptimizer( 179 | learning_rate, FLAGS.momentum, use_nesterov=True) 180 | elif FLAGS.optimizer == 'adam': 181 | optimizer = tf.train.AdamOptimizer( 182 | learning_rate) 183 | elif FLAGS.optimizer == 'lars': 184 | optimizer = LARSOptimizer( 185 | learning_rate, 186 | momentum=FLAGS.momentum, 187 | weight_decay=FLAGS.weight_decay, 188 | exclude_from_weight_decay=['batch_normalization', 'bias']) 189 | else: 190 | raise ValueError('Unknown optimizer {}'.format(FLAGS.optimizer)) 191 | 192 | if FLAGS.use_tpu: 193 | optimizer = tf.tpu.CrossShardOptimizer(optimizer) 194 | 195 | control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 196 | if FLAGS.train_summary_steps > 0: 197 | control_deps.extend(tf.summary.all_v2_summary_ops()) 198 | with tf.control_dependencies(control_deps): 199 | train_op = optimizer.minimize( 200 | loss, global_step=tf.train.get_or_create_global_step(), 201 | var_list=variables_to_train) 202 | 203 | if FLAGS.checkpoint: 204 | def scaffold_fn(): 205 | """Scaffold function to restore non-logits vars from checkpoint.""" 206 | tf.train.init_from_checkpoint( 207 | FLAGS.checkpoint, 208 | {v.op.name: v.op.name 209 | for v in tf.global_variables(FLAGS.variable_schema)}) 210 | 211 | if FLAGS.zero_init_logits_layer: 212 | # Init op that initializes output layer parameters to zeros. 213 | output_layer_parameters = [ 214 | var for var in tf.trainable_variables() if var.name.startswith( 215 | 'head_supervised')] 216 | tf.logging.info('Initializing output layer parameters %s to zero', 217 | [x.op.name for x in output_layer_parameters]) 218 | with tf.control_dependencies([tf.global_variables_initializer()]): 219 | init_op = tf.group([ 220 | tf.assign(x, tf.zeros_like(x)) 221 | for x in output_layer_parameters]) 222 | return tf.train.Scaffold(init_op=init_op) 223 | else: 224 | return tf.train.Scaffold() 225 | else: 226 | scaffold_fn = None 227 | 228 | return tf.estimator.tpu.TPUEstimatorSpec( 229 | mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn) 230 | else: 231 | 232 | def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask, 233 | **kws): 234 | """Inner metric function.""" 235 | metrics = {k: tf.metrics.mean(v, weights=mask) 236 | for k, v in kws.items()} 237 | metrics['label_top_1_accuracy'] = tf.metrics.accuracy( 238 | tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1), 239 | weights=mask) 240 | metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k( 241 | tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask) 242 | metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy( 243 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1), 244 | weights=mask) 245 | metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k( 246 | tf.argmax(labels_con, 1), logits_con, k=5, weights=mask) 247 | return metrics 248 | 249 | metrics = { 250 | 'logits_sup': logits_sup, 251 | 'labels_sup': labels['labels'], 252 | 'logits_con': logits_con, 253 | 'labels_con': labels_con, 254 | 'mask': labels['mask'], 255 | 'contrast_loss': tf.fill((params['batch_size'],), contrast_loss), 256 | 'inv_pred_loss': tf.fill((params['batch_size'],), inv_pred_loss), 257 | 'regularization_loss': tf.fill((params['batch_size'],), 258 | tf.losses.get_regularization_loss()), 259 | } 260 | 261 | return tf.estimator.tpu.TPUEstimatorSpec( 262 | mode=mode, 263 | loss=loss, 264 | eval_metrics=(metric_fn, metrics), 265 | scaffold_fn=None) 266 | 267 | return model_fn 268 | -------------------------------------------------------------------------------- /CIFAR10/model_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Network architectures related functions used in SimCLR.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | 24 | import resnet 25 | 26 | import tensorflow.compat.v1 as tf 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | def add_weight_decay(adjust_per_optimizer=True): 32 | """Compute weight decay from flags.""" 33 | if adjust_per_optimizer and 'lars' in FLAGS.optimizer: 34 | # Weight decay are taking care of by optimizer for these cases. 35 | return 36 | 37 | l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables() 38 | if 'batch_normalization' not in v.name] 39 | tf.losses.add_loss( 40 | FLAGS.weight_decay * tf.add_n(l2_losses), 41 | tf.GraphKeys.REGULARIZATION_LOSSES) 42 | 43 | 44 | def get_train_steps(num_examples): 45 | """Determine the number of training steps.""" 46 | return FLAGS.train_steps or ( 47 | num_examples * FLAGS.train_epochs // FLAGS.train_batch_size + 1) 48 | 49 | 50 | def learning_rate_schedule(base_learning_rate, num_examples): 51 | """Build learning rate schedule.""" 52 | global_step = tf.train.get_or_create_global_step() 53 | warmup_steps = int(round( 54 | FLAGS.warmup_epochs * num_examples // FLAGS.train_batch_size)) 55 | scaled_lr = base_learning_rate * FLAGS.train_batch_size / 256. 56 | learning_rate = (tf.to_float(global_step) / int(warmup_steps) * scaled_lr 57 | if warmup_steps else scaled_lr) 58 | 59 | # Cosine decay learning rate schedule 60 | total_steps = get_train_steps(num_examples) 61 | learning_rate = tf.where( 62 | global_step < warmup_steps, learning_rate, 63 | tf.train.cosine_decay( 64 | scaled_lr, 65 | global_step - warmup_steps, 66 | total_steps - warmup_steps)) 67 | 68 | return learning_rate 69 | 70 | 71 | def linear_layer(x, 72 | is_training, 73 | num_classes, 74 | use_bias=True, 75 | use_bn=False, 76 | name='linear_layer'): 77 | """Linear head for linear evaluation. 78 | 79 | Args: 80 | x: hidden state tensor of shape (bsz, dim). 81 | is_training: boolean indicator for training or test. 82 | num_classes: number of classes. 83 | use_bias: whether or not to use bias. 84 | use_bn: whether or not to use BN for output units. 85 | name: the name for variable scope. 86 | 87 | Returns: 88 | logits of shape (bsz, num_classes) 89 | """ 90 | assert x.shape.ndims == 2, x.shape 91 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 92 | x = tf.layers.dense( 93 | inputs=x, 94 | units=num_classes, 95 | use_bias=use_bias and not use_bn, 96 | kernel_initializer=tf.random_normal_initializer(stddev=.01)) 97 | if use_bn: 98 | x = resnet.batch_norm_relu(x, is_training, relu=False, center=use_bias) 99 | x = tf.identity(x, '%s_out' % name) 100 | return x 101 | 102 | 103 | def projection_head(hiddens, is_training, name='head_contrastive'): 104 | """Head for projecting hiddens fo contrastive loss.""" 105 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 106 | if FLAGS.head_proj_mode == 'none': 107 | pass # directly use the output hiddens as hiddens 108 | elif FLAGS.head_proj_mode == 'linear': 109 | hiddens = linear_layer( 110 | hiddens, is_training, FLAGS.head_proj_dim, 111 | use_bias=False, use_bn=True, name='l_0') 112 | elif FLAGS.head_proj_mode == 'nonlinear': 113 | hiddens = linear_layer( 114 | hiddens, is_training, hiddens.shape[-1], 115 | use_bias=True, use_bn=True, name='nl_0') 116 | for j in range(1, FLAGS.num_nlh_layers + 1): 117 | hiddens = tf.nn.relu(hiddens) 118 | hiddens = linear_layer( 119 | hiddens, is_training, FLAGS.head_proj_dim, 120 | use_bias=False, use_bn=True, name='nl_%d'%j) 121 | else: 122 | raise ValueError('Unknown head projection mode {}'.format( 123 | FLAGS.head_proj_mode)) 124 | return hiddens 125 | 126 | 127 | def supervised_head(hiddens, num_classes, is_training, name='head_supervised'): 128 | """Add supervised head & also add its variables to inblock collection.""" 129 | with tf.variable_scope(name): 130 | logits = linear_layer(hiddens, is_training, num_classes) 131 | for var in tf.trainable_variables(): 132 | if var.name.startswith(name): 133 | tf.add_to_collection('trainable_variables_inblock_5', var) 134 | return logits 135 | -------------------------------------------------------------------------------- /CIFAR10/objective.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Contrastive loss functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | LARGE_NUM = 1e9 31 | 32 | 33 | def add_supervised_loss(labels, logits, weights, **kwargs): 34 | """Compute loss for model and add it to loss collection.""" 35 | return tf.losses.softmax_cross_entropy(labels, logits, weights, **kwargs) 36 | 37 | def add_inv_pred_loss(repres, inv_pred_coeff=0.0, weights=1.0): 38 | """Compute inv predictive loss for model, not sure if it is compatible with TPU. 39 | 40 | Args: 41 | repres: representation vector (`Tensor`) of shape (bsz, dim). 42 | inv_pred_coeff: coeff of inverse predictive loss. 43 | 44 | Returns: 45 | A loss scalar. 46 | """ 47 | # Get representation1 and representation2. 48 | representation1, representation2 = tf.split(repres, 2, 0) 49 | inv_pred_loss = tf.losses.mean_squared_error(representation1, representation2, weights=weights*inv_pred_coeff) 50 | 51 | return inv_pred_loss 52 | 53 | def tpu_cross_replica_concat(tensor, tpu_context=None): 54 | """Reduce a concatenation of the `tensor` across TPU cores. 55 | 56 | Args: 57 | tensor: tensor to concatenate. 58 | tpu_context: A `TPUContext`. If not set, CPU execution is assumed. 59 | 60 | Returns: 61 | Tensor of the same rank as `tensor` with first dimension `num_replicas` 62 | times larger. 63 | """ 64 | if tpu_context is None or tpu_context.num_replicas <= 1: 65 | return tensor 66 | 67 | num_replicas = tpu_context.num_replicas 68 | 69 | with tf.name_scope('tpu_cross_replica_concat'): 70 | # This creates a tensor that is like the input tensor but has an added 71 | # replica dimension as the outermost dimension. On each replica it will 72 | # contain the local values and zeros for all other values that need to be 73 | # fetched from other replicas. 74 | ext_tensor = tf.scatter_nd( 75 | indices=[[xla.replica_id()]], 76 | updates=[tensor], 77 | shape=[num_replicas] + tensor.shape.as_list()) 78 | 79 | # As every value is only present on one replica and 0 in all others, adding 80 | # them all together will result in the full tensor on all replicas. 81 | ext_tensor = tf.tpu.cross_replica_sum(ext_tensor) 82 | 83 | # Flatten the replica dimension. 84 | # The first dimension size will be: tensor.shape[0] * num_replicas 85 | # Using [-1] trick to support also scalar input. 86 | return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:]) 87 | 88 | 89 | def add_contrastive_loss(hidden, 90 | hidden_norm=True, 91 | temperature=1.0, 92 | tpu_context=None, 93 | weights=1.0): 94 | """Compute loss for model. 95 | 96 | Args: 97 | hidden: hidden vector (`Tensor`) of shape (bsz, dim). 98 | hidden_norm: whether or not to use normalization on the hidden vector. 99 | temperature: a `floating` number for temperature scaling. 100 | tpu_context: context information for tpu. 101 | weights: a weighting number or vector. 102 | 103 | Returns: 104 | A loss scalar. 105 | The logits for contrastive prediction task. 106 | The labels for contrastive prediction task. 107 | """ 108 | # Get (normalized) hidden1 and hidden2. 109 | if hidden_norm: 110 | hidden = tf.math.l2_normalize(hidden, -1) 111 | hidden1, hidden2 = tf.split(hidden, 2, 0) 112 | batch_size = tf.shape(hidden1)[0] 113 | 114 | # Gather hidden1/hidden2 across replicas and create local labels. 115 | if tpu_context is not None: 116 | hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context) 117 | hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context) 118 | enlarged_batch_size = tf.shape(hidden1_large)[0] 119 | # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id. 120 | replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) 121 | labels_idx = tf.range(batch_size) + replica_id * batch_size 122 | labels = tf.one_hot(labels_idx, enlarged_batch_size * 2) 123 | masks = tf.one_hot(labels_idx, enlarged_batch_size) 124 | else: 125 | hidden1_large = hidden1 126 | hidden2_large = hidden2 127 | labels = tf.one_hot(tf.range(batch_size), batch_size * 2) # (bz, bz*2) 128 | masks = tf.one_hot(tf.range(batch_size), batch_size) # (bz, bz) 129 | 130 | logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature 131 | logits_aa = logits_aa - masks * LARGE_NUM # (bz, bz) 132 | logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature 133 | logits_bb = logits_bb - masks * LARGE_NUM # (bz, bz) 134 | logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature # (bz, bz) 135 | logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature # (bz, bz) 136 | 137 | loss_a = tf.losses.softmax_cross_entropy( 138 | labels, tf.concat([logits_ab, logits_aa], 1), weights=weights) 139 | loss_b = tf.losses.softmax_cross_entropy( 140 | labels, tf.concat([logits_ba, logits_bb], 1), weights=weights) 141 | loss = loss_a + loss_b 142 | 143 | return loss, logits_ab, labels 144 | -------------------------------------------------------------------------------- /CIFAR10/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | tensorflow-datasets==2.1.0 3 | tensorflow-hub==0.7.0 4 | -------------------------------------------------------------------------------- /CIFAR10/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The main training pipeline.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import json 23 | import math 24 | import os 25 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 26 | from absl import app 27 | from absl import flags 28 | 29 | import resnet 30 | import data as data_lib 31 | import model as model_lib 32 | import model_util as model_util 33 | 34 | import tensorflow.compat.v1 as tf 35 | import tensorflow_datasets as tfds 36 | import tensorflow_hub as hub 37 | 38 | tf.logging.set_verbosity(tf.logging.ERROR) 39 | 40 | FLAGS = flags.FLAGS 41 | 42 | 43 | flags.DEFINE_float( 44 | 'learning_rate', 0.3, 45 | 'Initial learning rate per batch size of 256.') 46 | 47 | flags.DEFINE_float( 48 | 'warmup_epochs', 10, 49 | 'Number of epochs of warmup.') 50 | 51 | flags.DEFINE_float( 52 | 'weight_decay', 1e-6, 53 | 'Amount of weight decay to use.') 54 | 55 | flags.DEFINE_float( 56 | 'batch_norm_decay', 0.9, 57 | 'Batch norm decay parameter.') 58 | 59 | flags.DEFINE_integer( 60 | 'train_batch_size', 512, 61 | 'Batch size for training.') 62 | 63 | flags.DEFINE_string( 64 | 'train_split', 'train', 65 | 'Split for training.') 66 | 67 | flags.DEFINE_integer( 68 | 'train_epochs', 100, 69 | 'Number of epochs to train for.') 70 | 71 | flags.DEFINE_integer( 72 | 'train_steps', 0, 73 | 'Number of steps to train for. If provided, overrides train_epochs.') 74 | 75 | flags.DEFINE_integer( 76 | 'eval_batch_size', 256, 77 | 'Batch size for eval.') 78 | 79 | flags.DEFINE_integer( 80 | 'train_summary_steps', 100, 81 | 'Steps before saving training summaries. If 0, will not save.') 82 | 83 | flags.DEFINE_integer( 84 | 'checkpoint_epochs', 1, 85 | 'Number of epochs between checkpoints/summaries.') 86 | 87 | flags.DEFINE_integer( 88 | 'checkpoint_steps', 0, 89 | 'Number of steps between checkpoints/summaries. If provided, overrides ' 90 | 'checkpoint_epochs.') 91 | 92 | flags.DEFINE_string( 93 | 'eval_split', 'validation', 94 | 'Split for evaluation.') 95 | 96 | flags.DEFINE_string( 97 | 'dataset', 'imagenet2012', 98 | 'Name of a dataset.') 99 | 100 | flags.DEFINE_bool( 101 | 'cache_dataset', False, 102 | 'Whether to cache the entire dataset in memory. If the dataset is ' 103 | 'ImageNet, this is a very bad idea, but for smaller datasets it can ' 104 | 'improve performance.') 105 | 106 | flags.DEFINE_enum( 107 | 'mode', 'train', ['train', 'eval', 'train_then_eval'], 108 | 'Whether to perform training or evaluation.') 109 | 110 | flags.DEFINE_enum( 111 | 'train_mode', 'pretrain', ['pretrain', 'finetune'], 112 | 'The train mode controls different objectives and trainable components.') 113 | 114 | flags.DEFINE_string( 115 | 'checkpoint', None, 116 | 'Loading from the given checkpoint for continued training or fine-tuning.') 117 | 118 | flags.DEFINE_string( 119 | 'variable_schema', '?!global_step', 120 | 'This defines whether some variable from the checkpoint should be loaded.') 121 | 122 | flags.DEFINE_bool( 123 | 'zero_init_logits_layer', False, 124 | 'If True, zero initialize layers after avg_pool for supervised learning.') 125 | 126 | flags.DEFINE_integer( 127 | 'fine_tune_after_block', -1, 128 | 'The layers after which block that we will fine-tune. -1 means fine-tuning ' 129 | 'everything. 0 means fine-tuning after stem block. 4 means fine-tuning ' 130 | 'just the linera head.') 131 | 132 | flags.DEFINE_string( 133 | 'master', None, 134 | 'Address/name of the TensorFlow master to use. By default, use an ' 135 | 'in-process master.') 136 | 137 | flags.DEFINE_string( 138 | 'model_dir', None, 139 | 'Model directory for training.') 140 | 141 | flags.DEFINE_string( 142 | 'data_dir', None, 143 | 'Directory where dataset is stored.') 144 | 145 | flags.DEFINE_bool( 146 | 'use_tpu', True, 147 | 'Whether to run on TPU.') 148 | 149 | tf.flags.DEFINE_string( 150 | 'tpu_name', None, 151 | 'The Cloud TPU to use for training. This should be either the name ' 152 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 ' 153 | 'url.') 154 | 155 | tf.flags.DEFINE_string( 156 | 'tpu_zone', None, 157 | '[Optional] GCE zone where the Cloud TPU is located in. If not ' 158 | 'specified, we will attempt to automatically detect the GCE project from ' 159 | 'metadata.') 160 | 161 | tf.flags.DEFINE_string( 162 | 'gcp_project', None, 163 | '[Optional] Project name for the Cloud TPU-enabled project. If not ' 164 | 'specified, we will attempt to automatically detect the GCE project from ' 165 | 'metadata.') 166 | 167 | flags.DEFINE_enum( 168 | 'optimizer', 'lars', ['momentum', 'adam', 'lars'], 169 | 'Optimizer to use.') 170 | 171 | flags.DEFINE_float( 172 | 'momentum', 0.9, 173 | 'Momentum parameter.') 174 | 175 | flags.DEFINE_string( 176 | 'eval_name', None, 177 | 'Name for eval.') 178 | 179 | flags.DEFINE_integer( 180 | 'keep_checkpoint_max', 5, 181 | 'Maximum number of checkpoints to keep.') 182 | 183 | flags.DEFINE_integer( 184 | 'keep_hub_module_max', 1, 185 | 'Maximum number of Hub modules to keep.') 186 | 187 | flags.DEFINE_float( 188 | 'temperature', 0.1, 189 | 'Temperature parameter for contrastive loss.') 190 | 191 | flags.DEFINE_boolean( 192 | 'hidden_norm', True, 193 | 'Whether normalize hidden representation in contrastive loss.') 194 | 195 | flags.DEFINE_enum( 196 | 'head_proj_mode', 'nonlinear', ['none', 'linear', 'nonlinear'], 197 | 'How the head projection is done.') 198 | 199 | flags.DEFINE_integer( 200 | 'head_proj_dim', 128, 201 | 'Number of head projection dimension.') 202 | 203 | flags.DEFINE_integer( 204 | 'num_nlh_layers', 1, 205 | 'Number of non-linear head layers.') 206 | 207 | flags.DEFINE_boolean( 208 | 'global_bn', True, 209 | 'Whether to aggregate BN statistics across distributed cores.') 210 | 211 | flags.DEFINE_integer( 212 | 'width_multiplier', 1, 213 | 'Multiplier to change width of network.') 214 | 215 | flags.DEFINE_integer( 216 | 'resnet_depth', 50, 217 | 'Depth of ResNet.') 218 | 219 | flags.DEFINE_integer( 220 | 'image_size', 224, 221 | 'Input image size.') 222 | 223 | flags.DEFINE_float( 224 | 'color_jitter_strength', 1.0, 225 | 'The strength of color jittering.') 226 | 227 | flags.DEFINE_float( 228 | 'inv_pred_coeff', 0.05, 229 | 'Coeff for inverse predictive coding.') 230 | 231 | flags.DEFINE_boolean( 232 | 'use_blur', True, 233 | 'Whether or not to use Gaussian blur for augmentation during pretraining.') 234 | 235 | 236 | def build_hub_module(model, num_classes, global_step, checkpoint_path): 237 | """Create TF-Hub module.""" 238 | 239 | tags_and_args = [ 240 | # The default graph is built with batch_norm, dropout etc. in inference 241 | # mode. This graph version is good for inference, not training. 242 | ([], {'is_training': False}), 243 | # A separate "train" graph builds batch_norm, dropout etc. in training 244 | # mode. 245 | (['train'], {'is_training': True}), 246 | ] 247 | 248 | def module_fn(is_training): 249 | """Function that builds TF-Hub module.""" 250 | endpoints = {} 251 | inputs = tf.placeholder( 252 | tf.float32, [None, FLAGS.image_size, FLAGS.image_size, 3]) 253 | with tf.variable_scope('base_model', reuse=tf.AUTO_REUSE): 254 | hiddens = model(inputs, is_training) 255 | for v in ['initial_conv', 'initial_max_pool', 'block_group1', 256 | 'block_group2', 'block_group3', 'block_group4', 257 | 'final_avg_pool']: 258 | endpoints[v] = tf.get_default_graph().get_tensor_by_name( 259 | 'base_model/{}:0'.format(v)) 260 | if FLAGS.train_mode == 'pretrain': 261 | hiddens_proj = model_util.projection_head(hiddens, is_training) 262 | endpoints['proj_head_input'] = hiddens 263 | endpoints['proj_head_output'] = hiddens_proj 264 | else: 265 | logits_sup = model_util.supervised_head( 266 | hiddens, num_classes, is_training) 267 | endpoints['logits_sup'] = logits_sup 268 | hub.add_signature(inputs=dict(images=inputs), 269 | outputs=dict(endpoints, default=hiddens)) 270 | 271 | # Drop the non-supported non-standard graph collection. 272 | drop_collections = ['trainable_variables_inblock_%d'%d for d in range(6)] 273 | spec = hub.create_module_spec(module_fn, tags_and_args, drop_collections) 274 | hub_export_dir = os.path.join(FLAGS.model_dir, 'hub') 275 | checkpoint_export_dir = os.path.join(hub_export_dir, str(global_step)) 276 | if tf.io.gfile.exists(checkpoint_export_dir): 277 | # Do not save if checkpoint already saved. 278 | tf.io.gfile.rmtree(checkpoint_export_dir) 279 | spec.export( 280 | checkpoint_export_dir, 281 | checkpoint_path=checkpoint_path, 282 | name_transform_fn=None) 283 | 284 | if FLAGS.keep_hub_module_max > 0: 285 | # Delete old exported Hub modules. 286 | exported_steps = [] 287 | for subdir in tf.io.gfile.listdir(hub_export_dir): 288 | if not subdir.isdigit(): 289 | continue 290 | exported_steps.append(int(subdir)) 291 | exported_steps.sort() 292 | for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]: 293 | tf.io.gfile.rmtree(os.path.join(hub_export_dir, str(step_to_delete))) 294 | 295 | 296 | def perform_evaluation(estimator, input_fn, eval_steps, model, num_classes, 297 | checkpoint_path=None): 298 | """Perform evaluation. 299 | 300 | Args: 301 | estimator: TPUEstimator instance. 302 | input_fn: Input function for estimator. 303 | eval_steps: Number of steps for evaluation. 304 | model: Instance of transfer_learning.models.Model. 305 | num_classes: Number of classes to build model for. 306 | checkpoint_path: Path of checkpoint to evaluate. 307 | 308 | Returns: 309 | result: A Dict of metrics and their values. 310 | """ 311 | if not checkpoint_path: 312 | checkpoint_path = estimator.latest_checkpoint() 313 | result = estimator.evaluate( 314 | input_fn, eval_steps, checkpoint_path=checkpoint_path, 315 | name=FLAGS.eval_name) 316 | 317 | # Record results as JSON. 318 | result_json_path = os.path.join(FLAGS.model_dir, 'result.json') 319 | with tf.io.gfile.GFile(result_json_path, 'w') as f: 320 | json.dump({k: float(v) for k, v in result.items()}, f) 321 | result_json_path = os.path.join( 322 | FLAGS.model_dir, 'result_%d.json'%result['global_step']) 323 | with tf.io.gfile.GFile(result_json_path, 'w') as f: 324 | json.dump({k: float(v) for k, v in result.items()}, f) 325 | flag_json_path = os.path.join(FLAGS.model_dir, 'flags.json') 326 | with tf.io.gfile.GFile(flag_json_path, 'w') as f: 327 | json.dump(FLAGS.flag_values_dict(), f) 328 | 329 | # Save Hub module. 330 | build_hub_module(model, num_classes, 331 | global_step=result['global_step'], 332 | checkpoint_path=checkpoint_path) 333 | 334 | return result 335 | 336 | 337 | def main(argv): 338 | if len(argv) > 1: 339 | raise app.UsageError('Too many command-line arguments.') 340 | 341 | # Enable training summary. 342 | if FLAGS.train_summary_steps > 0: 343 | tf.config.set_soft_device_placement(True) 344 | 345 | 346 | builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir) 347 | builder.download_and_prepare() 348 | num_train_examples = builder.info.splits[FLAGS.train_split].num_examples 349 | num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples 350 | num_classes = builder.info.features['label'].num_classes 351 | 352 | train_steps = model_util.get_train_steps(num_train_examples) 353 | eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size)) 354 | epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size)) 355 | 356 | resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay 357 | model = resnet.resnet_v1( 358 | resnet_depth=FLAGS.resnet_depth, 359 | width_multiplier=FLAGS.width_multiplier, 360 | cifar_stem=FLAGS.image_size <= 32) 361 | 362 | checkpoint_steps = ( 363 | FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps)) 364 | 365 | cluster = None 366 | if FLAGS.use_tpu and FLAGS.master is None: 367 | if FLAGS.tpu_name: 368 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver( 369 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 370 | else: 371 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver() 372 | tf.config.experimental_connect_to_cluster(cluster) 373 | tf.tpu.experimental.initialize_tpu_system(cluster) 374 | 375 | sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1#SLICED 376 | run_config = tf.estimator.tpu.RunConfig( 377 | tpu_config=tf.estimator.tpu.TPUConfig( 378 | iterations_per_loop=checkpoint_steps, 379 | eval_training_input_configuration=sliced_eval_mode), 380 | model_dir=FLAGS.model_dir, 381 | save_summary_steps=checkpoint_steps, 382 | save_checkpoints_steps=checkpoint_steps, 383 | keep_checkpoint_max=FLAGS.keep_checkpoint_max, 384 | master=FLAGS.master, 385 | cluster=cluster) 386 | estimator = tf.estimator.tpu.TPUEstimator( 387 | model_lib.build_model_fn(model, num_classes, num_train_examples), 388 | config=run_config, 389 | train_batch_size=FLAGS.train_batch_size, 390 | eval_batch_size=FLAGS.eval_batch_size, 391 | use_tpu=FLAGS.use_tpu) 392 | 393 | if FLAGS.mode == 'eval': 394 | for ckpt in tf.train.checkpoints_iterator( 395 | run_config.model_dir, min_interval_secs=15): 396 | try: 397 | result = perform_evaluation( 398 | estimator=estimator, 399 | input_fn=data_lib.build_input_fn(builder, False), 400 | eval_steps=eval_steps, 401 | model=model, 402 | num_classes=num_classes, 403 | checkpoint_path=ckpt) 404 | except tf.errors.NotFoundError: 405 | continue 406 | if result['global_step'] >= train_steps: 407 | return 408 | else: 409 | estimator.train( 410 | data_lib.build_input_fn(builder, True), max_steps=train_steps) 411 | if FLAGS.mode == 'train_then_eval': 412 | perform_evaluation( 413 | estimator=estimator, 414 | input_fn=data_lib.build_input_fn(builder, False), 415 | eval_steps=eval_steps, 416 | model=model, 417 | num_classes=num_classes) 418 | 419 | 420 | if __name__ == '__main__': 421 | tf.disable_eager_execution() # Disable eager mode when running with TF2. 422 | app.run(main) 423 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yao-Hung Hubert Tsai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MSCOCO/README.md: -------------------------------------------------------------------------------- 1 | # MS COCO Experiments 2 | 3 | Download MSCOCO2017 dataset from [here](http://cocodataset.org/) 4 | 5 | Run ```main.ipynb``` to reproduce experiments. 6 | 7 | Run ```ResNet_baseline.ipynb``` to reproduce baseline results. -------------------------------------------------------------------------------- /MSCOCO/ResNet_baseline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, time" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "TRIAL_NAME='ResNet_baseline'" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "CONF={\n", 28 | " 'niter':200,\n", 29 | " 'GPU':0,\n", 30 | " 'BS':128,\n", 31 | " 'test_BS':256,\n", 32 | " 'N_neg':3,\n", 33 | " 'name':TRIAL_NAME,\n", 34 | " 'tb_dir':os.path.join('./runs', TRIAL_NAME),\n", 35 | " 'nz':64,\n", 36 | " 'seed':10708,\n", 37 | " 'data_dir':'/DataSet/COCO',\n", 38 | " 'dataType':'train2017',\n", 39 | " 'valType':'val2017',\n", 40 | " 'LAMBDA':0.5,\n", 41 | " 'use_super':False,\n", 42 | " 'test_classes':80\n", 43 | "}" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=str(CONF['GPU'])" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# Clear any logs from previous runs\n", 62 | "time.sleep(2)\n", 63 | "import shutil\n", 64 | "shutil.rmtree(CONF['tb_dir'], ignore_errors=True)\n", 65 | "time.sleep(5)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "import texar.torch as tx\n", 75 | "import random\n", 76 | "import torch\n", 77 | "import torch.nn as nn\n", 78 | "import torch.nn.parallel\n", 79 | "from torch.nn import functional as F\n", 80 | "import torch.backends.cudnn as cudnn\n", 81 | "import torch.optim as optim\n", 82 | "import torch.utils.data\n", 83 | "import torchvision\n", 84 | "from torch.utils.tensorboard import SummaryWriter\n", 85 | "import torchvision.datasets as dset\n", 86 | "import torchvision.transforms as transforms\n", 87 | "import torchvision.utils as vutils\n", 88 | "import numpy as np\n", 89 | "from torch import autograd\n", 90 | "import multiprocessing\n", 91 | "from PIL import Image\n", 92 | "from sklearn import metrics" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "device = torch.device(\"cuda:0\" if True else \"cpu\")" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "writer = SummaryWriter(log_dir=CONF['tb_dir'])" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "random.seed(CONF['seed'])\n", 120 | "torch.manual_seed(CONF['seed'])\n", 121 | "np.random.seed(CONF['seed'])\n", 122 | "cudnn.benchmark = True" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "batch_size = CONF['BS']" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "T = transforms.Compose([\n", 141 | " transforms.RandomResizedCrop((256,256), scale=(0.3, 1.0), ratio=(0.75, 1.3333333333333333)),\n", 142 | " transforms.ColorJitter(brightness=.1, contrast=.05, saturation=.05, hue=.05),\n", 143 | " transforms.RandomHorizontalFlip(),\n", 144 | " transforms.ToTensor()\n", 145 | "])\n", 146 | "T_test = transforms.Compose([\n", 147 | " transforms.Resize((256,256)),\n", 148 | " transforms.ToTensor()\n", 149 | "])" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "from torchvision.datasets.vision import VisionDataset\n", 159 | "class CocoClassification(VisionDataset):\n", 160 | " \"\"\"`MS Coco Detection `_ Dataset.\n", 161 | "\n", 162 | " Args:\n", 163 | " root (string): Root directory where images are downloaded to.\n", 164 | " annFile (string): Path to json annotation file.\n", 165 | " transform (callable, optional): A function/transform that takes in an PIL image\n", 166 | " and returns a transformed version. E.g, ``transforms.ToTensor``\n", 167 | " target_transform (callable, optional): A function/transform that takes in the\n", 168 | " target and transforms it.\n", 169 | " transforms (callable, optional): A function/transform that takes input sample and its target as entry\n", 170 | " and returns a transformed version.\n", 171 | " \"\"\"\n", 172 | "\n", 173 | " def sample_class(self, k):\n", 174 | " if CONF['use_super']:\n", 175 | " self.classes = np.array([\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"])#np.arange(12)+1\n", 176 | " self.class_description = [\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"]\n", 177 | " return\n", 178 | " class_list = self.coco.getCatIds()\n", 179 | " self.classes = np.sort(np.random.choice(class_list, size=k, replace=False))\n", 180 | " self.class_description = self.coco.loadCats(self.classes)\n", 181 | " arr = []\n", 182 | " for catId in self.classes:\n", 183 | " arr+=self.coco.getImgIds(catIds=[catId])\n", 184 | " self.ids = sorted(list(set(arr)))\n", 185 | " \n", 186 | " def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):\n", 187 | " super(CocoClassification, self).__init__(root, transforms, transform, target_transform)\n", 188 | " from pycocotools.coco import COCO\n", 189 | " self.coco = COCO(annFile)\n", 190 | " self.sample_class(len(self.coco.getCatIds()))\n", 191 | "\n", 192 | " def __getitem__(self, index):\n", 193 | " \"\"\"\n", 194 | " Args:\n", 195 | " index (int): Index\n", 196 | "\n", 197 | " Returns:\n", 198 | " tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.\n", 199 | " \"\"\"\n", 200 | " coco = self.coco\n", 201 | " img_id = self.ids[index]\n", 202 | " ann_ids = coco.getAnnIds(imgIds=img_id)\n", 203 | " cat_ids = [ann['category_id'] for ann in coco.loadAnns(ann_ids)]\n", 204 | " target = coco.loadCats(cat_ids)\n", 205 | " if CONF['use_super']:\n", 206 | " target = np.array([x['supercategory'] for x in target])\n", 207 | " else:\n", 208 | " target = np.array([x['id'] for x in target if x['id'] in self.classes])\n", 209 | " targets = torch.FloatTensor([1 if (c in target) else 0 for c in self.classes])\n", 210 | " path = coco.loadImgs(img_id)[0]['file_name']\n", 211 | " img = Image.open(os.path.join(self.root, path)).convert('RGB')\n", 212 | " if self.transforms is not None:\n", 213 | " img, targets = self.transforms(img, targets)\n", 214 | "\n", 215 | " return img, targets\n", 216 | "\n", 217 | "\n", 218 | " def __len__(self):\n", 219 | " return len(self.ids)\n" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "\n", 229 | "clas_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n", 230 | " annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n", 231 | " transform=T)\n", 232 | "val_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['valType']),\n", 233 | " annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['valType']),\n", 234 | " transform=T_test)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "clas_set.sample_class(CONF['test_classes'])\n", 244 | "train_loader = torch.utils.data.DataLoader(clas_set, batch_size=CONF['BS'],shuffle=True, num_workers=8, pin_memory=True)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=CONF['test_BS'],shuffle=False, num_workers=8, pin_memory=False)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "class Flatten(nn.Module):\n", 263 | " def forward(self, x):\n", 264 | " x = x.view(x.size()[0], -1)\n", 265 | " return x" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "resnet18 = torchvision.models.resnet18(pretrained=True)\n", 275 | "modules=list(resnet18.children())[:-1]\n", 276 | "modules.append(Flatten())\n", 277 | "modules.append(nn.Linear(512, CONF['test_classes']))\n", 278 | "resnet = nn.Sequential(*modules)\n", 279 | "resnet.cuda()" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "opt = optim.SGD(resnet.parameters(), lr=0.02, momentum=0.9)\n", 289 | "scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.2, patience=10)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "c=0" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "criterion = nn.BCEWithLogitsLoss()" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "from tqdm.notebook import tqdm, trange" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):\n", 326 | " '''\n", 327 | " Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case\n", 328 | " https://stackoverflow.com/q/32239577/395857\n", 329 | " '''\n", 330 | " acc_list = []\n", 331 | " for i in range(y_true.shape[0]):\n", 332 | " set_true = set( np.where(y_true[i])[0] )\n", 333 | " set_pred = set( np.where(y_pred[i])[0] )\n", 334 | " tmp_a = None\n", 335 | " if len(set_true) == 0 and len(set_pred) == 0:\n", 336 | " tmp_a = 1\n", 337 | " else:\n", 338 | " tmp_a = len(set_true.intersection(set_pred))/\\\n", 339 | " float( len(set_true.union(set_pred)) )\n", 340 | " acc_list.append(tmp_a)\n", 341 | " return np.mean(acc_list)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "for it in trange(CONF['niter']):\n", 351 | "\n", 352 | " l=[]\n", 353 | " LBL=[]\n", 354 | " activation = []\n", 355 | " Y = []\n", 356 | " resnet.eval()\n", 357 | " with torch.no_grad():\n", 358 | " count = 0\n", 359 | " corrects = torch.zeros(CONF['test_classes'])\n", 360 | " for img,lbl in tqdm(val_loader, leave=False, desc=\"testing\"):\n", 361 | " LBL.append(lbl.numpy())\n", 362 | " lbl=lbl.cuda()\n", 363 | " count += img.shape[0]\n", 364 | " pred = resnet(img.cuda())\n", 365 | " l.append(criterion(pred, lbl).item())\n", 366 | " activation.append(pred.data.cpu().numpy())\n", 367 | " pred = torch.sigmoid(pred)>.5\n", 368 | " Y.append(pred.data.cpu().numpy())\n", 369 | " corrects += torch.sum(torch.eq(pred, lbl), dim=0).cpu()\n", 370 | " acc = (corrects/float(count))\n", 371 | " writer.add_scalar(\"sup_acc/val_avg\", torch.mean(acc).item(), global_step=it)\n", 372 | " writer.add_histogram('baseline/acc_val', acc.data.cpu().numpy(), global_step=it)\n", 373 | " writer.add_scalar(\"sup_loss/val_loss\", np.average(l), global_step=it)\n", 374 | " writer.flush()\n", 375 | "\n", 376 | " Y, LBL, activation = np.concatenate(Y,axis=0), np.concatenate(LBL,axis=0), np.concatenate(activation,axis=0)\n", 377 | " writer.add_scalar(\"sup_metrics/subset_acc\", metrics.accuracy_score(LBL, Y, normalize=True), global_step=it)\n", 378 | " writer.add_scalar(\"sup_metrics/hamming_loss\", metrics.hamming_loss(LBL, Y), global_step=it)\n", 379 | " writer.add_scalar(\"sup_metrics/hamming_score\", hamming_score(LBL, Y), global_step=it)\n", 380 | " writer.add_scalar(\"sup_metrics/micro_f1\", metrics.f1_score(LBL, Y, average='micro'), global_step=it)\n", 381 | " writer.add_scalar(\"sup_metrics/macro_f1\", metrics.f1_score(LBL, Y, average='macro'), global_step=it)\n", 382 | " writer.add_scalar(\"sup_metrics/micro_roc_auc\", metrics.roc_auc_score(LBL, Y, average='micro'), global_step=it)\n", 383 | " writer.add_scalar(\"sup_metrics/macro_roc_auc\", metrics.roc_auc_score(LBL, Y, average='macro'), global_step=it)\n", 384 | " writer.add_scalar(\"sup_metrics/micro_precision\", metrics.precision_score(LBL,Y,average='micro'), global_step=it)\n", 385 | " writer.add_scalar(\"sup_metrics/macro_precision\", metrics.precision_score(LBL,Y,average='macro'), global_step=it)\n", 386 | " writer.add_scalar(\"sup_metrics/micro_recall\", metrics.recall_score(LBL,Y,average='micro'), global_step=it)\n", 387 | " writer.add_scalar(\"sup_metrics/macro_recall\", metrics.recall_score(LBL,Y,average='macro'), global_step=it)\n", 388 | " writer.add_scalar(\"sup_metrics/avg_acc\", np.average(np.sum((Y==LBL), axis=0)/Y.shape[0]), global_step=it)\n", 389 | " scheduler.step(np.average(l))\n", 390 | " \n", 391 | " count = 0\n", 392 | " corrects = torch.zeros(CONF['test_classes'])\n", 393 | " resnet.train()\n", 394 | " for img,lbl in tqdm(train_loader, leave=False, desc=\"training\"):\n", 395 | " count += img.shape[0]\n", 396 | " img = img.cuda()\n", 397 | " pred = resnet(img)\n", 398 | " loss = criterion(pred,lbl.cuda())\n", 399 | " with torch.no_grad():\n", 400 | " corrects += torch.sum(torch.eq((torch.sigmoid(pred)>.5).cpu(), lbl), dim=0)\n", 401 | " opt.zero_grad()\n", 402 | " loss.backward()\n", 403 | " opt.step()\n", 404 | " writer.add_scalar(\"sup_loss/loss\", loss.item(), global_step=c)\n", 405 | " c+=1\n", 406 | " acc = (corrects/float(count))\n", 407 | " writer.add_scalar(\"sup_acc/train_avg\", torch.mean(acc).item(), global_step=it)\n", 408 | " writer.add_histogram('baseline/acc_train', acc.data.cpu().numpy(), global_step=it)\n", 409 | "\n", 410 | "writer.close()" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [] 454 | } 455 | ], 456 | "metadata": { 457 | "kernelspec": { 458 | "display_name": "Python 3", 459 | "language": "python", 460 | "name": "python3" 461 | }, 462 | "language_info": { 463 | "codemirror_mode": { 464 | "name": "ipython", 465 | "version": 3 466 | }, 467 | "file_extension": ".py", 468 | "mimetype": "text/x-python", 469 | "name": "python", 470 | "nbconvert_exporter": "python", 471 | "pygments_lexer": "ipython3", 472 | "version": "3.7.6" 473 | } 474 | }, 475 | "nbformat": 4, 476 | "nbformat_minor": 4 477 | } 478 | -------------------------------------------------------------------------------- /MSCOCO/main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, time" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "TRIAL_NAME='trial'" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "CONF={\n", 28 | " 'niter':5,\n", 29 | " 'ntest':50,\n", 30 | " 'GPU':0,\n", 31 | " 'BS':4,\n", 32 | " 'test_BS':256,\n", 33 | " 'N_neg':3,\n", 34 | " 'name':TRIAL_NAME,\n", 35 | " 'nz':2048,\n", 36 | " 'ng':128,\n", 37 | " 'seed':10715,\n", 38 | " 'data_dir':'/DataSet/COCO', #Set it to where your COCO dataset is\n", 39 | " 'dataType':'train2017',\n", 40 | " 'valType':'val2017',\n", 41 | " 'testType':'test2017',\n", 42 | " 'max_len':512,\n", 43 | " 'hidden_size':768,\n", 44 | " 'bert_hdim':3072,\n", 45 | " 'LAMBDAs':0.005,\n", 46 | " 'use_super':False,\n", 47 | " 'test_classes':80,\n", 48 | " 'n_trials':5,\n", 49 | " 'bert_pretrained':True, #pretrain bert\n", 50 | " 'resnet_pretrained':True #pretrain resnet\n", 51 | "}" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=str(CONF['GPU'])" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "import texar.torch as tx\n", 77 | "import random\n", 78 | "import torch\n", 79 | "import torch.nn as nn\n", 80 | "import torch.nn.parallel\n", 81 | "from torch.nn import functional as F\n", 82 | "import torch.backends.cudnn as cudnn\n", 83 | "import torch.optim as optim\n", 84 | "import torch.utils.data\n", 85 | "import torchvision\n", 86 | "from torch.utils.tensorboard import SummaryWriter\n", 87 | "import torchvision.datasets as dset\n", 88 | "import torchvision.transforms as transforms\n", 89 | "import torchvision.utils as vutils\n", 90 | "import numpy as np\n", 91 | "from torch import autograd\n", 92 | "import multiprocessing\n", 93 | "from PIL import Image\n", 94 | "from sklearn import metrics" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "device = torch.device(\"cuda:0\" if True else \"cpu\")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "time.sleep(2)\n", 113 | "import shutil\n", 114 | "tb_dir=os.path.join('./runs', CONF['name'] + \"_GLOBAL\")\n", 115 | "shutil.rmtree(tb_dir, ignore_errors=True)\n", 116 | "time.sleep(5)\n", 117 | "global_writer = SummaryWriter(log_dir=tb_dir)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "random.seed(CONF['seed'])\n", 127 | "torch.manual_seed(CONF['seed'])\n", 128 | "np.random.seed(CONF['seed'])\n", 129 | "cudnn.benchmark = True" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "batch_size = CONF['BS']" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "T = transforms.Compose([\n", 148 | " transforms.RandomResizedCrop((224,224), scale=(0.3, 1.0), ratio=(0.75, 1.3333333333333333)),\n", 149 | " transforms.ColorJitter(brightness=.1, contrast=.05, saturation=.05, hue=.05),\n", 150 | " transforms.RandomHorizontalFlip(),\n", 151 | " transforms.ToTensor(),\n", 152 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 153 | "])\n", 154 | "T_test = transforms.Compose([\n", 155 | " transforms.Resize((224,224)),\n", 156 | " transforms.ToTensor(),\n", 157 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 158 | "])" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "from torchvision.datasets.vision import VisionDataset\n", 168 | "class CocoClassification(VisionDataset):\n", 169 | " \"\"\"`MS Coco Detection `_ Dataset.\n", 170 | "\n", 171 | " Args:\n", 172 | " root (string): Root directory where images are downloaded to.\n", 173 | " annFile (string): Path to json annotation file.\n", 174 | " transform (callable, optional): A function/transform that takes in an PIL image\n", 175 | " and returns a transformed version. E.g, ``transforms.ToTensor``\n", 176 | " target_transform (callable, optional): A function/transform that takes in the\n", 177 | " target and transforms it.\n", 178 | " transforms (callable, optional): A function/transform that takes input sample and its target as entry\n", 179 | " and returns a transformed version.\n", 180 | " \"\"\"\n", 181 | "\n", 182 | " def sample_class(self, k):\n", 183 | " if CONF['use_super']:\n", 184 | " self.classes = np.array([\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"])\n", 185 | " self.class_description = [\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"]\n", 186 | " return\n", 187 | " class_list = self.coco.getCatIds()\n", 188 | " self.classes = np.sort(np.random.choice(class_list, size=k, replace=False))\n", 189 | " self.class_description = self.coco.loadCats(self.classes)\n", 190 | " arr = []\n", 191 | " for catId in self.classes:\n", 192 | " arr+=self.coco.getImgIds(catIds=[catId])\n", 193 | " self.ids = sorted(list(set(arr)))\n", 194 | " \n", 195 | " def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):\n", 196 | " super(CocoClassification, self).__init__(root, transforms, transform, target_transform)\n", 197 | " from pycocotools.coco import COCO\n", 198 | " self.coco = COCO(annFile)\n", 199 | " self.sample_class(len(self.coco.getCatIds()))\n", 200 | "\n", 201 | " def __getitem__(self, index):\n", 202 | " \"\"\"\n", 203 | " Args:\n", 204 | " index (int): Index\n", 205 | "\n", 206 | " Returns:\n", 207 | " tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.\n", 208 | " \"\"\"\n", 209 | " coco = self.coco\n", 210 | " img_id = self.ids[index]\n", 211 | " ann_ids = coco.getAnnIds(imgIds=img_id)\n", 212 | " cat_ids = [ann['category_id'] for ann in coco.loadAnns(ann_ids)]\n", 213 | " target = coco.loadCats(cat_ids)\n", 214 | " if CONF['use_super']:\n", 215 | " target = np.array([x['supercategory'] for x in target])\n", 216 | " else:\n", 217 | " target = np.array([x['id'] for x in target if x['id'] in self.classes])\n", 218 | " targets = torch.FloatTensor([1 if (c in target) else 0 for c in self.classes])\n", 219 | " path = coco.loadImgs(img_id)[0]['file_name']\n", 220 | " img = Image.open(os.path.join(self.root, path)).convert('RGB')\n", 221 | " if self.transforms is not None:\n", 222 | " img, targets = self.transforms(img, targets)\n", 223 | "\n", 224 | " return img, targets\n", 225 | "\n", 226 | "\n", 227 | " def __len__(self):\n", 228 | " return len(self.ids)\n" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "dataset = dset.CocoCaptions(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n", 238 | " annFile = '{}/annotations/captions_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n", 239 | " transform=T)\n", 240 | "\n", 241 | "clas_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n", 242 | " annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n", 243 | " transform=T_test)\n", 244 | "val_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['valType']),\n", 245 | " annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['valType']),\n", 246 | " transform=T_test)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size*(CONF['N_neg']+1),\n", 256 | " shuffle=True, num_workers=2, pin_memory=True, drop_last=True)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=CONF['test_BS'],shuffle=False, num_workers=8, pin_memory=False)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "hparams = {\n", 275 | " \"pretrained_model_name\": \"bert-base-uncased\",\n", 276 | " \"vocab_file\": None,\n", 277 | " \"max_len\": CONF['max_len'],\n", 278 | " \"unk_token\": \"[UNK]\",\n", 279 | " \"sep_token\": \"[SEP]\",\n", 280 | " \"pad_token\": \"[PAD]\",\n", 281 | " \"cls_token\": \"[CLS]\",\n", 282 | " \"mask_token\": \"[MASK]\",\n", 283 | " \"tokenize_chinese_chars\": True,\n", 284 | " \"do_lower_case\": True,\n", 285 | " \"do_basic_tokenize\": True,\n", 286 | " \"non_split_tokens\": None,\n", 287 | " \"name\": \"bert_tokenizer\",\n", 288 | "}" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "tokenizer = tx.data.BERTTokenizer(hparams=hparams, pretrained_model_name='bert-base-uncased')" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "bert_hparams = {\n", 307 | " 'embed': {'dim': CONF['hidden_size'], 'name': 'word_embeddings'},\n", 308 | " 'vocab_size': 30522,\n", 309 | " 'segment_embed': {'dim': CONF['hidden_size'], 'name': 'token_type_embeddings'},\n", 310 | " 'type_vocab_size': 2,\n", 311 | " 'position_embed': {'dim': CONF['hidden_size'], 'name': 'position_embeddings'},\n", 312 | " 'position_size': CONF['max_len'],\n", 313 | " 'encoder': {'dim': CONF['hidden_size'],\n", 314 | " 'embedding_dropout': 0.1,\n", 315 | " 'multihead_attention': {'dropout_rate': 0.1,\n", 316 | " 'name': 'self',\n", 317 | " 'num_heads': 6,\n", 318 | " 'num_units': CONF['hidden_size'],\n", 319 | " 'output_dim': CONF['hidden_size'],\n", 320 | " 'use_bias': True},\n", 321 | " 'name': 'encoder',\n", 322 | " 'num_blocks': 4,\n", 323 | " 'poswise_feedforward': {'layers': [{'kwargs': {'in_features': CONF['hidden_size'],\n", 324 | " 'out_features': CONF['bert_hdim'],\n", 325 | " 'bias': True},\n", 326 | " 'type': 'Linear'},\n", 327 | " {'type': 'BertGELU'},\n", 328 | " {'kwargs': {'in_features': CONF['bert_hdim'], 'out_features': CONF['hidden_size'], 'bias': True},\n", 329 | " 'type': 'Linear'}]},\n", 330 | " 'residual_dropout': 0.1,\n", 331 | " 'use_bert_config': True},\n", 332 | " 'hidden_size': CONF['hidden_size'],\n", 333 | " 'initializer': None,\n", 334 | " 'name': 'bert_encoder',\n", 335 | " 'pretrained_model_name':None}\n", 336 | "\n", 337 | "if CONF['bert_pretrained']:\n", 338 | " bert_hparams[\"pretrained_model_name\"]=\"bert-base-uncased\"" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "class Flatten(nn.Module):\n", 348 | " def forward(self, x):\n", 349 | " x = x.view(x.size()[0], -1)\n", 350 | " return x" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "class PC_Embedder(nn.Module):\n", 360 | "\n", 361 | "\n", 362 | " def __init__(self, k, nz, hparams):\n", 363 | " super(PC_Embedder, self).__init__()\n", 364 | " self.bert = tx.modules.BERTEncoder(hparams=hparams)\n", 365 | " self.q = nn.Linear(CONF['hidden_size'], nz)\n", 366 | " resnet50 = torchvision.models.resnet50(pretrained=CONF['resnet_pretrained'])\n", 367 | " modules=list(resnet50.children())[:-1]\n", 368 | " modules.append(Flatten())\n", 369 | " self.p = nn.Sequential(*modules)\n", 370 | " self.k = k+1\n", 371 | " self.G = nn.Sequential(\n", 372 | " nn.Linear(2048, 512, bias=True),\n", 373 | " nn.ReLU(),\n", 374 | " nn.Linear(512, nz, bias=True),\n", 375 | " )\n", 376 | "\n", 377 | " def forward(self, v):\n", 378 | " return self.p(v)\n", 379 | "\n", 380 | " def get_loss(self, v, v1, LAMBDA):\n", 381 | " inputs, segment_ids = v1\n", 382 | " _, v1 = self.bert(inputs=inputs, segment_ids=segment_ids)\n", 383 | " batch_size = v.shape[0]//self.k\n", 384 | " z = self.G(self.p(v))\n", 385 | " z = z.squeeze().view(batch_size, self.k, z.shape[1])[:,0,:]\n", 386 | " z_l = z.unsqueeze(1).expand(z.shape[0],self.k,z.shape[1]).contiguous()\n", 387 | " z_p = self.q(v1).view(z_l.shape).contiguous()\n", 388 | " l1 = F.log_softmax(torch.sum(z_l*z_p, dim = -1), dim=1)[:,0]\n", 389 | " l2 = torch.sum((z - z_p[:,0,:])**2, dim=-1)\n", 390 | " return torch.mean(- l1 + LAMBDA*l2, dim=0),l1,l2" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "class Tracker(object):\n", 400 | "\n", 401 | " def __init__(self, VARS, ranks):\n", 402 | " self.var_dict = dict(zip(VARS, [(1000000.0 if ranks[i] else -1000000.0, int(ranks[i])) for i in range(len(VARS))]))\n", 403 | " \n", 404 | " def update(self, d):\n", 405 | " for k,v in d.items():\n", 406 | " o,r = self.var_dict[k]\n", 407 | " self.var_dict[k] = (np.minimum(v,o) if r else np.maximum(v,o), r)\n", 408 | " \n", 409 | " def return_dict(self):\n", 410 | " D = dict()\n", 411 | " for k,(v,_) in self.var_dict.items():\n", 412 | " D[k]=v\n", 413 | " return D" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "test_class = nn.Linear(CONF['nz'],CONF['test_classes']).cuda()" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "criterion = nn.BCEWithLogitsLoss()" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "from tqdm import notebook\n", 441 | "tnrange=notebook.tnrange\n", 442 | "tqdm_notebook = notebook.tqdm" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "metadata": {}, 449 | "outputs": [], 450 | "source": [ 451 | "def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):\n", 452 | " '''\n", 453 | " Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case\n", 454 | " https://stackoverflow.com/q/32239577/395857\n", 455 | " '''\n", 456 | " acc_list = []\n", 457 | " for i in range(y_true.shape[0]):\n", 458 | " set_true = set( np.where(y_true[i])[0] )\n", 459 | " set_pred = set( np.where(y_pred[i])[0] )\n", 460 | " tmp_a = None\n", 461 | " if len(set_true) == 0 and len(set_pred) == 0:\n", 462 | " tmp_a = 1\n", 463 | " else:\n", 464 | " tmp_a = len(set_true.intersection(set_pred))/\\\n", 465 | " float( len(set_true.union(set_pred)) )\n", 466 | " acc_list.append(tmp_a)\n", 467 | " return np.mean(acc_list)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "METRICS = ['train_avg', 'val_avg', 'subset_acc', 'hamming_loss', 'hamming_score',\n", 477 | " 'micro_f1', 'macro_f1', 'micro_roc_auc', 'macro_roc_auc',\n", 478 | " 'micro_precision', 'macro_precision', 'micro_recall', 'macro_recall']\n", 479 | "RANKS = [0, 0, 0, 1, 0,\n", 480 | " 0, 0, 0, 0,\n", 481 | " 0, 0, 1, 1]\n", 482 | "\n", 483 | "assert(len(METRICS)==len(RANKS))" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [ 492 | "def trial(enc, opt, LAMBDA, trial_number):\n", 493 | " c=0\n", 494 | " # Clear any logs from previous runs\n", 495 | " import shutil\n", 496 | " tb_dir=os.path.join('./runs', CONF['name'] + \"_LAMBDA_{}_{}\".format(LAMBDA, trial_number))\n", 497 | " shutil.rmtree(tb_dir, ignore_errors=True)\n", 498 | " time.sleep(5)\n", 499 | " tracker = Tracker(METRICS, RANKS)\n", 500 | " writer = SummaryWriter(log_dir=tb_dir)\n", 501 | " for it in tnrange(CONF['niter'], desc=\"training with LAMBDA:{}\".format(LAMBDA)):\n", 502 | "\n", 503 | " test_class.reset_parameters()\n", 504 | " clas_set.sample_class(CONF['test_classes'])\n", 505 | " clas_loader = torch.utils.data.DataLoader(clas_set, batch_size=CONF['test_BS'],shuffle=True, num_workers=8, pin_memory=False)\n", 506 | " with torch.no_grad():\n", 507 | " enc.eval()\n", 508 | " test = []\n", 509 | " for img,lbl in tqdm_notebook(clas_loader,leave=False,desc=\"generating train set\"):\n", 510 | " test.append((enc(img.cuda()).data.cpu(), lbl))\n", 511 | " enc.train()\n", 512 | "\n", 513 | " opt_test = optim.Adam(test_class.parameters(), lr=1e-2)\n", 514 | " for test_it in tnrange(CONF['ntest'],leave=False,desc=\"training linear classifier\"):\n", 515 | " random.shuffle(test)\n", 516 | " for z,lbl in test:\n", 517 | " pred = test_class(z.cuda())\n", 518 | " loss = criterion(pred,lbl.cuda())\n", 519 | " opt_test.zero_grad()\n", 520 | " loss.backward()\n", 521 | " opt_test.step()\n", 522 | "\n", 523 | " with torch.no_grad():\n", 524 | " M = dict()\n", 525 | " count = 0\n", 526 | " corrects = torch.zeros(CONF['test_classes'])\n", 527 | " for z,lbl in test:\n", 528 | " count += z.shape[0]\n", 529 | " pred = torch.sigmoid(test_class(z.cuda()))>.5\n", 530 | " corrects += torch.sum(torch.eq(pred.cpu(), lbl), dim=0)\n", 531 | " acc = (corrects/float(count))\n", 532 | " writer.add_scalar(\"acc/train_avg\", torch.mean(acc).item(), global_step=it)\n", 533 | " writer.add_histogram('acc_train', acc.data.cpu().numpy(), global_step=it)\n", 534 | " M['train_avg'] = torch.mean(acc).item()\n", 535 | "\n", 536 | " count = 0\n", 537 | " corrects = torch.zeros(CONF['test_classes'])\n", 538 | " LBL, Y = [], []\n", 539 | " for img,lbl in val_loader:\n", 540 | " LBL.append(lbl.numpy())\n", 541 | " count += img.shape[0]\n", 542 | " z = enc(img.cuda())\n", 543 | " pred = torch.sigmoid(test_class(z))>.5\n", 544 | " Y.append(pred.data.cpu().numpy())\n", 545 | " corrects += torch.sum(torch.eq(pred.cpu(), lbl), dim=0)\n", 546 | " acc = (corrects/float(count))\n", 547 | " writer.add_scalar(\"acc/val_avg\", torch.mean(acc).item(), global_step=it)\n", 548 | " writer.add_histogram('acc_val', acc.data.cpu().numpy(), global_step=it)\n", 549 | "\n", 550 | " M['val_avg'] = torch.mean(acc).item()\n", 551 | " Y, LBL = np.concatenate(Y,axis=0), np.concatenate(LBL,axis=0)\n", 552 | " M['subset_acc'] = metrics.accuracy_score(LBL, Y, normalize=True)\n", 553 | " M['hamming_loss'] = metrics.hamming_loss(LBL, Y)\n", 554 | " M['hamming_score'] = hamming_score(LBL, Y)\n", 555 | " M['micro_f1'] = metrics.f1_score(LBL, Y, average='micro')\n", 556 | " M['macro_f1'] = metrics.f1_score(LBL, Y, average='macro')\n", 557 | " M['micro_roc_auc'] = metrics.roc_auc_score(LBL, Y, average='micro')\n", 558 | " M['macro_roc_auc'] = metrics.roc_auc_score(LBL, Y, average='macro')\n", 559 | " M['micro_precision'] = metrics.precision_score(LBL,Y,average='micro')\n", 560 | " M['macro_precision'] = metrics.precision_score(LBL,Y,average='macro')\n", 561 | " M['micro_recall'] = metrics.recall_score(LBL,Y,average='micro')\n", 562 | " M['macro_recall'] = metrics.recall_score(LBL,Y,average='macro')\n", 563 | " \n", 564 | " for k,v in M.items():\n", 565 | " writer.add_scalar(\"metrics/{}\".format(k), v, global_step=it)\n", 566 | " tracker.update(M)\n", 567 | " writer.flush()\n", 568 | "\n", 569 | " for img,ann in tqdm_notebook(train_loader, leave=False):\n", 570 | " ann = np.take_along_axis(np.array(ann), np.random.randint(0, len(ann)-1, size=(1,img.shape[0])),0).squeeze()\n", 571 | " inputs, segment_ids = [],[]\n", 572 | " for s in ann:\n", 573 | " x,y,_ = tokenizer.encode_text(s)\n", 574 | " inputs.append(x)\n", 575 | " segment_ids.append(y)\n", 576 | "\n", 577 | " inputs = torch.LongTensor(inputs).cuda()\n", 578 | " segment_ids = torch.LongTensor(segment_ids).cuda()\n", 579 | " img = img.cuda()\n", 580 | " loss, l1, l2 = enc.get_loss(img,(inputs,segment_ids), LAMBDA)\n", 581 | " opt.zero_grad()\n", 582 | " loss.backward()\n", 583 | " opt.step()\n", 584 | " writer.add_scalar(\"loss/loss\", loss.item(), global_step=c)\n", 585 | " writer.add_scalars(\"loss/parts\", {'l1':- l1.data.mean().item(), 'l2':l2.data.mean().item()}, global_step=c)\n", 586 | " c+=1\n", 587 | " torch.save(enc.state_dict(), \"./models/{}.pth\".format(CONF['name'] + \"_trial_{}_LAMBDA_{}_it_{}\".format(trial_number,LAMBDA,it)))\n", 588 | " writer.close()\n", 589 | " return tracker.return_dict()" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "enc = PC_Embedder(CONF['N_neg'], CONF['nz'], bert_hparams)\n", 599 | "enc.to(device)\n", 600 | "opt = optim.Adam(enc.parameters(), lr=1e-4)\n", 601 | "global_writer.add_hparams(hparam_dict={'lambda': CONF['LAMBDAs'], 'trial':t}, metric_dict=trial(enc, opt, CONF['LAMBDAs'], t))\n", 602 | "global_writer.close()" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": {}, 609 | "outputs": [], 610 | "source": [] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "execution_count": null, 615 | "metadata": {}, 616 | "outputs": [], 617 | "source": [] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": null, 622 | "metadata": {}, 623 | "outputs": [], 624 | "source": [] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": null, 629 | "metadata": {}, 630 | "outputs": [], 631 | "source": [] 632 | } 633 | ], 634 | "metadata": { 635 | "kernelspec": { 636 | "display_name": "Python 3", 637 | "language": "python", 638 | "name": "python3" 639 | }, 640 | "language_info": { 641 | "codemirror_mode": { 642 | "name": "ipython", 643 | "version": 3 644 | }, 645 | "file_extension": ".py", 646 | "mimetype": "text/x-python", 647 | "name": "python", 648 | "nbconvert_exporter": "python", 649 | "pygments_lexer": "ipython3", 650 | "version": "3.7.6" 651 | } 652 | }, 653 | "nbformat": 4, 654 | "nbformat_minor": 4 655 | } 656 | -------------------------------------------------------------------------------- /Omniglot/README.md: -------------------------------------------------------------------------------- 1 | # Omniglot Experiments 2 | 3 | The code is adapted from [here](https://github.com/leftthomas/SimCLR) 4 | 5 | 6 | ## Usage 7 | 8 | ### Evaluating Self-supervised Representations 9 | 10 | Contrastive Learning Objective only 11 | ``` 12 | python main.py --loss_type 1 --recon_param 0.0 --inver_param 0.0 --epochs 1000 13 | ``` 14 | 15 | Others please refer to `main.py` 16 | 17 | ### Measuring Information 18 | 19 | + Step 1: Train an Auto-Encoder (we assume the encoded features do not lose any information) 20 | ``` 21 | python compute_MI_CondEntro.py --stage AE 22 | ``` 23 | 24 | + Step 2: Estimate the raw information 25 | ``` 26 | python compute_MI_CondEntro.py --stage Raw_Information 27 | ``` 28 | 29 | + Step 3: Estimate the information for the SSL learned representations 30 | 31 | - for Contrastive Learning Objective only 32 | ``` 33 | python main.py --loss_type 1 --recon_param 0.0 --inver_param 0.0 --epochs 1000 --with_info 34 | ``` 35 | - for Contrastive Learning Objective + Inverse Predictive Learning Objective 36 | ``` 37 | python main.py --loss_type 2 --recon_param 0.0 --inver_param 1.0 --epochs 1000 --with_info 38 | ``` 39 | -------------------------------------------------------------------------------- /Omniglot/compute_MI_CondEntro.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models.resnet import resnet18, resnet34, resnet50 5 | import math 6 | import utils 7 | import argparse 8 | from torch.utils.data import DataLoader 9 | import torch.optim as optim 10 | from tqdm import tqdm 11 | import pandas as pd 12 | 13 | 14 | # ## Utility Functions 15 | 16 | class Lambda(nn.Module): 17 | def __init__(self, func): 18 | super(Lambda, self).__init__() 19 | self.func = func 20 | 21 | def forward(self, x): 22 | return self.func(x) 23 | def mlp(dim, hidden_dim, output_dim, layers=1, batch_norm=False): 24 | if batch_norm: 25 | seq = [nn.Linear(dim, hidden_dim), nn.BatchNorm1d(num_features=hidden_dim),\ 26 | nn.ReLU(inplace=True)] 27 | for _ in range(layers): 28 | seq += [nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(num_features=hidden_dim),\ 29 | nn.ReLU(inplace=True)] 30 | else: 31 | seq = [nn.Linear(dim, hidden_dim), nn.ReLU(inplace=True)] 32 | for _ in range(layers): 33 | seq += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] 34 | seq += [nn.Linear(hidden_dim, output_dim)] 35 | 36 | return nn.Sequential(*seq) 37 | 38 | 39 | def init_models(stage, additional_encoder=False): 40 | if stage == 'AE': 41 | norm_encoder = Encoder(normalize=True).cuda() 42 | norm_decoder = Decoder().cuda() 43 | encoder = Encoder(normalize=False).cuda() 44 | decoder = Decoder().cuda() 45 | 46 | params = list(norm_encoder.parameters()) + list(norm_decoder.parameters()) +\ 47 | list(encoder.parameters()) + list(decoder.parameters()) 48 | 49 | norm_encoder.train() 50 | norm_decoder.train() 51 | encoder.train() 52 | decoder.train() 53 | 54 | models = { 55 | 'norm_encoder': norm_encoder, 56 | 'norm_decoder': norm_decoder, 57 | 'encoder': encoder, 58 | 'decoder': decoder, 59 | } 60 | elif stage == 'Raw_Information' or stage == 'Feature_Information': 61 | if stage == 'Raw_Information': 62 | norm_encoder = Encoder(normalize=True).cuda() 63 | norm_encoder.load_state_dict(torch.load('results/norm_encoder.pth')) 64 | norm_encoder.eval() 65 | if additional_encoder: 66 | feat_encoder = Encoder(normalize=True).cuda() 67 | feat_encoder.load_state_dict(torch.load('results/norm_encoder.pth')) 68 | feat_encoder.train() 69 | params = list(feat_encoder.parameters()) 70 | else: 71 | feat_encoder = None 72 | params = [] 73 | else: 74 | norm_encoder = None 75 | feat_encoder = None 76 | params = [] 77 | 78 | encoder = Encoder(normalize=False).cuda() 79 | encoder.load_state_dict(torch.load('results/encoder.pth')) 80 | 81 | mi_z_z_model = MI_Z_Z_Model().cuda() 82 | mi_z_t_model = MI_Z_T_Model().cuda() 83 | cond_z_t_model = Cond_Z_T_Model().cuda() 84 | cond_z_z_model = Cond_Z_Z_Model().cuda() 85 | 86 | params = params + list(mi_z_z_model.parameters()) + list(mi_z_t_model.parameters()) +\ 87 | list(cond_z_t_model.parameters()) + list(cond_z_z_model.parameters()) 88 | 89 | encoder.eval() 90 | mi_z_z_model.train() 91 | mi_z_t_model.train() 92 | cond_z_t_model.train() 93 | cond_z_z_model.train() 94 | 95 | models = { 96 | 'encoder': encoder, 97 | 'norm_encoder': norm_encoder, 98 | 'feat_encoder': feat_encoder, 99 | 'mi_z_z_model': mi_z_z_model, 100 | 'mi_z_t_model': mi_z_t_model, 101 | 'cond_z_t_model': cond_z_t_model, 102 | 'cond_z_z_model': cond_z_z_model, 103 | } 104 | optimizer = optim.Adam(params, lr=1e-3) 105 | 106 | return models, optimizer 107 | 108 | 109 | # ## Auto-Encoding Structure (for one-to-one mapping) 110 | 111 | class Encoder(nn.Module): 112 | def __init__(self, normalize=False): 113 | super(Encoder, self).__init__() 114 | self.f = nn.Sequential( 115 | nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28 116 | #nn.BatchNorm2d(num_features=128), 117 | nn.ReLU(inplace=True), 118 | #nn.MaxPool2d(kernel_size=2, stride=2), 119 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28 120 | #nn.BatchNorm2d(num_features=128), 121 | nn.ReLU(inplace=True), 122 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 14 123 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 14 124 | #nn.BatchNorm2d(num_features=128), 125 | nn.ReLU(inplace=True), 126 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 7 127 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 7 128 | #nn.BatchNorm2d(num_features=128), 129 | nn.ReLU(inplace=True), 130 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 3 131 | nn.Flatten(), 132 | nn.Linear(9*128, 1024), 133 | ) 134 | self.normalize = normalize 135 | 136 | def forward(self, _input): 137 | feature = self.f(_input) 138 | if self.normalize: 139 | return F.normalize(feature, dim=-1) 140 | else: 141 | return feature 142 | class Decoder(nn.Module): 143 | def __init__(self): 144 | super(Decoder, self).__init__() 145 | self.f = nn.Sequential( 146 | nn.Linear(1024, 9*128, bias=False), 147 | #nn.BatchNorm1d(num_features=9*128), 148 | nn.ReLU(inplace=True), # (9*128 -> 3*3*128) 149 | Lambda(lambda x: x.view(-1, 128, 3, 3)), 150 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=0, 151 | output_padding=0, bias=False), # out: 7 152 | #nn.BatchNorm2d(num_features=128), 153 | nn.ReLU(inplace=True), 154 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, 155 | output_padding=1, bias=False), # out: 14 156 | #nn.BatchNorm2d(num_features=128), 157 | nn.ReLU(inplace=True), 158 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, 159 | output_padding=1, bias=False), # out: 28 160 | #nn.BatchNorm2d(num_features=128), 161 | nn.ReLU(inplace=True), 162 | nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1, 163 | output_padding=0, bias=True), # out: 28 164 | ) 165 | 166 | def forward(self, _input): 167 | return self.f(_input) 168 | 169 | 170 | # ## Infomation Functions 171 | 172 | class MI_Z_T_Model(nn.Module): 173 | def __init__(self): 174 | super(MI_Z_T_Model, self).__init__() 175 | self._g = mlp(1024, 512, 512) 176 | self._h = mlp(964, 512, 512) 177 | 178 | def forward(self, z, t): 179 | t = F.one_hot(t, num_classes=964) 180 | 181 | scores = torch.matmul(self._g(z), self._h(t.float()).t()) 182 | return scores 183 | 184 | 185 | class MI_Z_Z_Model(nn.Module): 186 | def __init__(self): 187 | super(MI_Z_Z_Model, self).__init__() 188 | self._g = mlp(1024, 512, 512) 189 | self._h = mlp(1024, 512, 512) 190 | 191 | def forward(self, z1, z2): 192 | scores = torch.matmul(self._g(z1), self._h(z2).t()) 193 | return scores 194 | 195 | 196 | class Cond_Z_T_Model(nn.Module): 197 | def __init__(self): 198 | super(Cond_Z_T_Model, self).__init__() 199 | self._g = mlp(964, 512, 1024) 200 | 201 | def forward(self, t): 202 | t = F.one_hot(t, num_classes=964) 203 | recon_z = self._g(t.float()) 204 | return recon_z 205 | 206 | 207 | class Cond_Z_Z_Model(nn.Module): 208 | def __init__(self): 209 | super(Cond_Z_Z_Model, self).__init__() 210 | self._g = mlp(1024, 512, 1024) 211 | 212 | def forward(self, z): 213 | recon_z = self._g(z) 214 | return recon_z 215 | 216 | 217 | # ## Training 218 | 219 | def AE_loss(x, s, encoder, decoder): 220 | zx = encoder(x) 221 | zs = encoder(s) 222 | hat_x = decoder(zx) 223 | hat_s = decoder(zs) 224 | return F.binary_cross_entropy_with_logits(hat_x, x) +\ 225 | F.binary_cross_entropy_with_logits(hat_s, s) 226 | def Enc_z(x, s, encoder): 227 | return encoder(x), encoder(s) 228 | def AE_Step(pos_1, pos_2, models, optimizer): 229 | norm_encoder, norm_decoder = models['norm_encoder'], models['norm_decoder'] 230 | encoder, decoder = models['encoder'], models['decoder'] 231 | 232 | optimizer.zero_grad() 233 | loss = 0.5*AE_loss(pos_1, pos_2, norm_encoder, norm_decoder) +\ 234 | 0.5*AE_loss(pos_1, pos_2, encoder, decoder) 235 | loss.backward() 236 | optimizer.step() 237 | 238 | return loss.item() 239 | 240 | 241 | # Maximization 242 | def MI_Estimator(ft1, ft2, model): 243 | ''' 244 | ft1, ft2: r.v.s 245 | model: takes ft1 and ft2, output batch_size x batch_size 246 | ''' 247 | scores = model(ft1, ft2) 248 | 249 | # optimal critic f(ft1, ft2) = log {p(ft1, ft2)/p(ft1)p(ft2)} 250 | def js_fgan_lower_bound_obj(scores): 251 | """Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016).""" 252 | scores_diag = scores.diag() 253 | first_term = -F.softplus(-scores_diag).mean() 254 | n = scores.size(0) 255 | second_term = (torch.sum(F.softplus(scores)) - 256 | torch.sum(F.softplus(scores_diag))) / (n * (n - 1.)) 257 | return first_term - second_term 258 | # if the input is in log form 259 | def direct_log_density_ratio_mi(scores): 260 | return scores.diag().mean() 261 | 262 | train_val = js_fgan_lower_bound_obj(scores) 263 | eval_val = direct_log_density_ratio_mi(scores) 264 | 265 | with torch.no_grad(): 266 | eval_train = eval_val - train_val 267 | 268 | return train_val + eval_train 269 | 270 | 271 | # Minimization 272 | def Conditional_Entropy(ft1, ft2, model): 273 | ''' 274 | Calculating H(ft2|ft1) by min_Q H[P(ft2|ft1), Q(ft2|ft1)] 275 | ft1: discrete or continuous 276 | ft2: continuous (k-dim.) 277 | We assume Q(ft2|ft1) is Gaussian. 278 | model (Q): takes ft1, out the reconstructed ft2 279 | ''' 280 | hat_ft2 = model(ft1) 281 | 282 | # sigma = l2_norm of ft2 (we let it be 1) 283 | # when Q = Normal(mu(ft1), sigma^2I) -> -logQ = log(sqrt((2*pi)^k sigma^(2k))) + 284 | # 0.5*1/(sigma^2)*(y-mu)^T(y-mu) 285 | # H[P(ft2|(ft1), Q(ft2|(ft1)] = E_{P_{(ft1,ft2}} [-logQ] 286 | dim = ft2.shape[1] 287 | bsz = ft2.shape[0] 288 | 289 | #cond_entropy = 0.5*dim*math.log(2*math.pi) + 0.5*(F.mse_loss(hat_ft2, ft2, reduction='sum')/bsz) 290 | #return cond_entropy 291 | scaled_cond_entropy = F.mse_loss(hat_ft2, ft2, reduction='sum')/bsz 292 | return scaled_cond_entropy 293 | 294 | 295 | 296 | def Information_Step(pos_1, pos_2, t, models, optimizer, zx=None, zs=None): 297 | feat_encoder, norm_encoder, encoder,\ 298 | mi_z_z_model, mi_z_t_model, cond_z_t_model, cond_z_z_model =\ 299 | models['feat_encoder'], models['norm_encoder'], models['encoder'], models['mi_z_z_model'],\ 300 | models['mi_z_t_model'], models['cond_z_t_model'], models['cond_z_z_model'] 301 | 302 | ae_zx, ae_zs = Enc_z(pos_1, pos_2, encoder) 303 | if zx is None and zs is None: 304 | if feat_encoder is not None: 305 | I_zx, I_zs = Enc_z(pos_1, pos_2, feat_encoder) 306 | else: 307 | I_zx, I_zs = Enc_z(pos_1, pos_2, norm_encoder) 308 | H_zx, H_zs = Enc_z(pos_1, pos_2, norm_encoder) 309 | else: 310 | I_zx, I_zs = zx, zs 311 | H_zx, H_zs = zx, zs 312 | 313 | I_Z_T = 0.5*MI_Estimator(I_zx, t, mi_z_t_model) +\ 314 | 0.5*MI_Estimator(I_zs, t, mi_z_t_model) 315 | 316 | I_Z_S = 0.5*MI_Estimator(I_zx, ae_zs, mi_z_z_model) +\ 317 | 0.5*MI_Estimator(I_zs, ae_zx, mi_z_z_model) 318 | 319 | H_Z_T = 0.5*Conditional_Entropy(t, H_zx, cond_z_t_model) +\ 320 | 0.5*Conditional_Entropy(t, H_zs, cond_z_t_model) 321 | 322 | H_Z_S = 0.5*Conditional_Entropy(ae_zs, H_zx, cond_z_z_model) +\ 323 | 0.5*Conditional_Entropy(ae_zx, H_zs, cond_z_z_model) 324 | 325 | optimizer.zero_grad() 326 | loss = -I_Z_S - I_Z_T + H_Z_T + H_Z_S 327 | loss.backward() 328 | optimizer.step() 329 | 330 | return I_Z_S.item(), I_Z_T.item(), H_Z_T.item(), H_Z_S.item() 331 | 332 | 333 | # ## Script for Calculating I(X;S), I(X;S|T), I(X;T), H(X|T), H(X|S) 334 | 335 | if __name__ == '__main__': 336 | parser = argparse.ArgumentParser(description='Train SimCLR') 337 | parser.add_argument('--batch_size', default=482, type=int, help='Number of images in each mini-batch\ 338 | (964/2=482 for omniglot and 512 for cifar)') 339 | parser.add_argument('--epochs', default=23000, type=int, help='Number of sweeps over the dataset to train') 340 | parser.add_argument('--stage', default='AE', type=str, help='AE or Raw_Information') 341 | parser.add_argument('--additional_encoder', default=False, action='store_true') 342 | 343 | # args parse 344 | args = parser.parse_args() 345 | batch_size, epochs, stage, additional_encoder = args.batch_size, args.epochs, args.stage, args.additional_encoder 346 | 347 | train_data = utils.Our_Omniglot(root='data', background=True, transform=utils.omniglot_train_transform, 348 | character_target_transform=None, alphabet_target_transform=None, download=True, 349 | contrast_training=True) 350 | 351 | 352 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, 353 | drop_last=True) 354 | 355 | # model setup and optimizer config 356 | models, optimizer = init_models(stage, additional_encoder) 357 | 358 | if stage == 'Raw_Information': 359 | results = {'I(X;S)': [], 'I(X;T)': [], 'H(X|T)': [], 'H(X|S)': []} 360 | 361 | for epoch in range(1, epochs + 1): 362 | train_bar = tqdm(train_loader) 363 | total_num = 0 364 | if stage == 'Raw_Information': 365 | I_X_S_total, I_X_T_total, H_X_T_total, H_X_S_total =\ 366 | 0.0, 0.0, 0.0, 0.0 367 | elif stage == 'AE': 368 | AE_Loss_total = 0.0 369 | 370 | for pos_1, pos_2, target in train_bar: 371 | pos_1, pos_2, target = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True),\ 372 | target.cuda(non_blocking=True) 373 | total_num += 1 374 | if stage == 'Raw_Information': 375 | I_X_S, I_X_T, H_X_T, H_X_S = \ 376 | Information_Step(pos_1, pos_2, target, models, optimizer, zx=None, zs=None) 377 | 378 | I_X_S_total += I_X_S 379 | I_X_T_total += I_X_T 380 | H_X_T_total += H_X_T 381 | H_X_S_total += H_X_S 382 | elif stage == 'AE': 383 | AE_Loss_total += AE_Step(pos_1, pos_2, models, optimizer) 384 | 385 | if stage == 'Raw_Information': 386 | print('Epoch: {}, I(X;S): {}, I(X;T): {}, H(X|T): {}, H(X|S): {}'\ 387 | .format(epoch, I_X_S_total / total_num, I_X_T_total / total_num,\ 388 | H_X_T_total / total_num, H_X_S_total / total_num)) 389 | 390 | results['I(X;S)'].append(I_X_S_total / total_num) 391 | results['I(X;T)'].append(I_X_T_total / total_num) 392 | results['H(X|T)'].append(H_X_T_total / total_num) 393 | results['H(X|S)'].append(H_X_S_total / total_num) 394 | 395 | # save statistics 396 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 397 | 398 | if additional_encoder: 399 | data_frame.to_csv('results/Raw_Information_additional_encoder.csv', index_label='epoch') 400 | else: 401 | data_frame.to_csv('results/Raw_Information.csv', index_label='epoch') 402 | elif stage == 'AE': 403 | print('Epoch: {}, AE_Loss: {}'.format(epoch, AE_Loss_total / total_num)) 404 | 405 | # save encoder 406 | torch.save(models['norm_encoder'].state_dict(), 'results/norm_encoder.pth') 407 | torch.save(models['encoder'].state_dict(), 'results/encoder.pth') 408 | 409 | 410 | def information(epoch, train_loader, inner_epochs, net, models, optimizer): 411 | net.eval() 412 | I_Z_S_total, I_Z_T_total, H_Z_T_total, H_Z_S_total =\ 413 | 0.0, 0.0, 0.0, 0.0 414 | total_num = 0 415 | 416 | for _in in range(inner_epochs): 417 | train_bar = tqdm(train_loader) 418 | for pos_1, pos_2, target in train_bar: 419 | pos_1, pos_2, target = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True),\ 420 | target.cuda(non_blocking=True) 421 | total_num += 1 422 | zx, _ = net(pos_1) 423 | zs, _ = net(pos_2) 424 | I_Z_S, I_Z_T, H_Z_T, H_Z_S = \ 425 | Information_Step(pos_1, pos_2, target, models, optimizer, zx=zx, zs=zs) 426 | 427 | I_Z_S_total += I_Z_S 428 | I_Z_T_total += I_Z_T 429 | H_Z_T_total += H_Z_T 430 | H_Z_S_total += H_Z_S 431 | 432 | print('Epoch: {}, Inner Epoch: {}, I(Z;S): {}, I(Z;T): {}, H(Z|T): {}, H(Z|S): {}'\ 433 | .format(epoch, _in, I_Z_S, I_Z_T, H_Z_T, H_Z_S)) 434 | 435 | return I_Z_S_total/total_num , I_Z_T_total/total_num, H_Z_T_total/total_num,\ 436 | H_Z_S_total/total_num 437 | -------------------------------------------------------------------------------- /Omniglot/linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pandas as pd 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from thop import profile, clever_format 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import CIFAR10 11 | from tqdm import tqdm 12 | 13 | import utils 14 | from model import Model, Omniglot_Model 15 | 16 | 17 | class Net(nn.Module): 18 | def __init__(self, num_class, pretrained_path, resnet_depth=18): 19 | super(Net, self).__init__() 20 | 21 | if resnet_depth == 18: 22 | resnet_output_dim = 512 23 | elif resnet_depth == 34: 24 | resnet_output_dim = 512 25 | elif resnet_depth == 50: 26 | resnet_output_dim = 2048 27 | 28 | # encoder 29 | self.f = Model(resnet_depth=resnet_depth).f 30 | # classifier 31 | self.fc = nn.Linear(resnet_output_dim, num_class, bias=True) 32 | self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False) 33 | 34 | def forward(self, x): 35 | x = self.f(x) 36 | feature = torch.flatten(x, start_dim=1) 37 | out = self.fc(feature) 38 | return out 39 | 40 | 41 | class Omniglot_Net(nn.Module): 42 | def __init__(self, num_class, pretrained_path): 43 | super(Omniglot_Net, self).__init__() 44 | 45 | # encoder 46 | self.f = Omniglot_Model().f 47 | # classifier 48 | self.fc = nn.Sequential( 49 | nn.Linear(1024, num_class, bias=False), 50 | #nn.BatchNorm1d(256), 51 | #nn.ReLU(inplace=True), 52 | #nn.Linear(256, 256, bias=True), 53 | #nn.BatchNorm1d(256), 54 | #nn.ReLU(inplace=True), 55 | #nn.Linear(256, num_class, bias=True) 56 | ) 57 | 58 | 59 | 60 | self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False) 61 | 62 | def forward(self, x): 63 | feature = self.f(x) 64 | feature = F.normalize(feature, dim=-1) 65 | out = self.fc(feature) 66 | return out 67 | 68 | 69 | # train or test for one epoch 70 | def train_val(net, data_loader, train_optimizer): 71 | is_train = train_optimizer is not None 72 | net.train() if is_train else net.eval() 73 | 74 | total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader) 75 | with (torch.enable_grad() if is_train else torch.no_grad()): 76 | for data, target in data_bar: 77 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 78 | out = net(data) 79 | loss = loss_criterion(out, target) 80 | 81 | if is_train: 82 | train_optimizer.zero_grad() 83 | loss.backward() 84 | train_optimizer.step() 85 | 86 | total_num += data.size(0) 87 | total_loss += loss.item() * data.size(0) 88 | prediction = torch.argsort(out, dim=-1, descending=True) 89 | total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 90 | total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 91 | 92 | data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%' 93 | .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num, 94 | total_correct_1 / total_num * 100, total_correct_5 / total_num * 100)) 95 | 96 | return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser(description='Linear Evaluation') 101 | parser.add_argument('--model_path', type=str, default='results/128_0.5_200_512_500_model.pth', 102 | help='The pretrained model path') 103 | parser.add_argument('--batch_size', type=int, default=512, help='Number of images in each mini-batch') 104 | parser.add_argument('--epochs', type=int, default=1000, help='Number of sweeps over the dataset to train') 105 | parser.add_argument('--resnet_depth', default=18, type=int, help='The depth of the resnet') 106 | parser.add_argument('--dataset', default='omniglot', type=str, help='omniglot or cifar') 107 | 108 | args = parser.parse_args() 109 | model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs 110 | if args.dataset == 'cifar': 111 | resnet_depth = args.resnet_depth 112 | 113 | train_data = CIFAR10(root='data', train=True, transform=utils.train_transform, download=True) 114 | test_data = CIFAR10(root='data', train=False, transform=utils.test_transform, download=True) 115 | else: 116 | train_data = utils.Our_Omniglot(root='data', background=False, transform=utils.omniglot_train_transform, 117 | character_target_transform=None, alphabet_target_transform=None, download=True, 118 | eval_split_train=True, out_character=False, contrast_training=False) 119 | test_data = utils.Our_Omniglot(root='data', background=False, transform=utils.omniglot_test_transform, 120 | character_target_transform=None, alphabet_target_transform=None, download=True, 121 | eval_split_train=False, out_character=False, contrast_training=False) 122 | 123 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True) 124 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) 125 | 126 | 127 | if args.dataset == 'cifar': 128 | model = Net(num_class=len(train_data.classes), pretrained_path=model_path, resnet_depth=resnet_depth ).cuda() 129 | else: 130 | #model = Omniglot_Net(num_class=20, pretrained_path=model_path).cuda() 131 | model = Omniglot_Net(num_class=659, pretrained_path=model_path).cuda() 132 | for param in model.f.parameters(): 133 | param.requires_grad = False 134 | 135 | if args.dataset == 'cifar': 136 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),)) 137 | else: 138 | flops, params = profile(model, inputs=(torch.randn(1, 1, 28, 28).cuda(),)) 139 | flops, params = clever_format([flops, params]) 140 | print('# Model Params: {} FLOPs: {}'.format(params, flops)) 141 | optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6) 142 | #optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=1e-6) 143 | loss_criterion = nn.CrossEntropyLoss() 144 | results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [], 145 | 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []} 146 | 147 | save_name_pre = model_path.split('.pth')[0] 148 | 149 | best_acc = 0.0 150 | for epoch in range(1, epochs + 1): 151 | train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer) 152 | results['train_loss'].append(train_loss) 153 | results['train_acc@1'].append(train_acc_1) 154 | results['train_acc@5'].append(train_acc_5) 155 | test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None) 156 | results['test_loss'].append(test_loss) 157 | results['test_acc@1'].append(test_acc_1) 158 | results['test_acc@5'].append(test_acc_5) 159 | # save statistics 160 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 161 | data_frame.to_csv('{}_linear_statistics.csv'.format(save_name_pre), index_label='epoch') 162 | if test_acc_1 > best_acc: 163 | best_acc = test_acc_1 164 | torch.save(model.state_dict(), '{}_linear_model.pth'.format(save_name_pre)) 165 | -------------------------------------------------------------------------------- /Omniglot/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pandas as pd 5 | import torch 6 | import torch.optim as optim 7 | from thop import profile, clever_format 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | import torch.nn.functional as F 11 | import math 12 | 13 | import utils 14 | from model import Model, Omniglot_Model, Recon_Omniglot_Model 15 | 16 | from compute_MI_CondEntro import init_models, information 17 | 18 | 19 | def contrastive_loss(out_1, out_2, _type='NCE'): 20 | # compute loss 21 | if _type == 'NCE': 22 | # [2*B, D] 23 | out = torch.cat([out_1, out_2], dim=0) 24 | # [2*B, 2*B] 25 | sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature) 26 | mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool() 27 | # [2*B, 2*B-1] 28 | sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1) 29 | 30 | pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) 31 | # [2*B] 32 | pos_sim = torch.cat([pos_sim, pos_sim], dim=0) 33 | loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean() 34 | elif _type == 'JS': 35 | temperature_JS = temperature 36 | 37 | scores_12 = torch.mm(out_1, out_2.t().contiguous()) / temperature_JS 38 | first_term = -F.softplus(-scores_12.diag()).mean() 39 | 40 | n = scores_12.size(0) 41 | second_term_12 = (torch.sum(F.softplus(scores_12)) - 42 | torch.sum(F.softplus(scores_12.diag()))) / (n * (n - 1.)) 43 | scores_11 = torch.mm(out_1, out_1.t().contiguous()) / temperature_JS 44 | second_term_11 = (torch.sum(F.softplus(scores_11)) - 45 | torch.sum(F.softplus(scores_11.diag()))) / (n * (n - 1.)) 46 | scores_22 = torch.mm(out_2, out_2.t().contiguous()) / temperature_JS 47 | second_term_22 = (torch.sum(F.softplus(scores_22)) - 48 | torch.sum(F.softplus(scores_22.diag()))) / (n * (n - 1.)) 49 | second_term = (second_term_11 + second_term_22 + second_term_12*2.) / 4. 50 | loss = -1. * (first_term - second_term) 51 | return loss 52 | 53 | 54 | def inverse_perdictive_loss(feature_1, feature_2): 55 | # symmetric 56 | return F.mse_loss(feature_1, feature_2) + F.mse_loss(feature_2, feature_1) 57 | 58 | 59 | # Use MSE_Loss here (assuming Gaussian) 60 | # Other losses can be binary cross_entropy is assuming Bernoulli 61 | def forward_predictive_loss(target, recon, _type='RevBCE'): 62 | if _type == 'RevBCE': 63 | # empirically good 64 | return F.binary_cross_entropy_with_logits(target, recon) 65 | elif _type == 'BCE': 66 | # assuming factorized bernoulli 67 | return F.binary_cross_entropy_with_logits(recon, target) 68 | elif _type == 'MSE': 69 | # assuming diagnonal Gaussian 70 | # target has [0,1] 71 | # change it to [-\infty, \infty] 72 | # inverse of sigmoid: x = ln(y/(1-y)) 73 | target = target.clamp(min=1e-4, max=1. - 1e-4) 74 | target = torch.log(target / (1.-target)) 75 | return F.mse_loss(recon, target) 76 | 77 | 78 | # train for one epoch to learn unique features 79 | def train(epoch, net, data_loader, train_optimizer, recon_net, loss_type, recon_optimizer, info_dct=None): 80 | net.train() 81 | if recon_net is not None: 82 | recon_net.train() 83 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) 84 | 85 | for pos_1, pos_2, target in train_bar: 86 | pos_1, pos_2, target = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True),\ 87 | target.cuda(non_blocking=True) 88 | #norm = False if loss_type == 7 else True 89 | norm = True 90 | 91 | feature_1, out_1 = net(pos_1, norm) 92 | feature_2, out_2 = net(pos_2, norm) 93 | 94 | if loss_type == 1 or loss_type == 2 or loss_type == 5 or loss_type == 6 or loss_type == 7\ 95 | or loss_type == 11 or loss_type == 12: 96 | contr_type = 'NCE' if not loss_type == 7 else 'JS' 97 | contr_loss = contrastive_loss(out_1, out_2, _type = contr_type) 98 | if loss_type == 2 or loss_type == 4 or loss_type == 6 or loss_type == 10\ 99 | or loss_type == 12: 100 | inver_loss = inverse_perdictive_loss(feature_1, feature_2) 101 | if loss_type == 3 or loss_type == 4 or loss_type == 5 or loss_type == 6 or loss_type == 8\ 102 | or loss_type == 9 or loss_type == 10 or loss_type == 11 or loss_type == 12: 103 | recon_for_2 = recon_net(feature_1) 104 | recon_for_1 = recon_net(feature_2) 105 | if loss_type == 3: 106 | recon_type = 'BCE' 107 | elif loss_type == 4 or loss_type == 5 or loss_type == 6 or loss_type == 8: 108 | recon_type = 'RevBCE' 109 | elif loss_type == 9 or loss_type == 10 or loss_type == 11 or loss_type == 12: 110 | recon_type = 'MSE' 111 | recon_loss = forward_predictive_loss(pos_1, recon_for_1, _type=recon_type) +\ 112 | forward_predictive_loss(pos_2, recon_for_2, _type=recon_type) 113 | recon_optimizer.zero_grad() 114 | train_optimizer.zero_grad() 115 | 116 | if loss_type == 1 or loss_type == 7: 117 | loss = contr_loss 118 | elif loss_type == 2: 119 | loss = contr_loss + inver_param*inver_loss 120 | elif loss_type == 3 or loss_type == 8 or loss_type == 9: 121 | loss = recon_param*recon_loss 122 | elif loss_type == 4 or loss_type == 10: 123 | loss = recon_param*recon_loss + inver_param*inver_loss 124 | elif loss_type == 5 or loss_type == 11: 125 | loss = contr_loss + recon_param*recon_loss 126 | elif loss_type == 6 or loss_type == 12: 127 | loss = contr_loss + recon_param*recon_loss +\ 128 | inver_param*inver_loss 129 | 130 | 131 | loss.backward() 132 | 133 | train_optimizer.step() 134 | if recon_optimizer is not None: 135 | recon_optimizer.step() 136 | 137 | total_num += batch_size 138 | total_loss += loss.item() * batch_size 139 | train_bar.set_description('Train Epoch: [{}/{}], loss_type: {}, Loss: {:.4f}'.format(\ 140 | epoch, epochs, loss_type, total_loss / total_num)) 141 | 142 | 143 | 144 | if info_dct is not None: 145 | inner_epochs = 200 if epoch < 100 else 80 146 | I_Z_S, I_Z_T, H_Z_T, H_Z_S = information(epoch, data_loader, inner_epochs, net,\ 147 | info_dct['info_models'], info_dct['info_optimizer']) 148 | 149 | info_dct['info_results']['I(Z;S)'].append(I_Z_S) 150 | info_dct['info_results']['I(Z;T)'].append(I_Z_T) 151 | info_dct['info_results']['H(Z|T)'].append(H_Z_T) 152 | info_dct['info_results']['H(Z|S)'].append(H_Z_S) 153 | 154 | return total_loss / total_num 155 | 156 | 157 | # test for one epoch, use weighted knn to find the most similar images' label to assign the test image 158 | def test(net, memory_data_loader, test_data_loader): 159 | net.eval() 160 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 161 | with torch.no_grad(): 162 | # generate feature bank 163 | for data, _, target in tqdm(memory_data_loader, desc='Feature extracting'): 164 | feature, out = net(data.cuda(non_blocking=True)) 165 | feature_bank.append(feature) 166 | # [D, N] 167 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 168 | # [N] 169 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device) 170 | # loop test data to predict the label by weighted knn search 171 | test_bar = tqdm(test_data_loader) 172 | for data, _, target in test_bar: 173 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 174 | feature, out = net(data) 175 | 176 | total_num += data.size(0) 177 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 178 | sim_matrix = torch.mm(feature, feature_bank) 179 | # [B, K] 180 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1) 181 | # [B, K] 182 | sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices) 183 | sim_weight = (sim_weight / temperature).exp() 184 | 185 | # counts for each class 186 | one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device) 187 | # [B*K, C] 188 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) 189 | # weighted score ---> [B, C] 190 | pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1) 191 | 192 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 193 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 194 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 195 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%' 196 | .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100)) 197 | 198 | return total_top1 / total_num * 100, total_top5 / total_num * 100 199 | 200 | 201 | # test for one epoch, use weighted knn to find the most similar images' label to assign the test image 202 | def omniglot_test(net, memory_data_loader, test_data_loader): 203 | net.eval() 204 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 205 | with torch.no_grad(): 206 | # generate feature bank 207 | for data, target in tqdm(memory_data_loader, desc='Feature extracting'): 208 | feature, out = net(data.cuda(non_blocking=True)) 209 | feature_bank.append(feature) 210 | # [D, N] 211 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 212 | # [N] 213 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device) 214 | # loop test data to predict the label by weighted knn search 215 | test_bar = tqdm(test_data_loader) 216 | for data, target in test_bar: 217 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 218 | feature, out = net(data) 219 | 220 | total_num += data.size(0) 221 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 222 | sim_matrix = torch.mm(feature, feature_bank) 223 | # [B, K] 224 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1) 225 | # [B, K] 226 | sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices) 227 | sim_weight = (sim_weight / temperature).exp() 228 | 229 | # counts for each class 230 | one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device) 231 | # [B*K, C] 232 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) 233 | # weighted score ---> [B, C] 234 | pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1) 235 | 236 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 237 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 238 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 239 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%' 240 | .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100)) 241 | 242 | return total_top1 / total_num * 100, total_top5 / total_num * 100 243 | 244 | 245 | if __name__ == '__main__': 246 | parser = argparse.ArgumentParser(description='Omniglot Experiments') 247 | parser.add_argument('--temperature', default=0.1, type=float, help='Temperature used in softmax\ 248 | (0.1 for omniglot and 0.5 for cifar)') 249 | parser.add_argument('--k', default=1, type=int, help='Top k most similar images used to predict the label\ 250 | (1 for omniglot and 200 for cifar)') 251 | parser.add_argument('--batch_size', default=482, type=int, help='Number of images in each mini-batch\ 252 | (964/2=482 for omniglot and 512 for cifar)') 253 | parser.add_argument('--epochs', default=1000, type=int, help='Number of sweeps over the dataset to train') 254 | parser.add_argument('--resnet_depth', default=18, type=int, help='The depth of the resnet\ 255 | (only for cifar)') 256 | parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for latent vector\ 257 | (only for cifar)') 258 | parser.add_argument('--dataset', default='omniglot', type=str, help='omniglot or cifar') 259 | parser.add_argument('--trial', default=99, type=int, help='number of trial') 260 | parser.add_argument('--loss_type', default=1, type=int, help='1: only contrast (NCE),\ 261 | 2: contrast (NCE) + inverse_pred, 3: only forward_pred (BCE),\ 262 | 4: forward_pred (RevBCE) + inverse_pred, 5: contrast (NCE) + forward_pred (RevBCE),\ 263 | 6: contrast (NCE) + forward_pred (RevBCE) + inverse_pred,\ 264 | 7: only contrast (JS), 8: only forward_pred (RevBCE),\ 265 | 9: only forward_pred (MSE), 10: forward_pred (MSE) + inverse_pred,\ 266 | 11: contrast (NCE) + forward_pred (MSE),\ 267 | 12: contrast (NCE) + forward_pred (MSE) + inverse_pred') 268 | parser.add_argument('--inver_param', default=0.001, type=float, help='Hyper_param for inverse_pred') 269 | parser.add_argument('--recon_param', default=0.001, type=float, help='Hyper_param for forward_pred') 270 | parser.add_argument('--with_info', default=False, action='store_true') 271 | 272 | # args parse 273 | args = parser.parse_args() 274 | feature_dim, temperature, k = args.feature_dim, args.temperature, args.k 275 | batch_size, epochs = args.batch_size, args.epochs 276 | resnet_depth = args.resnet_depth 277 | trial = args.trial 278 | 279 | recon_param = args.recon_param 280 | inver_param = args.inver_param 281 | 282 | # data prepare 283 | if args.dataset == 'cifar': 284 | train_data = utils.CIFAR10Pair(root='data', train=True, transform=utils.train_transform, download=True) 285 | else: 286 | # our self-supervised signal construction strategy 287 | train_data = utils.Our_Omniglot(root='data', background=True, transform=utils.omniglot_train_transform, 288 | character_target_transform=None, alphabet_target_transform=None, download=True, 289 | contrast_training=True) 290 | # self-supervised signal construction strategy in SimCLR 291 | #train_data = utils.Our_Omniglot_v2(root='data', background=True, transform=utils.omniglot_train_transform, 292 | # character_target_transform=None, alphabet_target_transform=None, download=True, 293 | # contrast_training=True) 294 | 295 | 296 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, 297 | drop_last=True) 298 | 299 | if args.dataset == 'cifar': 300 | memory_data = utils.CIFAR10Pair(root='data', train=True, transform=utils.test_transform, download=True) 301 | else: 302 | memory_data = utils.Our_Omniglot(root='data', background=False, transform=utils.omniglot_test_transform, 303 | character_target_transform=None, alphabet_target_transform=None, download=True, 304 | eval_split_train=True, out_character=True, contrast_training=False) 305 | memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 306 | 307 | 308 | if args.dataset == 'cifar': 309 | test_data = utils.CIFAR10Pair(root='data', train=False, transform=utils.test_transform, download=True) 310 | else: 311 | test_data = utils.Our_Omniglot(root='data', background=False, transform=utils.omniglot_test_transform, 312 | character_target_transform=None, alphabet_target_transform=None, download=True, 313 | eval_split_train=False, out_character=True, contrast_training=False) 314 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 315 | 316 | # calculating information 317 | if args.with_info: 318 | info_models, info_optimizer = init_models('Feature_Information') 319 | info_results = {'I(Z;S)': [], 'I(Z;T)': [], 'H(Z|T)': [], 'H(Z|S)': []} 320 | info_dct = { 321 | 'info_models': info_models, 322 | 'info_optimizer': info_optimizer, 323 | 'info_results': info_results, 324 | } 325 | else: 326 | info_dct = None 327 | 328 | # model setup and optimizer config 329 | if args.dataset == 'cifar': 330 | model = Model(feature_dim, resnet_depth=resnet_depth).cuda() 331 | recon_model = None 332 | else: 333 | model = Omniglot_Model().cuda() 334 | recon_model = Recon_Omniglot_Model().cuda() if args.loss_type >= 3 else None 335 | 336 | if args.dataset == 'cifar': 337 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),)) 338 | else: 339 | flops, params = profile(model, inputs=(torch.randn(1, 1, 28, 28).cuda(),)) 340 | #flops, params = profile(model, inputs=(torch.randn(1, 1, 56, 56).cuda(),)) 341 | #flops, params = profile(model, inputs=(torch.randn(1, 1, 105, 105).cuda(),)) 342 | if recon_model is not None: 343 | recon_flops, recon_params = profile(recon_model, inputs=(torch.randn(1, 1024).cuda(),)) 344 | flops, params = clever_format([flops, params]) 345 | print('# Model Params: {} FLOPs: {}'.format(params, flops)) 346 | if recon_model is not None: 347 | recon_flops, recon_params = clever_format([recon_flops, recon_params]) 348 | print('# Recon_Model Params: {} FLOPs: {}'.format(recon_params, recon_flops)) 349 | optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6) 350 | recon_optimizer = optim.Adam(recon_model.parameters(), lr=1e-3, weight_decay=1e-6) if recon_model is not None \ 351 | else None 352 | #optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=1e-6) 353 | #milestone1, milestone2 = int(args.epochs*0.4), int(args.epochs*0.7) 354 | #lr_decay = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[milestone1, milestone2], gamma=0.1) 355 | if args.dataset == 'cifar': 356 | c = len(memory_data.classes) 357 | else: 358 | c = 659 #c = 20 359 | 360 | # training loop 361 | results = {'train_loss': [], 'test_acc@1': [], 'test_acc@5': []} 362 | if args.dataset == 'cifar': 363 | save_name_pre = '{}_{}_{}_{}_{}_{}_{}'.format(args.dataset, resnet_depth, feature_dim, temperature, k, \ 364 | batch_size, epochs) 365 | else: 366 | save_name_pre = '{}_{}_{}_{}_{}'.format(args.dataset, args.loss_type, recon_param, inver_param, trial) 367 | if not os.path.exists('results'): 368 | os.mkdir('results') 369 | best_acc = 0.0 370 | for epoch in range(1, epochs + 1): 371 | if(epoch%4)==1: 372 | train_loss = train(epoch, model, train_loader, optimizer, recon_model, args.loss_type, recon_optimizer, info_dct) 373 | else: 374 | train_loss = train(epoch, model, train_loader, optimizer, recon_model, args.loss_type, recon_optimizer, None) 375 | #lr_decay.step() 376 | results['train_loss'].append(train_loss) 377 | 378 | if args.dataset == 'cifar': 379 | test_acc_1, test_acc_5 = test(model, memory_loader, test_loader) 380 | else: 381 | test_acc_1, test_acc_5 = omniglot_test(model, memory_loader, test_loader) 382 | 383 | results['test_acc@1'].append(test_acc_1) 384 | results['test_acc@5'].append(test_acc_5) 385 | 386 | # save statistics 387 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 388 | data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch') 389 | 390 | if (epoch%4)==1 and info_dct is not None: 391 | info_data_frame = pd.DataFrame(data=info_dct['info_results'], index=range(1, epoch + 1, 4)) 392 | if args.loss_type==1: 393 | info_data_frame.to_csv('results/Feature_Information.csv', index_label=info_dct['epoch']) 394 | elif args.loss_type==2: 395 | info_data_frame.to_csv('results/Feature_Information_min_H.csv', index_label=info_dct['epoch']) 396 | 397 | if args.dataset == 'cifar': 398 | if test_acc_1 > best_acc: 399 | best_acc = test_acc_1 400 | torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre)) 401 | -------------------------------------------------------------------------------- /Omniglot/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models.resnet import resnet18, resnet34, resnet50 5 | 6 | 7 | class Model(nn.Module): 8 | def __init__(self, feature_dim=128, resnet_depth=18): 9 | super(Model, self).__init__() 10 | 11 | self.f = [] 12 | if resnet_depth == 18: 13 | my_resnet = resnet18() 14 | resnet_output_dim = 512 15 | elif resnet_depth == 34: 16 | my_resnet = resnet34() 17 | resnet_output_dim = 512 18 | elif resnet_depth == 50: 19 | my_resnet = resnet50() 20 | resnet_output_dim = 2048 21 | 22 | for name, module in my_resnet.named_children(): 23 | if name == 'conv1': 24 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 25 | if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d): 26 | self.f.append(module) 27 | # encoder 28 | self.f = nn.Sequential(*self.f) 29 | # projection head 30 | self.g = nn.Sequential(nn.Linear(resnet_output_dim, 512, bias=False), nn.BatchNorm1d(512), 31 | nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True)) 32 | 33 | def forward(self, x): 34 | x = self.f(x) 35 | feature = torch.flatten(x, start_dim=1) 36 | out = self.g(feature) 37 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 38 | 39 | 40 | # for 105x105 size 41 | ''' 42 | class Omniglot_Model(nn.Module): 43 | def __init__(self): 44 | super(Omniglot_Model, self).__init__() 45 | # encoder 46 | self.f = nn.Sequential( 47 | nn.Conv2d(in_channels=1, out_channels=64, kernel_size=10, stride=1, padding=0), # out: 96 48 | nn.BatchNorm2d(num_features=64), 49 | nn.ReLU(inplace=True), 50 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 48 51 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=0), # out: 42 52 | nn.BatchNorm2d(num_features=128), 53 | nn.ReLU(inplace=True), 54 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 21 55 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=1, padding=0), # out: 18 56 | nn.BatchNorm2d(num_features=128), 57 | nn.ReLU(inplace=True), 58 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 9 59 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=1, padding=0), # out: 6 60 | nn.BatchNorm2d(num_features=128), 61 | nn.ReLU(inplace=True), 62 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 3 63 | nn.Flatten(), 64 | nn.Linear(9*128, 1024), 65 | ) 66 | # projection head 67 | self.g = nn.Identity() 68 | 69 | def forward(self, x): 70 | feature = self.f(x) 71 | out = self.g(feature) 72 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 73 | ''' 74 | 75 | 76 | # for 28x28 size (using max_pool) 77 | class Omniglot_Model(nn.Module): 78 | def __init__(self): 79 | super(Omniglot_Model, self).__init__() 80 | # encoder 81 | self.f = nn.Sequential( 82 | nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28 83 | nn.BatchNorm2d(num_features=128), 84 | nn.ReLU(inplace=True), 85 | #nn.MaxPool2d(kernel_size=2, stride=2), 86 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28 87 | nn.BatchNorm2d(num_features=128), 88 | nn.ReLU(inplace=True), 89 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 14 90 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 14 91 | nn.BatchNorm2d(num_features=128), 92 | nn.ReLU(inplace=True), 93 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 7 94 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 7 95 | nn.BatchNorm2d(num_features=128), 96 | nn.ReLU(inplace=True), 97 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 3 98 | nn.Flatten(), 99 | nn.Linear(9*128, 1024), 100 | ) 101 | # projection head 102 | self.g = nn.Identity() 103 | 104 | def forward(self, x, norm=True): 105 | feature = self.f(x) 106 | out = self.g(feature) 107 | if norm: 108 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 109 | else: 110 | return F.normalize(feature, dim=-1), out 111 | 112 | 113 | # for 28x28 size (not using maxpool) 114 | ''' 115 | class Omniglot_Model(nn.Module): 116 | def __init__(self): 117 | super(Omniglot_Model, self).__init__() 118 | # encoder 119 | self.f = nn.Sequential( 120 | nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28 121 | nn.BatchNorm2d(num_features=128), 122 | nn.ReLU(inplace=True), 123 | #nn.MaxPool2d(kernel_size=2, stride=2), 124 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1), # out: 14 125 | nn.BatchNorm2d(num_features=128), 126 | nn.ReLU(inplace=True), 127 | #nn.MaxPool2d(kernel_size=2, stride=2), # out: 14 128 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1), # out: 7 129 | nn.BatchNorm2d(num_features=128), 130 | nn.ReLU(inplace=True), 131 | #nn.MaxPool2d(kernel_size=2, stride=2), # out: 7 132 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=0), # out: 3 133 | nn.BatchNorm2d(num_features=128), 134 | nn.ReLU(inplace=True), 135 | #nn.MaxPool2d(kernel_size=2, stride=2), # out: 3 136 | nn.Flatten(), 137 | nn.Linear(9*128, 1024), 138 | ) 139 | # projection head 140 | self.g = nn.Identity() 141 | 142 | def forward(self, x): 143 | feature = self.f(x) 144 | out = self.g(feature) 145 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 146 | ''' 147 | 148 | 149 | # + 150 | # for 28x28 size 151 | class Lambda(nn.Module): 152 | def __init__(self, func): 153 | super(Lambda, self).__init__() 154 | self.func = func 155 | 156 | def forward(self, x): 157 | return self.func(x) 158 | 159 | class Recon_Omniglot_Model(nn.Module): 160 | def __init__(self): 161 | super(Recon_Omniglot_Model, self).__init__() 162 | # reconstructer (approximately the inverse of the encoder) 163 | self.f = nn.Sequential( 164 | nn.Linear(1024, 9*128, bias=False), 165 | nn.BatchNorm1d(num_features=9*128), 166 | nn.ReLU(inplace=True), # (9*128 -> 3*3*128) 167 | Lambda(lambda x: x.view(-1, 128, 3, 3)), 168 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=0, 169 | output_padding=0, bias=False), # out: 7 170 | nn.BatchNorm2d(num_features=128), 171 | nn.ReLU(inplace=True), 172 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, 173 | output_padding=1, bias=False), # out: 14 174 | nn.BatchNorm2d(num_features=128), 175 | nn.ReLU(inplace=True), 176 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, 177 | output_padding=1, bias=False), # out: 28 178 | nn.BatchNorm2d(num_features=128), 179 | nn.ReLU(inplace=True), 180 | nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1, 181 | output_padding=0, bias=True), # out: 28 182 | #nn.Sigmoid(), 183 | ) 184 | 185 | def forward(self, x): 186 | recon = self.f(x) 187 | return recon 188 | 189 | 190 | # - 191 | 192 | # for 56x56 size 193 | ''' 194 | class Omniglot_Model(nn.Module): 195 | def __init__(self): 196 | super(Omniglot_Model, self).__init__() 197 | # encoder 198 | self.f = nn.Sequential( 199 | nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 56 200 | nn.BatchNorm2d(num_features=128), 201 | nn.ReLU(inplace=True), 202 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 28 203 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28 204 | nn.BatchNorm2d(num_features=128), 205 | nn.ReLU(inplace=True), 206 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 14 207 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 14 208 | nn.BatchNorm2d(num_features=128), 209 | nn.ReLU(inplace=True), 210 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 7 211 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 7 212 | nn.BatchNorm2d(num_features=128), 213 | nn.ReLU(inplace=True), 214 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 3 215 | nn.Flatten(), 216 | nn.Linear(9*128, 1024), 217 | ) 218 | # projection head 219 | self.g = nn.Identity() 220 | 221 | def forward(self, x): 222 | feature = self.f(x) 223 | out = self.g(feature) 224 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 225 | ''' 226 | -------------------------------------------------------------------------------- /Omniglot/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | from torchvision.datasets import CIFAR10, Omniglot 4 | 5 | # + 6 | import cv2 7 | import numpy as np 8 | 9 | from torchvision.datasets.utils import check_integrity, list_dir, list_files 10 | from os.path import join 11 | 12 | 13 | # - 14 | 15 | # np.random.seed(0) 16 | 17 | class GaussianBlur(object): 18 | # Implements Gaussian blur as described in the SimCLR paper 19 | def __init__(self, kernel_size, min=0.1, max=2.0): 20 | self.min = min 21 | self.max = max 22 | # kernel size is set to be 10% of the image height/width 23 | self.kernel_size = kernel_size 24 | 25 | def __call__(self, sample): 26 | sample = np.array(sample) 27 | 28 | # blur the image with a 50% chance 29 | prob = np.random.random_sample() 30 | 31 | if prob < 0.5: 32 | sigma = (self.max - self.min) * np.random.random_sample() + self.min 33 | sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) 34 | 35 | return sample 36 | 37 | 38 | class CIFAR10Pair(CIFAR10): 39 | """CIFAR10 Dataset. 40 | """ 41 | 42 | def __getitem__(self, index): 43 | img, target = self.data[index], self.targets[index] 44 | img = Image.fromarray(img) 45 | 46 | if self.transform is not None: 47 | pos_1 = self.transform(img) 48 | pos_2 = self.transform(img) 49 | 50 | if self.target_transform is not None: 51 | target = self.target_transform(target) 52 | 53 | return pos_1, pos_2, target 54 | 55 | class Our_Omniglot(Omniglot): 56 | ''' 57 | The code is adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/omniglot.py 58 | [Usage] 59 | contrastive_training_data = Our_Omniglot(root='data', background=True, transform=None, 60 | character_target_transform=None, alphabet_target_transform=None, download=True, 61 | contrast_training=True) 62 | classifier_train_data = Our_Omniglot(root='data', background=False, transform=None, 63 | character_target_transform=None, alphabet_target_transform=None, download=True, 64 | eval_split_train=True, out_character=False, contrast_training=False) 65 | classifier_test_data = Our_Omniglot(root='data', background=False, transform=None, 66 | character_target_transform=None, alphabet_target_transform=None, download=True, 67 | eval_split_train=False, out_character=False, contrast_training=False) 68 | ''' 69 | def __init__(self, root, background=True, transform=None, character_target_transform=None, 70 | alphabet_target_transform=None, download=False, eval_split_train=True, out_character=False, 71 | contrast_training=True): 72 | super(Omniglot, self).__init__(join(root, self.folder), transform=transform, 73 | target_transform=character_target_transform) 74 | self.background = background 75 | 76 | if download: 77 | self.download() 78 | 79 | if not self._check_integrity(): 80 | raise RuntimeError('Dataset not found or corrupted.' + 81 | ' You can use download=True to download it') 82 | 83 | self.character_target_transform = character_target_transform 84 | self.alphabet_target_transform = alphabet_target_transform 85 | 86 | 87 | self.target_folder = join(self.root, self._get_target_folder()) 88 | self._alphabets = list_dir(self.target_folder) 89 | self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] 90 | for a in self._alphabets], []) 91 | self._character_images = [[(image, idx, self._alphabets.index(character.split('/')[0])) 92 | for image in list_files(join(self.target_folder, character), '.png')] 93 | for idx, character in enumerate(self._characters)] 94 | self._flat_character_images = sum(self._character_images, []) 95 | 96 | self.contrast_training = contrast_training 97 | 98 | # we adopt contrastive training in the background split 99 | if self.contrast_training: 100 | # 20 samples per character 101 | self._flat_character_images = np.array(self._flat_character_images).reshape(-1,20,3) 102 | self.out_character = out_character 103 | # we adopt standard classification training in the evaluation split 104 | else: 105 | # 20 samples per character 106 | self._flat_character_images = np.array(self._flat_character_images).reshape(-1,20,3) 107 | if eval_split_train: 108 | self._flat_character_images = self._flat_character_images[:,:5,:] 109 | else: 110 | self._flat_character_images = self._flat_character_images[:,5:,:] 111 | self._flat_character_images = self._flat_character_images.reshape(-1,3) 112 | self.out_character = out_character 113 | if self.out_character: 114 | self.targets = self._flat_character_images[:,1].astype(np.int64) 115 | else: 116 | self.targets = self._flat_character_images[:,2].astype(np.int64) 117 | 118 | def __getitem__(self, index): 119 | """ 120 | Args: 121 | index (int): Index 122 | Returns: 123 | when contrastive training: 124 | tuple: (image0, image1) 125 | image0 and image1 are belong to the same class 126 | when not contrastive training: 127 | tuple: (image, character_target, alphabet_target) 128 | where character_target is index of the target character class 129 | and alphabet_target is index of the target alphabet class. 130 | """ 131 | if self.contrast_training: 132 | random_idx = np.random.randint(20, size=2) 133 | image_name_0, character_class_0, alphabet_class_0 = self._flat_character_images[index,random_idx[0]] 134 | character_class_0, alphabet_class_0 = int(character_class_0), int(alphabet_class_0) 135 | image_name_1, character_class_1, alphabet_class_1 = self._flat_character_images[index,random_idx[1]] 136 | character_class_1, alphabet_class_1 = int(character_class_1), int(alphabet_class_1) 137 | image_path_0 = join(self.target_folder, self._characters[character_class_0], image_name_0) 138 | image_0 = Image.open(image_path_0, mode='r').convert('L') 139 | image_path_1 = join(self.target_folder, self._characters[character_class_1], image_name_1) 140 | image_1 = Image.open(image_path_1, mode='r').convert('L') 141 | 142 | if self.transform: 143 | image_0 = self.transform(image_0) 144 | image_1 = self.transform(image_1) 145 | 146 | if self.character_target_transform: 147 | character_class_0 = self.character_target_transform(character_class_0) 148 | # character_class_1 = self.character_target_transform(character_class_1) 149 | if self.alphabet_target_transform: 150 | alphabet_class_0 = self.alphabet_target_transform(alphabet_class_0) 151 | # alphabet_class_1 = self.alphabet_target_transform(alphabet_class_1) 152 | 153 | if self.out_character: 154 | return image_0, image_1, character_class_0#, character_class_1, alphabet_class_0, alphabet_class_1 155 | else: 156 | return image_0, image_1, alphabet_class_0#, character_class_1, alphabet_class_0, alphabet_class_1 157 | else: 158 | image_name, character_class, alphabet_class = self._flat_character_images[index] 159 | character_class, alphabet_class = int(character_class), int(alphabet_class) 160 | image_path = join(self.target_folder, self._characters[character_class], image_name) 161 | image = Image.open(image_path, mode='r').convert('L') 162 | 163 | if self.transform: 164 | image = self.transform(image) 165 | 166 | if self.character_target_transform: 167 | character_class = self.character_target_transform(character_class) 168 | if self.alphabet_target_transform: 169 | alphabet_class = self.alphabet_target_transform(alphabet_class) 170 | 171 | if self.out_character: 172 | return image, character_class 173 | else: 174 | return image, alphabet_class 175 | 176 | 177 | class Our_Omniglot_v2(Omniglot): 178 | ''' 179 | The code is adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/omniglot.py 180 | [Usage] 181 | contrastive_training_data = Our_Omniglot(root='data', background=True, transform=None, 182 | character_target_transform=None, alphabet_target_transform=None, download=True, 183 | contrast_training=True) 184 | classifier_train_data = Our_Omniglot(root='data', background=False, transform=None, 185 | character_target_transform=None, alphabet_target_transform=None, download=True, 186 | eval_split_train=True, out_character=False, contrast_training=False) 187 | classifier_test_data = Our_Omniglot(root='data', background=False, transform=None, 188 | character_target_transform=None, alphabet_target_transform=None, download=True, 189 | eval_split_train=False, out_character=False, contrast_training=False) 190 | ''' 191 | def __init__(self, root, background=True, transform=None, character_target_transform=None, 192 | alphabet_target_transform=None, download=False, eval_split_train=True, out_character=True, 193 | contrast_training=True): 194 | super(Omniglot, self).__init__(join(root, self.folder), transform=transform, 195 | target_transform=character_target_transform) 196 | self.background = background 197 | 198 | if download: 199 | self.download() 200 | 201 | if not self._check_integrity(): 202 | raise RuntimeError('Dataset not found or corrupted.' + 203 | ' You can use download=True to download it') 204 | 205 | self.character_target_transform = character_target_transform 206 | self.alphabet_target_transform = alphabet_target_transform 207 | 208 | 209 | self.target_folder = join(self.root, self._get_target_folder()) 210 | self._alphabets = list_dir(self.target_folder) 211 | self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] 212 | for a in self._alphabets], []) 213 | self._character_images = [[(image, idx, self._alphabets.index(character.split('/')[0])) 214 | for image in list_files(join(self.target_folder, character), '.png')] 215 | for idx, character in enumerate(self._characters)] 216 | self._flat_character_images = sum(self._character_images, []) 217 | 218 | self.contrast_training = contrast_training 219 | 220 | # 20 samples per character 221 | self._flat_character_images = np.array(self._flat_character_images).reshape(-1,20,3) 222 | if eval_split_train: 223 | self._flat_character_images = self._flat_character_images[:,:5,:] 224 | else: 225 | self._flat_character_images = self._flat_character_images[:,5:,:] 226 | self._flat_character_images = self._flat_character_images.reshape(-1,3) 227 | self.out_character = out_character 228 | if self.out_character: 229 | self.targets = self._flat_character_images[:,1].astype(np.int64) 230 | else: 231 | self.targets = self._flat_character_images[:,2].astype(np.int64) 232 | 233 | def __getitem__(self, index): 234 | """ 235 | Args: 236 | index (int): Index 237 | Returns: 238 | when contrastive training: 239 | tuple: (image0, image1) 240 | image0 and image1 are the same image with different image augmentations 241 | when not contrastive training: 242 | tuple: (image, character_target, alphabet_target) 243 | where character_target is index of the target character class 244 | and alphabet_target is index of the target alphabet class. 245 | """ 246 | image_name, character_class, alphabet_class = self._flat_character_images[index] 247 | character_class, alphabet_class = int(character_class), int(alphabet_class) 248 | image_path = join(self.target_folder, self._characters[character_class], image_name) 249 | image = Image.open(image_path, mode='r').convert('L') 250 | 251 | if self.character_target_transform: 252 | character_class = self.character_target_transform(character_class) 253 | if self.alphabet_target_transform: 254 | alphabet_class = self.alphabet_target_transform(alphabet_class) 255 | 256 | if self.contrast_training: 257 | if self.transform: 258 | image_0 = self.transform(image) 259 | image_1 = self.transform(image) 260 | 261 | if self.out_character: 262 | return image_0, image_1, character_class 263 | else: 264 | return image_0, image_1, alphabet_class 265 | else: 266 | if self.transform: 267 | image = self.transform(image) 268 | 269 | if self.out_character: 270 | return image, character_class 271 | else: 272 | return image, alphabet_class 273 | 274 | # GausssianBlur is False for CIFAR10 275 | 276 | train_transform = transforms.Compose([ 277 | transforms.RandomResizedCrop(32), 278 | transforms.RandomHorizontalFlip(p=0.5), 279 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 280 | transforms.RandomGrayscale(p=0.2), 281 | #GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])), 282 | transforms.ToTensor(), 283 | #transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 284 | ]) 285 | 286 | omniglot_train_transform = transforms.Compose([ 287 | transforms.RandomAffine(degrees=10.0, translate=(0.1, 0.1)), 288 | #transforms.RandomResizedCrop(105, scale=(0.85, 1.0), ratio=(0.8, 1.25)), 289 | #transforms.RandomResizedCrop(56, scale=(0.85, 1.0), ratio=(0.8, 1.25)), 290 | transforms.RandomResizedCrop(28, scale=(0.85, 1.0), ratio=(0.8, 1.25)), 291 | transforms.ToTensor(), 292 | lambda x: 1. - x, 293 | ]) 294 | 295 | test_transform = transforms.Compose([ 296 | transforms.ToTensor(), 297 | #transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 298 | ]) 299 | 300 | omniglot_test_transform = transforms.Compose([ 301 | #transforms.Resize(105), 302 | #transforms.Resize(56), 303 | transforms.Resize(28), 304 | transforms.ToTensor(), 305 | lambda x: 1. - x, 306 | ]) 307 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo contains codes to reproduce the results in the paper. 2 | 3 | [PDF link](https://arxiv.org/pdf/2006.05576.pdf) 4 | 5 | I will make this repo much more informative in a later date. But welcome any questions regarding the paper or the experiments. 6 | --------------------------------------------------------------------------------