├── CIFAR10
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── data.py
├── data_util.py
├── imagenet_subsets
│ ├── 10percent.txt
│ └── 1percent.txt
├── lars_optimizer.py
├── model.py
├── model_util.py
├── objective.py
├── requirements.txt
├── resnet.py
└── run.py
├── LICENSE
├── MSCOCO
├── README.md
├── ResNet_baseline.ipynb
└── main.ipynb
├── Omniglot
├── README.md
├── compute_MI_CondEntro.py
├── linear.py
├── main.py
├── model.py
└── utils.py
└── README.md
/CIFAR10/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | SimCLR needs to maintain permanent compatibility with the pre-trained model
4 | files, so we do not plan to make any major changes to this library (other than
5 | what was promised in the README). However, we can accept small patches related
6 | to re-factoring and documentation. To submit contributes, there are just a few
7 | small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
32 |
--------------------------------------------------------------------------------
/CIFAR10/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/CIFAR10/README.md:
--------------------------------------------------------------------------------
1 | # CIFAR10 Experiments
2 |
3 | The code is adapted from [here](https://github.com/google-research/simclr)
4 |
5 |
6 | ## Usage
7 |
8 | ### Evaluating Self-supervised Representations
9 |
10 | Contrastive Learning Objective only
11 | ```
12 | python run.py --train_mode=pretrain --train_batch_size=512 --train_epochs=1000 \
13 | --learning_rate=1.0 --weight_decay=1e-6 --dataset=cifar10 --image_size=32 \
14 | --eval_split=test --resnet_depth=18 --use_blur=False --color_jitter_strength=0.5 \
15 | --model_dir=/root/data/githubs/simclr_models/c10_cpc_1 --inv_pred_coeff=0.0 \
16 | --use_tpu=False --temperature=0.5 --hidden_norm=True
17 |
18 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=4 \
19 | --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)LARSOptimizer|head)' \
20 | --global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0 \
21 | --train_epochs=100 --train_batch_size=512 --warmup_epochs=0 --dataset=cifar10 --image_size=32 \
22 | --eval_split=test --resnet_depth=18 --checkpoint=/root/data/githubs/simclr_models/c10_cpc_1 \
23 | --model_dir=/root/data/githubs/simclr_models/c10_cpc_1/ft --inv_pred_coeff=0.0 --use_tpu=False \
24 | --temperature=0.5 --hidden_norm=True
25 | ```
26 |
27 | Contrastive Learning Objective + Inverse Predictive Learning Objective
28 | ```
29 | python run.py --train_mode=pretrain --train_batch_size=512 --train_epochs=1000 \
30 | --learning_rate=1.0 --weight_decay=1e-6 --dataset=cifar10 --image_size=32 \
31 | --eval_split=test --resnet_depth=18 --use_blur=False --color_jitter_strength=0.5 \
32 | --model_dir=/root/data/githubs/simclr_models/c10_cpc_inv_1 --inv_pred_coeff=0.03 \
33 | --use_tpu=False --temperature=0.5 --hidden_norm=True
34 |
35 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=4 \
36 | --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)LARSOptimizer|head)' \
37 | --global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0 \
38 | --train_epochs=100 --train_batch_size=512 --warmup_epochs=0 --dataset=cifar10 --image_size=32 \
39 | --eval_split=test --resnet_depth=18 --checkpoint=/root/data/githubs/simclr_models/c10_cpc_1 \
40 | --model_dir=/root/data/githubs/simclr_models/c10_cpc_inv_1/ft --inv_pred_coeff=0.03 --use_tpu=False \
41 | --temperature=0.5 --hidden_norm=True
42 | ```
43 |
--------------------------------------------------------------------------------
/CIFAR10/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Data pipeline."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | from absl import flags
24 |
25 | import data_util as data_util
26 | import tensorflow.compat.v1 as tf
27 |
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | def pad_to_batch(dataset, batch_size):
32 | """Pad Tensors to specified batch size.
33 |
34 | Args:
35 | dataset: An instance of tf.data.Dataset.
36 | batch_size: The number of samples per batch of input requested.
37 |
38 | Returns:
39 | An instance of tf.data.Dataset that yields the same Tensors with the same
40 | structure as the original padded to batch_size along the leading
41 | dimension.
42 |
43 | Raises:
44 | ValueError: If the dataset does not comprise any tensors; if a tensor
45 | yielded by the dataset has an unknown number of dimensions or is a
46 | scalar; or if it can be statically determined that tensors comprising
47 | a single dataset element will have different leading dimensions.
48 | """
49 | def _pad_to_batch(*args):
50 | """Given Tensors yielded by a Dataset, pads all to the batch size."""
51 | flat_args = tf.nest.flatten(args)
52 |
53 | for tensor in flat_args:
54 | if tensor.shape.ndims is None:
55 | raise ValueError(
56 | 'Unknown number of dimensions for tensor %s.' % tensor.name)
57 | if tensor.shape.ndims == 0:
58 | raise ValueError('Tensor %s is a scalar.' % tensor.name)
59 |
60 | # This will throw if flat_args is empty. However, as of this writing,
61 | # tf.data.Dataset.map will throw first with an internal error, so we do
62 | # not check this case explicitly.
63 | first_tensor = flat_args[0]
64 | first_tensor_shape = tf.shape(first_tensor)
65 | first_tensor_batch_size = first_tensor_shape[0]
66 | difference = batch_size - first_tensor_batch_size
67 |
68 | for i, tensor in enumerate(flat_args):
69 | control_deps = []
70 | if i != 0:
71 | # Check that leading dimensions of this tensor matches the first,
72 | # either statically or dynamically. (If the first dimensions of both
73 | # tensors are statically known, the we have to check the static
74 | # shapes at graph construction time or else we will never get to the
75 | # dynamic assertion.)
76 | if (first_tensor.shape[:1].is_fully_defined() and
77 | tensor.shape[:1].is_fully_defined()):
78 | if first_tensor.shape[0] != tensor.shape[0]:
79 | raise ValueError(
80 | 'Batch size of dataset tensors does not match. %s '
81 | 'has shape %s, but %s has shape %s' % (
82 | first_tensor.name, first_tensor.shape,
83 | tensor.name, tensor.shape))
84 | else:
85 | curr_shape = tf.shape(tensor)
86 | control_deps = [tf.Assert(
87 | tf.equal(curr_shape[0], first_tensor_batch_size),
88 | ['Batch size of dataset tensors %s and %s do not match. '
89 | 'Shapes are' % (tensor.name, first_tensor.name), curr_shape,
90 | first_tensor_shape])]
91 |
92 | with tf.control_dependencies(control_deps):
93 | # Pad to batch_size along leading dimension.
94 | flat_args[i] = tf.pad(
95 | tensor, [[0, difference]] + [[0, 0]] * (tensor.shape.ndims - 1))
96 | flat_args[i].set_shape([batch_size] + tensor.shape.as_list()[1:])
97 |
98 | return tf.nest.pack_sequence_as(args, flat_args)
99 |
100 | return dataset.map(_pad_to_batch)
101 |
102 |
103 | def build_input_fn(builder, is_training):
104 | """Build input function.
105 |
106 | Args:
107 | builder: TFDS builder for specified dataset.
108 | is_training: Whether to build in training mode.
109 |
110 | Returns:
111 | A function that accepts a dict of params and returns a tuple of images and
112 | features, to be used as the input_fn in TPUEstimator.
113 | """
114 | def _input_fn(params):
115 | """Inner input function."""
116 | preprocess_fn_pretrain = get_preprocess_fn(is_training, is_pretrain=True)
117 | preprocess_fn_finetune = get_preprocess_fn(is_training, is_pretrain=False)
118 | num_classes = builder.info.features['label'].num_classes
119 |
120 | def map_fn(image, label):
121 | """Produces multiple transformations of the same batch."""
122 | if FLAGS.train_mode == 'pretrain':
123 | xs = []
124 | for _ in range(2): # Two transformations
125 | xs.append(preprocess_fn_pretrain(image))
126 | image = tf.concat(xs, -1)
127 | label = tf.zeros([num_classes])
128 | else:
129 | image = preprocess_fn_finetune(image)
130 | label = tf.one_hot(label, num_classes)
131 | return image, label, 1.0
132 |
133 | dataset = builder.as_dataset(
134 | split=FLAGS.train_split if is_training else FLAGS.eval_split,
135 | shuffle_files=is_training, as_supervised=True)
136 | if FLAGS.cache_dataset:
137 | dataset = dataset.cache()
138 | if is_training:
139 | buffer_multiplier = 50 if FLAGS.image_size <= 32 else 10
140 | dataset = dataset.shuffle(params['batch_size'] * buffer_multiplier)
141 | dataset = dataset.repeat(-1)
142 | dataset = dataset.map(map_fn,
143 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
144 | dataset = dataset.batch(params['batch_size'], drop_remainder=is_training)
145 | dataset = pad_to_batch(dataset, params['batch_size'])
146 | images, labels, mask = tf.data.make_one_shot_iterator(dataset).get_next()
147 |
148 | return images, {'labels': labels, 'mask': mask}
149 | return _input_fn
150 |
151 |
152 | def get_preprocess_fn(is_training, is_pretrain):
153 | """Get function that accepts an image and returns a preprocessed image."""
154 | # Disable test cropping for small images (e.g. CIFAR)
155 | if FLAGS.image_size <= 32:
156 | test_crop = False
157 | else:
158 | test_crop = True
159 | return functools.partial(
160 | data_util.preprocess_image,
161 | height=FLAGS.image_size,
162 | width=FLAGS.image_size,
163 | is_training=is_training,
164 | color_distort=is_pretrain,
165 | test_crop=test_crop)
166 |
--------------------------------------------------------------------------------
/CIFAR10/data_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Data preprocessing and augmentation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | from absl import flags
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 | FLAGS = flags.FLAGS
28 |
29 | CROP_PROPORTION = 0.875 # Standard for ImageNet.
30 |
31 |
32 | def random_apply(func, p, x):
33 | """Randomly apply function func to x with probability p."""
34 | return tf.cond(
35 | tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32),
36 | tf.cast(p, tf.float32)),
37 | lambda: func(x),
38 | lambda: x)
39 |
40 |
41 | def to_grayscale(image, keep_channels=True):
42 | image = tf.image.rgb_to_grayscale(image)
43 | if keep_channels:
44 | image = tf.tile(image, [1, 1, 3])
45 | return image
46 |
47 |
48 | def color_jitter(image,
49 | strength,
50 | random_order=True):
51 | """Distorts the color of the image.
52 |
53 | Args:
54 | image: The input image tensor.
55 | strength: the floating number for the strength of the color augmentation.
56 | random_order: A bool, specifying whether to randomize the jittering order.
57 |
58 | Returns:
59 | The distorted image tensor.
60 | """
61 | brightness = 0.8 * strength
62 | contrast = 0.8 * strength
63 | saturation = 0.8 * strength
64 | hue = 0.2 * strength
65 | if random_order:
66 | return color_jitter_rand(image, brightness, contrast, saturation, hue)
67 | else:
68 | return color_jitter_nonrand(image, brightness, contrast, saturation, hue)
69 |
70 |
71 | def color_jitter_nonrand(image, brightness=0, contrast=0, saturation=0, hue=0):
72 | """Distorts the color of the image (jittering order is fixed).
73 |
74 | Args:
75 | image: The input image tensor.
76 | brightness: A float, specifying the brightness for color jitter.
77 | contrast: A float, specifying the contrast for color jitter.
78 | saturation: A float, specifying the saturation for color jitter.
79 | hue: A float, specifying the hue for color jitter.
80 |
81 | Returns:
82 | The distorted image tensor.
83 | """
84 | with tf.name_scope('distort_color'):
85 | def apply_transform(i, x, brightness, contrast, saturation, hue):
86 | """Apply the i-th transformation."""
87 | if brightness != 0 and i == 0:
88 | x = tf.image.random_brightness(x, max_delta=brightness)
89 | elif contrast != 0 and i == 1:
90 | x = tf.image.random_contrast(
91 | x, lower=1-contrast, upper=1+contrast)
92 | elif saturation != 0 and i == 2:
93 | x = tf.image.random_saturation(
94 | x, lower=1-saturation, upper=1+saturation)
95 | elif hue != 0:
96 | x = tf.image.random_hue(x, max_delta=hue)
97 | return x
98 |
99 | for i in range(4):
100 | image = apply_transform(i, image, brightness, contrast, saturation, hue)
101 | image = tf.clip_by_value(image, 0., 1.)
102 | return image
103 |
104 |
105 | def color_jitter_rand(image, brightness=0, contrast=0, saturation=0, hue=0):
106 | """Distorts the color of the image (jittering order is random).
107 |
108 | Args:
109 | image: The input image tensor.
110 | brightness: A float, specifying the brightness for color jitter.
111 | contrast: A float, specifying the contrast for color jitter.
112 | saturation: A float, specifying the saturation for color jitter.
113 | hue: A float, specifying the hue for color jitter.
114 |
115 | Returns:
116 | The distorted image tensor.
117 | """
118 | with tf.name_scope('distort_color'):
119 | def apply_transform(i, x):
120 | """Apply the i-th transformation."""
121 | def brightness_foo():
122 | if brightness == 0:
123 | return x
124 | else:
125 | return tf.image.random_brightness(x, max_delta=brightness)
126 | def contrast_foo():
127 | if contrast == 0:
128 | return x
129 | else:
130 | return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
131 | def saturation_foo():
132 | if saturation == 0:
133 | return x
134 | else:
135 | return tf.image.random_saturation(
136 | x, lower=1-saturation, upper=1+saturation)
137 | def hue_foo():
138 | if hue == 0:
139 | return x
140 | else:
141 | return tf.image.random_hue(x, max_delta=hue)
142 | x = tf.cond(tf.less(i, 2),
143 | lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
144 | lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo))
145 | return x
146 |
147 | perm = tf.random_shuffle(tf.range(4))
148 | for i in range(4):
149 | image = apply_transform(perm[i], image)
150 | image = tf.clip_by_value(image, 0., 1.)
151 | return image
152 |
153 |
154 | def _compute_crop_shape(
155 | image_height, image_width, aspect_ratio, crop_proportion):
156 | """Compute aspect ratio-preserving shape for central crop.
157 |
158 | The resulting shape retains `crop_proportion` along one side and a proportion
159 | less than or equal to `crop_proportion` along the other side.
160 |
161 | Args:
162 | image_height: Height of image to be cropped.
163 | image_width: Width of image to be cropped.
164 | aspect_ratio: Desired aspect ratio (width / height) of output.
165 | crop_proportion: Proportion of image to retain along the less-cropped side.
166 |
167 | Returns:
168 | crop_height: Height of image after cropping.
169 | crop_width: Width of image after cropping.
170 | """
171 | image_width_float = tf.cast(image_width, tf.float32)
172 | image_height_float = tf.cast(image_height, tf.float32)
173 |
174 | def _requested_aspect_ratio_wider_than_image():
175 | crop_height = tf.cast(tf.rint(
176 | crop_proportion / aspect_ratio * image_width_float), tf.int32)
177 | crop_width = tf.cast(tf.rint(
178 | crop_proportion * image_width_float), tf.int32)
179 | return crop_height, crop_width
180 |
181 | def _image_wider_than_requested_aspect_ratio():
182 | crop_height = tf.cast(
183 | tf.rint(crop_proportion * image_height_float), tf.int32)
184 | crop_width = tf.cast(tf.rint(
185 | crop_proportion * aspect_ratio *
186 | image_height_float), tf.int32)
187 | return crop_height, crop_width
188 |
189 | return tf.cond(
190 | aspect_ratio > image_width_float / image_height_float,
191 | _requested_aspect_ratio_wider_than_image,
192 | _image_wider_than_requested_aspect_ratio)
193 |
194 |
195 | def center_crop(image, height, width, crop_proportion):
196 | """Crops to center of image and rescales to desired size.
197 |
198 | Args:
199 | image: Image Tensor to crop.
200 | height: Height of image to be cropped.
201 | width: Width of image to be cropped.
202 | crop_proportion: Proportion of image to retain along the less-cropped side.
203 |
204 | Returns:
205 | A `height` x `width` x channels Tensor holding a central crop of `image`.
206 | """
207 | shape = tf.shape(image)
208 | image_height = shape[0]
209 | image_width = shape[1]
210 | crop_height, crop_width = _compute_crop_shape(
211 | image_height, image_width, height / width, crop_proportion)
212 | offset_height = ((image_height - crop_height) + 1) // 2
213 | offset_width = ((image_width - crop_width) + 1) // 2
214 | image = tf.image.crop_to_bounding_box(
215 | image, offset_height, offset_width, crop_height, crop_width)
216 |
217 | image = tf.image.resize_bicubic([image], [height, width])[0]
218 |
219 | return image
220 |
221 |
222 | def distorted_bounding_box_crop(image,
223 | bbox,
224 | min_object_covered=0.1,
225 | aspect_ratio_range=(0.75, 1.33),
226 | area_range=(0.05, 1.0),
227 | max_attempts=100,
228 | scope=None):
229 | """Generates cropped_image using one of the bboxes randomly distorted.
230 |
231 | See `tf.image.sample_distorted_bounding_box` for more documentation.
232 |
233 | Args:
234 | image: `Tensor` of image data.
235 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
236 | where each coordinate is [0, 1) and the coordinates are arranged
237 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
238 | image.
239 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
240 | area of the image must contain at least this fraction of any bounding
241 | box supplied.
242 | aspect_ratio_range: An optional list of `float`s. The cropped area of the
243 | image must have an aspect ratio = width / height within this range.
244 | area_range: An optional list of `float`s. The cropped area of the image
245 | must contain a fraction of the supplied image within in this range.
246 | max_attempts: An optional `int`. Number of attempts at generating a cropped
247 | region of the image of the specified constraints. After `max_attempts`
248 | failures, return the entire image.
249 | scope: Optional `str` for name scope.
250 | Returns:
251 | (cropped image `Tensor`, distorted bbox `Tensor`).
252 | """
253 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
254 | shape = tf.shape(image)
255 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
256 | shape,
257 | bounding_boxes=bbox,
258 | min_object_covered=min_object_covered,
259 | aspect_ratio_range=aspect_ratio_range,
260 | area_range=area_range,
261 | max_attempts=max_attempts,
262 | use_image_if_no_bounding_boxes=True)
263 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box
264 |
265 | # Crop the image to the specified bounding box.
266 | offset_y, offset_x, _ = tf.unstack(bbox_begin)
267 | target_height, target_width, _ = tf.unstack(bbox_size)
268 | image = tf.image.crop_to_bounding_box(
269 | image, offset_y, offset_x, target_height, target_width)
270 |
271 | return image
272 |
273 |
274 | def crop_and_resize(image, height, width):
275 | """Make a random crop and resize it to height `height` and width `width`.
276 |
277 | Args:
278 | image: Tensor representing the image.
279 | height: Desired image height.
280 | width: Desired image width.
281 |
282 | Returns:
283 | A `height` x `width` x channels Tensor holding a random crop of `image`.
284 | """
285 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
286 | aspect_ratio = width / height
287 | image = distorted_bounding_box_crop(
288 | image,
289 | bbox,
290 | min_object_covered=0.1,
291 | aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
292 | area_range=(0.08, 1.0),
293 | max_attempts=100,
294 | scope=None)
295 | return tf.image.resize_bicubic([image], [height, width])[0]
296 |
297 |
298 | def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
299 | """Blurs the given image with separable convolution.
300 |
301 |
302 | Args:
303 | image: Tensor of shape [height, width, channels] and dtype float to blur.
304 | kernel_size: Integer Tensor for the size of the blur kernel. This is should
305 | be an odd number. If it is an even number, the actual kernel size will be
306 | size + 1.
307 | sigma: Sigma value for gaussian operator.
308 | padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
309 |
310 | Returns:
311 | A Tensor representing the blurred image.
312 | """
313 | radius = tf.to_int32(kernel_size / 2)
314 | kernel_size = radius * 2 + 1
315 | x = tf.to_float(tf.range(-radius, radius + 1))
316 | blur_filter = tf.exp(
317 | -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.to_float(sigma), 2.0)))
318 | blur_filter /= tf.reduce_sum(blur_filter)
319 | # One vertical and one horizontal filter.
320 | blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
321 | blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
322 | num_channels = tf.shape(image)[-1]
323 | blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
324 | blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
325 | expand_batch_dim = image.shape.ndims == 3
326 | if expand_batch_dim:
327 | # Tensorflow requires batched input to convolutions, which we can fake with
328 | # an extra dimension.
329 | image = tf.expand_dims(image, axis=0)
330 | blurred = tf.nn.depthwise_conv2d(
331 | image, blur_h, strides=[1, 1, 1, 1], padding=padding)
332 | blurred = tf.nn.depthwise_conv2d(
333 | blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
334 | if expand_batch_dim:
335 | blurred = tf.squeeze(blurred, axis=0)
336 | return blurred
337 |
338 |
339 | def random_crop_with_resize(image, height, width, p=1.0):
340 | """Randomly crop and resize an image.
341 |
342 | Args:
343 | image: `Tensor` representing an image of arbitrary size.
344 | height: Height of output image.
345 | width: Width of output image.
346 | p: Probability of applying this transformation.
347 |
348 | Returns:
349 | A preprocessed image `Tensor`.
350 | """
351 | def _transform(image): # pylint: disable=missing-docstring
352 | image = crop_and_resize(image, height, width)
353 | return image
354 | return random_apply(_transform, p=p, x=image)
355 |
356 |
357 | def random_color_jitter(image, p=1.0):
358 | def _transform(image):
359 | color_jitter_t = functools.partial(
360 | color_jitter, strength=FLAGS.color_jitter_strength)
361 | image = random_apply(color_jitter_t, p=0.8, x=image)
362 | return random_apply(to_grayscale, p=0.2, x=image)
363 | return random_apply(_transform, p=p, x=image)
364 |
365 |
366 | def random_blur(image, height, width, p=1.0):
367 | """Randomly blur an image.
368 |
369 | Args:
370 | image: `Tensor` representing an image of arbitrary size.
371 | height: Height of output image.
372 | width: Width of output image.
373 | p: probability of applying this transformation.
374 |
375 | Returns:
376 | A preprocessed image `Tensor`.
377 | """
378 | del width
379 | def _transform(image):
380 | sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
381 | return gaussian_blur(
382 | image, kernel_size=height//10, sigma=sigma, padding='SAME')
383 | return random_apply(_transform, p=p, x=image)
384 |
385 |
386 | def batch_random_blur(images_list, height, width, blur_probability=0.5):
387 | """Apply efficient batch data transformations.
388 |
389 | Args:
390 | images_list: a list of image tensors.
391 | height: the height of image.
392 | width: the width of image.
393 | blur_probability: the probaility to apply the blur operator.
394 |
395 | Returns:
396 | Preprocessed feature list.
397 | """
398 | def generate_selector(p, bsz):
399 | shape = [bsz, 1, 1, 1]
400 | selector = tf.cast(
401 | tf.less(tf.random_uniform(shape, 0, 1, dtype=tf.float32), p),
402 | tf.float32)
403 | return selector
404 |
405 | new_images_list = []
406 | for images in images_list:
407 | images_new = random_blur(images, height, width, p=1.)
408 | selector = generate_selector(blur_probability, tf.shape(images)[0])
409 | images = images_new * selector + images * (1 - selector)
410 | images = tf.clip_by_value(images, 0., 1.)
411 | new_images_list.append(images)
412 |
413 | return new_images_list
414 |
415 |
416 | def preprocess_for_train(image, height, width,
417 | color_distort=True, crop=True, flip=True):
418 | """Preprocesses the given image for training.
419 |
420 | Args:
421 | image: `Tensor` representing an image of arbitrary size.
422 | height: Height of output image.
423 | width: Width of output image.
424 | color_distort: Whether to apply the color distortion.
425 | crop: Whether to crop the image.
426 | flip: Whether or not to flip left and right of an image.
427 |
428 | Returns:
429 | A preprocessed image `Tensor`.
430 | """
431 | if crop:
432 | image = random_crop_with_resize(image, height, width)
433 | if flip:
434 | image = tf.image.random_flip_left_right(image)
435 | if color_distort:
436 | image = random_color_jitter(image)
437 | image = tf.reshape(image, [height, width, 3])
438 | image = tf.clip_by_value(image, 0., 1.)
439 | return image
440 |
441 |
442 | def preprocess_for_eval(image, height, width, crop=True):
443 | """Preprocesses the given image for evaluation.
444 |
445 | Args:
446 | image: `Tensor` representing an image of arbitrary size.
447 | height: Height of output image.
448 | width: Width of output image.
449 | crop: Whether or not to (center) crop the test images.
450 |
451 | Returns:
452 | A preprocessed image `Tensor`.
453 | """
454 | if crop:
455 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION)
456 | image = tf.reshape(image, [height, width, 3])
457 | image = tf.clip_by_value(image, 0., 1.)
458 | return image
459 |
460 |
461 | def preprocess_image(image, height, width, is_training=False,
462 | color_distort=True, test_crop=True):
463 | """Preprocesses the given image.
464 |
465 | Args:
466 | image: `Tensor` representing an image of arbitrary size.
467 | height: Height of output image.
468 | width: Width of output image.
469 | is_training: `bool` for whether the preprocessing is for training.
470 | color_distort: whether to apply the color distortion.
471 | test_crop: whether or not to extract a central crop of the images
472 | (as for standard ImageNet evaluation) during the evaluation.
473 |
474 | Returns:
475 | A preprocessed image `Tensor` of range [0, 1].
476 | """
477 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
478 | if is_training:
479 | return preprocess_for_train(image, height, width, color_distort)
480 | else:
481 | return preprocess_for_eval(image, height, width, test_crop)
482 |
--------------------------------------------------------------------------------
/CIFAR10/lars_optimizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Functions and classes related to optimization (weight updates)."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import re
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | EETA_DEFAULT = 0.001
27 |
28 |
29 | class LARSOptimizer(tf.train.Optimizer):
30 | """Layer-wise Adaptive Rate Scaling for large batch training.
31 |
32 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
33 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
34 | """
35 |
36 | def __init__(self,
37 | learning_rate,
38 | momentum=0.9,
39 | use_nesterov=False,
40 | weight_decay=0.0,
41 | exclude_from_weight_decay=None,
42 | exclude_from_layer_adaptation=None,
43 | classic_momentum=True,
44 | eeta=EETA_DEFAULT,
45 | name="LARSOptimizer"):
46 | """Constructs a LARSOptimizer.
47 |
48 | Args:
49 | learning_rate: A `float` for learning rate.
50 | momentum: A `float` for momentum.
51 | use_nesterov: A 'Boolean' for whether to use nesterov momentum.
52 | weight_decay: A `float` for weight decay.
53 | exclude_from_weight_decay: A list of `string` for variable screening, if
54 | any of the string appears in a variable's name, the variable will be
55 | excluded for computing weight decay. For example, one could specify
56 | the list like ['batch_normalization', 'bias'] to exclude BN and bias
57 | from weight decay.
58 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
59 | for layer adaptation. If it is None, it will be defaulted the same as
60 | exclude_from_weight_decay.
61 | classic_momentum: A `boolean` for whether to use classic (or popular)
62 | momentum. The learning rate is applied during momeuntum update in
63 | classic momentum, but after momentum for popular momentum.
64 | eeta: A `float` for scaling of learning rate when computing trust ratio.
65 | name: The name for the scope.
66 | """
67 | super(LARSOptimizer, self).__init__(False, name)
68 |
69 | self.learning_rate = learning_rate
70 | self.momentum = momentum
71 | self.weight_decay = weight_decay
72 | self.use_nesterov = use_nesterov
73 | self.classic_momentum = classic_momentum
74 | self.eeta = eeta
75 | self.exclude_from_weight_decay = exclude_from_weight_decay
76 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
77 | # arg is None.
78 | if exclude_from_layer_adaptation:
79 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
80 | else:
81 | self.exclude_from_layer_adaptation = exclude_from_weight_decay
82 |
83 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
84 | if global_step is None:
85 | global_step = tf.train.get_or_create_global_step()
86 | new_global_step = global_step + 1
87 |
88 | assignments = []
89 | for (grad, param) in grads_and_vars:
90 | if grad is None or param is None:
91 | continue
92 |
93 | param_name = param.op.name
94 |
95 | v = tf.get_variable(
96 | name=param_name + "/Momentum",
97 | shape=param.shape.as_list(),
98 | dtype=tf.float32,
99 | trainable=False,
100 | initializer=tf.zeros_initializer())
101 |
102 | if self._use_weight_decay(param_name):
103 | grad += self.weight_decay * param
104 |
105 | if self.classic_momentum:
106 | trust_ratio = 1.0
107 | if self._do_layer_adaptation(param_name):
108 | w_norm = tf.norm(param, ord=2)
109 | g_norm = tf.norm(grad, ord=2)
110 | trust_ratio = tf.where(
111 | tf.greater(w_norm, 0), tf.where(
112 | tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm),
113 | 1.0),
114 | 1.0)
115 | scaled_lr = self.learning_rate * trust_ratio
116 |
117 | next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
118 | if self.use_nesterov:
119 | update = tf.multiply(self.momentum, next_v) + scaled_lr * grad
120 | else:
121 | update = next_v
122 | next_param = param - update
123 | else:
124 | next_v = tf.multiply(self.momentum, v) + grad
125 | if self.use_nesterov:
126 | update = tf.multiply(self.momentum, next_v) + grad
127 | else:
128 | update = next_v
129 |
130 | trust_ratio = 1.0
131 | if self._do_layer_adaptation(param_name):
132 | w_norm = tf.norm(param, ord=2)
133 | v_norm = tf.norm(update, ord=2)
134 | trust_ratio = tf.where(
135 | tf.greater(w_norm, 0), tf.where(
136 | tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm),
137 | 1.0),
138 | 1.0)
139 | scaled_lr = trust_ratio * self.learning_rate
140 | next_param = param - scaled_lr * update
141 |
142 | assignments.extend(
143 | [param.assign(next_param),
144 | v.assign(next_v),
145 | global_step.assign(new_global_step)])
146 | return tf.group(*assignments, name=name)
147 |
148 | def _use_weight_decay(self, param_name):
149 | """Whether to use L2 weight decay for `param_name`."""
150 | if not self.weight_decay:
151 | return False
152 | if self.exclude_from_weight_decay:
153 | for r in self.exclude_from_weight_decay:
154 | if re.search(r, param_name) is not None:
155 | return False
156 | return True
157 |
158 | def _do_layer_adaptation(self, param_name):
159 | """Whether to do layer-wise learning rate adaptation for `param_name`."""
160 | if self.exclude_from_layer_adaptation:
161 | for r in self.exclude_from_layer_adaptation:
162 | if re.search(r, param_name) is not None:
163 | return False
164 | return True
165 |
--------------------------------------------------------------------------------
/CIFAR10/model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Model specification for SimCLR."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 |
24 | import data_util as data_util
25 | from lars_optimizer import LARSOptimizer
26 | import model_util as model_util
27 | import objective as obj_lib
28 |
29 | import tensorflow.compat.v1 as tf
30 | import tensorflow.compat.v2 as tf2
31 |
32 | FLAGS = flags.FLAGS
33 |
34 |
35 | def build_model_fn(model, num_classes, num_train_examples):
36 | """Build model function."""
37 | def model_fn(features, labels, mode, params=None):
38 | """Build model and optimizer."""
39 | is_training = mode == tf.estimator.ModeKeys.TRAIN
40 |
41 | # Check training mode.
42 | if FLAGS.train_mode == 'pretrain':
43 | num_transforms = 2
44 | if FLAGS.fine_tune_after_block > -1:
45 | raise ValueError('Does not support layer freezing during pretraining,'
46 | 'should set fine_tune_after_block<=-1 for safety.')
47 | elif FLAGS.train_mode == 'finetune':
48 | num_transforms = 1
49 | else:
50 | raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))
51 |
52 | # Split channels, and optionally apply extra batched augmentation.
53 | features_list = tf.split(
54 | features, num_or_size_splits=num_transforms, axis=-1)
55 | if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
56 | features_list = data_util.batch_random_blur(
57 | features_list, FLAGS.image_size, FLAGS.image_size)
58 | features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c)
59 |
60 | # Base network forward pass.
61 | with tf.variable_scope('base_model'):
62 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
63 | # Finetune just supervised (linear) head will not update BN stats.
64 | model_train_mode = False
65 | else:
66 | # Pretrain or finetuen anything else will update BN stats.
67 | model_train_mode = is_training
68 | hiddens = model(features, is_training=model_train_mode)
69 |
70 | # Add head and loss.
71 | if FLAGS.train_mode == 'pretrain':
72 | tpu_context = params['context'] if 'context' in params else None
73 | hiddens_proj = model_util.projection_head(hiddens, is_training)
74 | contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
75 | hiddens_proj,
76 | hidden_norm=FLAGS.hidden_norm,
77 | temperature=FLAGS.temperature,
78 | tpu_context=tpu_context if is_training else None)
79 | # note the inv_pred_loss is performed on the representation, not on the projection head
80 | inv_pred_loss = obj_lib.add_inv_pred_loss(
81 | hiddens,
82 | inv_pred_coeff=FLAGS.inv_pred_coeff)
83 | logits_sup = tf.zeros([params['batch_size'], num_classes])
84 | else:
85 | contrast_loss = tf.zeros([])
86 | inv_pred_loss = tf.zeros([])
87 | logits_con = tf.zeros([params['batch_size'], 10])
88 | labels_con = tf.zeros([params['batch_size'], 10])
89 | logits_sup = model_util.supervised_head(
90 | hiddens, num_classes, is_training)
91 | obj_lib.add_supervised_loss(
92 | labels=labels['labels'],
93 | logits=logits_sup,
94 | weights=labels['mask'])
95 |
96 | # Add weight decay to loss, for non-LARS optimizers.
97 | model_util.add_weight_decay(adjust_per_optimizer=True)
98 | loss = tf.losses.get_total_loss()
99 |
100 | if FLAGS.train_mode == 'pretrain':
101 | variables_to_train = tf.trainable_variables()
102 | else:
103 | collection_prefix = 'trainable_variables_inblock_'
104 | variables_to_train = []
105 | for j in range(FLAGS.fine_tune_after_block + 1, 6):
106 | variables_to_train += tf.get_collection(collection_prefix + str(j))
107 | assert variables_to_train, 'variables_to_train shouldn\'t be empty!'
108 |
109 | tf.logging.info('===============Variables to train (begin)===============')
110 | tf.logging.info(variables_to_train)
111 | tf.logging.info('================Variables to train (end)================')
112 |
113 | learning_rate = model_util.learning_rate_schedule(
114 | FLAGS.learning_rate, num_train_examples)
115 |
116 | if is_training:
117 | if FLAGS.train_summary_steps > 0:
118 | # Compute stats for the summary.
119 | prob_con = tf.nn.softmax(logits_con)
120 | entropy_con = - tf.reduce_mean(
121 | tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))
122 |
123 | summary_writer = tf2.summary.create_file_writer(FLAGS.model_dir)
124 | # TODO(iamtingchen): remove this control_dependencies in the future.
125 | with tf.control_dependencies([summary_writer.init()]):
126 | with summary_writer.as_default():
127 | should_record = tf.math.equal(
128 | tf.math.floormod(tf.train.get_global_step(),
129 | FLAGS.train_summary_steps), 0)
130 | with tf2.summary.record_if(should_record):
131 | contrast_acc = tf.equal(
132 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1))
133 | contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
134 | label_acc = tf.equal(
135 | tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1))
136 | label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
137 | tf2.summary.scalar(
138 | 'train_contrast_loss',
139 | contrast_loss,
140 | step=tf.train.get_global_step())
141 | tf2.summary.scalar(
142 | 'train_inv_pred_loss',
143 | inv_pred_loss,
144 | step=tf.train.get_global_step())
145 | tf2.summary.scalar(
146 | 'train_contrast_acc',
147 | contrast_acc,
148 | step=tf.train.get_global_step())
149 | tf2.summary.scalar(
150 | 'train_label_accuracy',
151 | label_acc,
152 | step=tf.train.get_global_step())
153 | tf2.summary.scalar(
154 | 'contrast_entropy',
155 | entropy_con,
156 | step=tf.train.get_global_step())
157 | tf2.summary.scalar(
158 | 'learning_rate', learning_rate,
159 | step=tf.train.get_global_step())
160 | tf2.summary.scalar(
161 | 'input_mean',
162 | tf.reduce_mean(features),
163 | step=tf.train.get_global_step())
164 | tf2.summary.scalar(
165 | 'input_max',
166 | tf.reduce_max(features),
167 | step=tf.train.get_global_step())
168 | tf2.summary.scalar(
169 | 'input_min',
170 | tf.reduce_min(features),
171 | step=tf.train.get_global_step())
172 | tf2.summary.scalar(
173 | 'num_labels',
174 | tf.reduce_mean(tf.reduce_sum(labels['labels'], -1)),
175 | step=tf.train.get_global_step())
176 |
177 | if FLAGS.optimizer == 'momentum':
178 | optimizer = tf.train.MomentumOptimizer(
179 | learning_rate, FLAGS.momentum, use_nesterov=True)
180 | elif FLAGS.optimizer == 'adam':
181 | optimizer = tf.train.AdamOptimizer(
182 | learning_rate)
183 | elif FLAGS.optimizer == 'lars':
184 | optimizer = LARSOptimizer(
185 | learning_rate,
186 | momentum=FLAGS.momentum,
187 | weight_decay=FLAGS.weight_decay,
188 | exclude_from_weight_decay=['batch_normalization', 'bias'])
189 | else:
190 | raise ValueError('Unknown optimizer {}'.format(FLAGS.optimizer))
191 |
192 | if FLAGS.use_tpu:
193 | optimizer = tf.tpu.CrossShardOptimizer(optimizer)
194 |
195 | control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
196 | if FLAGS.train_summary_steps > 0:
197 | control_deps.extend(tf.summary.all_v2_summary_ops())
198 | with tf.control_dependencies(control_deps):
199 | train_op = optimizer.minimize(
200 | loss, global_step=tf.train.get_or_create_global_step(),
201 | var_list=variables_to_train)
202 |
203 | if FLAGS.checkpoint:
204 | def scaffold_fn():
205 | """Scaffold function to restore non-logits vars from checkpoint."""
206 | tf.train.init_from_checkpoint(
207 | FLAGS.checkpoint,
208 | {v.op.name: v.op.name
209 | for v in tf.global_variables(FLAGS.variable_schema)})
210 |
211 | if FLAGS.zero_init_logits_layer:
212 | # Init op that initializes output layer parameters to zeros.
213 | output_layer_parameters = [
214 | var for var in tf.trainable_variables() if var.name.startswith(
215 | 'head_supervised')]
216 | tf.logging.info('Initializing output layer parameters %s to zero',
217 | [x.op.name for x in output_layer_parameters])
218 | with tf.control_dependencies([tf.global_variables_initializer()]):
219 | init_op = tf.group([
220 | tf.assign(x, tf.zeros_like(x))
221 | for x in output_layer_parameters])
222 | return tf.train.Scaffold(init_op=init_op)
223 | else:
224 | return tf.train.Scaffold()
225 | else:
226 | scaffold_fn = None
227 |
228 | return tf.estimator.tpu.TPUEstimatorSpec(
229 | mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)
230 | else:
231 |
232 | def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
233 | **kws):
234 | """Inner metric function."""
235 | metrics = {k: tf.metrics.mean(v, weights=mask)
236 | for k, v in kws.items()}
237 | metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
238 | tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
239 | weights=mask)
240 | metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
241 | tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
242 | metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
243 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1),
244 | weights=mask)
245 | metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
246 | tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)
247 | return metrics
248 |
249 | metrics = {
250 | 'logits_sup': logits_sup,
251 | 'labels_sup': labels['labels'],
252 | 'logits_con': logits_con,
253 | 'labels_con': labels_con,
254 | 'mask': labels['mask'],
255 | 'contrast_loss': tf.fill((params['batch_size'],), contrast_loss),
256 | 'inv_pred_loss': tf.fill((params['batch_size'],), inv_pred_loss),
257 | 'regularization_loss': tf.fill((params['batch_size'],),
258 | tf.losses.get_regularization_loss()),
259 | }
260 |
261 | return tf.estimator.tpu.TPUEstimatorSpec(
262 | mode=mode,
263 | loss=loss,
264 | eval_metrics=(metric_fn, metrics),
265 | scaffold_fn=None)
266 |
267 | return model_fn
268 |
--------------------------------------------------------------------------------
/CIFAR10/model_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Network architectures related functions used in SimCLR."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 |
24 | import resnet
25 |
26 | import tensorflow.compat.v1 as tf
27 |
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | def add_weight_decay(adjust_per_optimizer=True):
32 | """Compute weight decay from flags."""
33 | if adjust_per_optimizer and 'lars' in FLAGS.optimizer:
34 | # Weight decay are taking care of by optimizer for these cases.
35 | return
36 |
37 | l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables()
38 | if 'batch_normalization' not in v.name]
39 | tf.losses.add_loss(
40 | FLAGS.weight_decay * tf.add_n(l2_losses),
41 | tf.GraphKeys.REGULARIZATION_LOSSES)
42 |
43 |
44 | def get_train_steps(num_examples):
45 | """Determine the number of training steps."""
46 | return FLAGS.train_steps or (
47 | num_examples * FLAGS.train_epochs // FLAGS.train_batch_size + 1)
48 |
49 |
50 | def learning_rate_schedule(base_learning_rate, num_examples):
51 | """Build learning rate schedule."""
52 | global_step = tf.train.get_or_create_global_step()
53 | warmup_steps = int(round(
54 | FLAGS.warmup_epochs * num_examples // FLAGS.train_batch_size))
55 | scaled_lr = base_learning_rate * FLAGS.train_batch_size / 256.
56 | learning_rate = (tf.to_float(global_step) / int(warmup_steps) * scaled_lr
57 | if warmup_steps else scaled_lr)
58 |
59 | # Cosine decay learning rate schedule
60 | total_steps = get_train_steps(num_examples)
61 | learning_rate = tf.where(
62 | global_step < warmup_steps, learning_rate,
63 | tf.train.cosine_decay(
64 | scaled_lr,
65 | global_step - warmup_steps,
66 | total_steps - warmup_steps))
67 |
68 | return learning_rate
69 |
70 |
71 | def linear_layer(x,
72 | is_training,
73 | num_classes,
74 | use_bias=True,
75 | use_bn=False,
76 | name='linear_layer'):
77 | """Linear head for linear evaluation.
78 |
79 | Args:
80 | x: hidden state tensor of shape (bsz, dim).
81 | is_training: boolean indicator for training or test.
82 | num_classes: number of classes.
83 | use_bias: whether or not to use bias.
84 | use_bn: whether or not to use BN for output units.
85 | name: the name for variable scope.
86 |
87 | Returns:
88 | logits of shape (bsz, num_classes)
89 | """
90 | assert x.shape.ndims == 2, x.shape
91 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
92 | x = tf.layers.dense(
93 | inputs=x,
94 | units=num_classes,
95 | use_bias=use_bias and not use_bn,
96 | kernel_initializer=tf.random_normal_initializer(stddev=.01))
97 | if use_bn:
98 | x = resnet.batch_norm_relu(x, is_training, relu=False, center=use_bias)
99 | x = tf.identity(x, '%s_out' % name)
100 | return x
101 |
102 |
103 | def projection_head(hiddens, is_training, name='head_contrastive'):
104 | """Head for projecting hiddens fo contrastive loss."""
105 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
106 | if FLAGS.head_proj_mode == 'none':
107 | pass # directly use the output hiddens as hiddens
108 | elif FLAGS.head_proj_mode == 'linear':
109 | hiddens = linear_layer(
110 | hiddens, is_training, FLAGS.head_proj_dim,
111 | use_bias=False, use_bn=True, name='l_0')
112 | elif FLAGS.head_proj_mode == 'nonlinear':
113 | hiddens = linear_layer(
114 | hiddens, is_training, hiddens.shape[-1],
115 | use_bias=True, use_bn=True, name='nl_0')
116 | for j in range(1, FLAGS.num_nlh_layers + 1):
117 | hiddens = tf.nn.relu(hiddens)
118 | hiddens = linear_layer(
119 | hiddens, is_training, FLAGS.head_proj_dim,
120 | use_bias=False, use_bn=True, name='nl_%d'%j)
121 | else:
122 | raise ValueError('Unknown head projection mode {}'.format(
123 | FLAGS.head_proj_mode))
124 | return hiddens
125 |
126 |
127 | def supervised_head(hiddens, num_classes, is_training, name='head_supervised'):
128 | """Add supervised head & also add its variables to inblock collection."""
129 | with tf.variable_scope(name):
130 | logits = linear_layer(hiddens, is_training, num_classes)
131 | for var in tf.trainable_variables():
132 | if var.name.startswith(name):
133 | tf.add_to_collection('trainable_variables_inblock_5', var)
134 | return logits
135 |
--------------------------------------------------------------------------------
/CIFAR10/objective.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Contrastive loss functions."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import
27 |
28 | FLAGS = flags.FLAGS
29 |
30 | LARGE_NUM = 1e9
31 |
32 |
33 | def add_supervised_loss(labels, logits, weights, **kwargs):
34 | """Compute loss for model and add it to loss collection."""
35 | return tf.losses.softmax_cross_entropy(labels, logits, weights, **kwargs)
36 |
37 | def add_inv_pred_loss(repres, inv_pred_coeff=0.0, weights=1.0):
38 | """Compute inv predictive loss for model, not sure if it is compatible with TPU.
39 |
40 | Args:
41 | repres: representation vector (`Tensor`) of shape (bsz, dim).
42 | inv_pred_coeff: coeff of inverse predictive loss.
43 |
44 | Returns:
45 | A loss scalar.
46 | """
47 | # Get representation1 and representation2.
48 | representation1, representation2 = tf.split(repres, 2, 0)
49 | inv_pred_loss = tf.losses.mean_squared_error(representation1, representation2, weights=weights*inv_pred_coeff)
50 |
51 | return inv_pred_loss
52 |
53 | def tpu_cross_replica_concat(tensor, tpu_context=None):
54 | """Reduce a concatenation of the `tensor` across TPU cores.
55 |
56 | Args:
57 | tensor: tensor to concatenate.
58 | tpu_context: A `TPUContext`. If not set, CPU execution is assumed.
59 |
60 | Returns:
61 | Tensor of the same rank as `tensor` with first dimension `num_replicas`
62 | times larger.
63 | """
64 | if tpu_context is None or tpu_context.num_replicas <= 1:
65 | return tensor
66 |
67 | num_replicas = tpu_context.num_replicas
68 |
69 | with tf.name_scope('tpu_cross_replica_concat'):
70 | # This creates a tensor that is like the input tensor but has an added
71 | # replica dimension as the outermost dimension. On each replica it will
72 | # contain the local values and zeros for all other values that need to be
73 | # fetched from other replicas.
74 | ext_tensor = tf.scatter_nd(
75 | indices=[[xla.replica_id()]],
76 | updates=[tensor],
77 | shape=[num_replicas] + tensor.shape.as_list())
78 |
79 | # As every value is only present on one replica and 0 in all others, adding
80 | # them all together will result in the full tensor on all replicas.
81 | ext_tensor = tf.tpu.cross_replica_sum(ext_tensor)
82 |
83 | # Flatten the replica dimension.
84 | # The first dimension size will be: tensor.shape[0] * num_replicas
85 | # Using [-1] trick to support also scalar input.
86 | return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
87 |
88 |
89 | def add_contrastive_loss(hidden,
90 | hidden_norm=True,
91 | temperature=1.0,
92 | tpu_context=None,
93 | weights=1.0):
94 | """Compute loss for model.
95 |
96 | Args:
97 | hidden: hidden vector (`Tensor`) of shape (bsz, dim).
98 | hidden_norm: whether or not to use normalization on the hidden vector.
99 | temperature: a `floating` number for temperature scaling.
100 | tpu_context: context information for tpu.
101 | weights: a weighting number or vector.
102 |
103 | Returns:
104 | A loss scalar.
105 | The logits for contrastive prediction task.
106 | The labels for contrastive prediction task.
107 | """
108 | # Get (normalized) hidden1 and hidden2.
109 | if hidden_norm:
110 | hidden = tf.math.l2_normalize(hidden, -1)
111 | hidden1, hidden2 = tf.split(hidden, 2, 0)
112 | batch_size = tf.shape(hidden1)[0]
113 |
114 | # Gather hidden1/hidden2 across replicas and create local labels.
115 | if tpu_context is not None:
116 | hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context)
117 | hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context)
118 | enlarged_batch_size = tf.shape(hidden1_large)[0]
119 | # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
120 | replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
121 | labels_idx = tf.range(batch_size) + replica_id * batch_size
122 | labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
123 | masks = tf.one_hot(labels_idx, enlarged_batch_size)
124 | else:
125 | hidden1_large = hidden1
126 | hidden2_large = hidden2
127 | labels = tf.one_hot(tf.range(batch_size), batch_size * 2) # (bz, bz*2)
128 | masks = tf.one_hot(tf.range(batch_size), batch_size) # (bz, bz)
129 |
130 | logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
131 | logits_aa = logits_aa - masks * LARGE_NUM # (bz, bz)
132 | logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
133 | logits_bb = logits_bb - masks * LARGE_NUM # (bz, bz)
134 | logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature # (bz, bz)
135 | logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature # (bz, bz)
136 |
137 | loss_a = tf.losses.softmax_cross_entropy(
138 | labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)
139 | loss_b = tf.losses.softmax_cross_entropy(
140 | labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)
141 | loss = loss_a + loss_b
142 |
143 | return loss, logits_ab, labels
144 |
--------------------------------------------------------------------------------
/CIFAR10/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow-datasets==2.1.0
3 | tensorflow-hub==0.7.0
4 |
--------------------------------------------------------------------------------
/CIFAR10/run.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The main training pipeline."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import json
23 | import math
24 | import os
25 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
26 | from absl import app
27 | from absl import flags
28 |
29 | import resnet
30 | import data as data_lib
31 | import model as model_lib
32 | import model_util as model_util
33 |
34 | import tensorflow.compat.v1 as tf
35 | import tensorflow_datasets as tfds
36 | import tensorflow_hub as hub
37 |
38 | tf.logging.set_verbosity(tf.logging.ERROR)
39 |
40 | FLAGS = flags.FLAGS
41 |
42 |
43 | flags.DEFINE_float(
44 | 'learning_rate', 0.3,
45 | 'Initial learning rate per batch size of 256.')
46 |
47 | flags.DEFINE_float(
48 | 'warmup_epochs', 10,
49 | 'Number of epochs of warmup.')
50 |
51 | flags.DEFINE_float(
52 | 'weight_decay', 1e-6,
53 | 'Amount of weight decay to use.')
54 |
55 | flags.DEFINE_float(
56 | 'batch_norm_decay', 0.9,
57 | 'Batch norm decay parameter.')
58 |
59 | flags.DEFINE_integer(
60 | 'train_batch_size', 512,
61 | 'Batch size for training.')
62 |
63 | flags.DEFINE_string(
64 | 'train_split', 'train',
65 | 'Split for training.')
66 |
67 | flags.DEFINE_integer(
68 | 'train_epochs', 100,
69 | 'Number of epochs to train for.')
70 |
71 | flags.DEFINE_integer(
72 | 'train_steps', 0,
73 | 'Number of steps to train for. If provided, overrides train_epochs.')
74 |
75 | flags.DEFINE_integer(
76 | 'eval_batch_size', 256,
77 | 'Batch size for eval.')
78 |
79 | flags.DEFINE_integer(
80 | 'train_summary_steps', 100,
81 | 'Steps before saving training summaries. If 0, will not save.')
82 |
83 | flags.DEFINE_integer(
84 | 'checkpoint_epochs', 1,
85 | 'Number of epochs between checkpoints/summaries.')
86 |
87 | flags.DEFINE_integer(
88 | 'checkpoint_steps', 0,
89 | 'Number of steps between checkpoints/summaries. If provided, overrides '
90 | 'checkpoint_epochs.')
91 |
92 | flags.DEFINE_string(
93 | 'eval_split', 'validation',
94 | 'Split for evaluation.')
95 |
96 | flags.DEFINE_string(
97 | 'dataset', 'imagenet2012',
98 | 'Name of a dataset.')
99 |
100 | flags.DEFINE_bool(
101 | 'cache_dataset', False,
102 | 'Whether to cache the entire dataset in memory. If the dataset is '
103 | 'ImageNet, this is a very bad idea, but for smaller datasets it can '
104 | 'improve performance.')
105 |
106 | flags.DEFINE_enum(
107 | 'mode', 'train', ['train', 'eval', 'train_then_eval'],
108 | 'Whether to perform training or evaluation.')
109 |
110 | flags.DEFINE_enum(
111 | 'train_mode', 'pretrain', ['pretrain', 'finetune'],
112 | 'The train mode controls different objectives and trainable components.')
113 |
114 | flags.DEFINE_string(
115 | 'checkpoint', None,
116 | 'Loading from the given checkpoint for continued training or fine-tuning.')
117 |
118 | flags.DEFINE_string(
119 | 'variable_schema', '?!global_step',
120 | 'This defines whether some variable from the checkpoint should be loaded.')
121 |
122 | flags.DEFINE_bool(
123 | 'zero_init_logits_layer', False,
124 | 'If True, zero initialize layers after avg_pool for supervised learning.')
125 |
126 | flags.DEFINE_integer(
127 | 'fine_tune_after_block', -1,
128 | 'The layers after which block that we will fine-tune. -1 means fine-tuning '
129 | 'everything. 0 means fine-tuning after stem block. 4 means fine-tuning '
130 | 'just the linera head.')
131 |
132 | flags.DEFINE_string(
133 | 'master', None,
134 | 'Address/name of the TensorFlow master to use. By default, use an '
135 | 'in-process master.')
136 |
137 | flags.DEFINE_string(
138 | 'model_dir', None,
139 | 'Model directory for training.')
140 |
141 | flags.DEFINE_string(
142 | 'data_dir', None,
143 | 'Directory where dataset is stored.')
144 |
145 | flags.DEFINE_bool(
146 | 'use_tpu', True,
147 | 'Whether to run on TPU.')
148 |
149 | tf.flags.DEFINE_string(
150 | 'tpu_name', None,
151 | 'The Cloud TPU to use for training. This should be either the name '
152 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
153 | 'url.')
154 |
155 | tf.flags.DEFINE_string(
156 | 'tpu_zone', None,
157 | '[Optional] GCE zone where the Cloud TPU is located in. If not '
158 | 'specified, we will attempt to automatically detect the GCE project from '
159 | 'metadata.')
160 |
161 | tf.flags.DEFINE_string(
162 | 'gcp_project', None,
163 | '[Optional] Project name for the Cloud TPU-enabled project. If not '
164 | 'specified, we will attempt to automatically detect the GCE project from '
165 | 'metadata.')
166 |
167 | flags.DEFINE_enum(
168 | 'optimizer', 'lars', ['momentum', 'adam', 'lars'],
169 | 'Optimizer to use.')
170 |
171 | flags.DEFINE_float(
172 | 'momentum', 0.9,
173 | 'Momentum parameter.')
174 |
175 | flags.DEFINE_string(
176 | 'eval_name', None,
177 | 'Name for eval.')
178 |
179 | flags.DEFINE_integer(
180 | 'keep_checkpoint_max', 5,
181 | 'Maximum number of checkpoints to keep.')
182 |
183 | flags.DEFINE_integer(
184 | 'keep_hub_module_max', 1,
185 | 'Maximum number of Hub modules to keep.')
186 |
187 | flags.DEFINE_float(
188 | 'temperature', 0.1,
189 | 'Temperature parameter for contrastive loss.')
190 |
191 | flags.DEFINE_boolean(
192 | 'hidden_norm', True,
193 | 'Whether normalize hidden representation in contrastive loss.')
194 |
195 | flags.DEFINE_enum(
196 | 'head_proj_mode', 'nonlinear', ['none', 'linear', 'nonlinear'],
197 | 'How the head projection is done.')
198 |
199 | flags.DEFINE_integer(
200 | 'head_proj_dim', 128,
201 | 'Number of head projection dimension.')
202 |
203 | flags.DEFINE_integer(
204 | 'num_nlh_layers', 1,
205 | 'Number of non-linear head layers.')
206 |
207 | flags.DEFINE_boolean(
208 | 'global_bn', True,
209 | 'Whether to aggregate BN statistics across distributed cores.')
210 |
211 | flags.DEFINE_integer(
212 | 'width_multiplier', 1,
213 | 'Multiplier to change width of network.')
214 |
215 | flags.DEFINE_integer(
216 | 'resnet_depth', 50,
217 | 'Depth of ResNet.')
218 |
219 | flags.DEFINE_integer(
220 | 'image_size', 224,
221 | 'Input image size.')
222 |
223 | flags.DEFINE_float(
224 | 'color_jitter_strength', 1.0,
225 | 'The strength of color jittering.')
226 |
227 | flags.DEFINE_float(
228 | 'inv_pred_coeff', 0.05,
229 | 'Coeff for inverse predictive coding.')
230 |
231 | flags.DEFINE_boolean(
232 | 'use_blur', True,
233 | 'Whether or not to use Gaussian blur for augmentation during pretraining.')
234 |
235 |
236 | def build_hub_module(model, num_classes, global_step, checkpoint_path):
237 | """Create TF-Hub module."""
238 |
239 | tags_and_args = [
240 | # The default graph is built with batch_norm, dropout etc. in inference
241 | # mode. This graph version is good for inference, not training.
242 | ([], {'is_training': False}),
243 | # A separate "train" graph builds batch_norm, dropout etc. in training
244 | # mode.
245 | (['train'], {'is_training': True}),
246 | ]
247 |
248 | def module_fn(is_training):
249 | """Function that builds TF-Hub module."""
250 | endpoints = {}
251 | inputs = tf.placeholder(
252 | tf.float32, [None, FLAGS.image_size, FLAGS.image_size, 3])
253 | with tf.variable_scope('base_model', reuse=tf.AUTO_REUSE):
254 | hiddens = model(inputs, is_training)
255 | for v in ['initial_conv', 'initial_max_pool', 'block_group1',
256 | 'block_group2', 'block_group3', 'block_group4',
257 | 'final_avg_pool']:
258 | endpoints[v] = tf.get_default_graph().get_tensor_by_name(
259 | 'base_model/{}:0'.format(v))
260 | if FLAGS.train_mode == 'pretrain':
261 | hiddens_proj = model_util.projection_head(hiddens, is_training)
262 | endpoints['proj_head_input'] = hiddens
263 | endpoints['proj_head_output'] = hiddens_proj
264 | else:
265 | logits_sup = model_util.supervised_head(
266 | hiddens, num_classes, is_training)
267 | endpoints['logits_sup'] = logits_sup
268 | hub.add_signature(inputs=dict(images=inputs),
269 | outputs=dict(endpoints, default=hiddens))
270 |
271 | # Drop the non-supported non-standard graph collection.
272 | drop_collections = ['trainable_variables_inblock_%d'%d for d in range(6)]
273 | spec = hub.create_module_spec(module_fn, tags_and_args, drop_collections)
274 | hub_export_dir = os.path.join(FLAGS.model_dir, 'hub')
275 | checkpoint_export_dir = os.path.join(hub_export_dir, str(global_step))
276 | if tf.io.gfile.exists(checkpoint_export_dir):
277 | # Do not save if checkpoint already saved.
278 | tf.io.gfile.rmtree(checkpoint_export_dir)
279 | spec.export(
280 | checkpoint_export_dir,
281 | checkpoint_path=checkpoint_path,
282 | name_transform_fn=None)
283 |
284 | if FLAGS.keep_hub_module_max > 0:
285 | # Delete old exported Hub modules.
286 | exported_steps = []
287 | for subdir in tf.io.gfile.listdir(hub_export_dir):
288 | if not subdir.isdigit():
289 | continue
290 | exported_steps.append(int(subdir))
291 | exported_steps.sort()
292 | for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]:
293 | tf.io.gfile.rmtree(os.path.join(hub_export_dir, str(step_to_delete)))
294 |
295 |
296 | def perform_evaluation(estimator, input_fn, eval_steps, model, num_classes,
297 | checkpoint_path=None):
298 | """Perform evaluation.
299 |
300 | Args:
301 | estimator: TPUEstimator instance.
302 | input_fn: Input function for estimator.
303 | eval_steps: Number of steps for evaluation.
304 | model: Instance of transfer_learning.models.Model.
305 | num_classes: Number of classes to build model for.
306 | checkpoint_path: Path of checkpoint to evaluate.
307 |
308 | Returns:
309 | result: A Dict of metrics and their values.
310 | """
311 | if not checkpoint_path:
312 | checkpoint_path = estimator.latest_checkpoint()
313 | result = estimator.evaluate(
314 | input_fn, eval_steps, checkpoint_path=checkpoint_path,
315 | name=FLAGS.eval_name)
316 |
317 | # Record results as JSON.
318 | result_json_path = os.path.join(FLAGS.model_dir, 'result.json')
319 | with tf.io.gfile.GFile(result_json_path, 'w') as f:
320 | json.dump({k: float(v) for k, v in result.items()}, f)
321 | result_json_path = os.path.join(
322 | FLAGS.model_dir, 'result_%d.json'%result['global_step'])
323 | with tf.io.gfile.GFile(result_json_path, 'w') as f:
324 | json.dump({k: float(v) for k, v in result.items()}, f)
325 | flag_json_path = os.path.join(FLAGS.model_dir, 'flags.json')
326 | with tf.io.gfile.GFile(flag_json_path, 'w') as f:
327 | json.dump(FLAGS.flag_values_dict(), f)
328 |
329 | # Save Hub module.
330 | build_hub_module(model, num_classes,
331 | global_step=result['global_step'],
332 | checkpoint_path=checkpoint_path)
333 |
334 | return result
335 |
336 |
337 | def main(argv):
338 | if len(argv) > 1:
339 | raise app.UsageError('Too many command-line arguments.')
340 |
341 | # Enable training summary.
342 | if FLAGS.train_summary_steps > 0:
343 | tf.config.set_soft_device_placement(True)
344 |
345 |
346 | builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
347 | builder.download_and_prepare()
348 | num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
349 | num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
350 | num_classes = builder.info.features['label'].num_classes
351 |
352 | train_steps = model_util.get_train_steps(num_train_examples)
353 | eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
354 | epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))
355 |
356 | resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
357 | model = resnet.resnet_v1(
358 | resnet_depth=FLAGS.resnet_depth,
359 | width_multiplier=FLAGS.width_multiplier,
360 | cifar_stem=FLAGS.image_size <= 32)
361 |
362 | checkpoint_steps = (
363 | FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps))
364 |
365 | cluster = None
366 | if FLAGS.use_tpu and FLAGS.master is None:
367 | if FLAGS.tpu_name:
368 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
369 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
370 | else:
371 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
372 | tf.config.experimental_connect_to_cluster(cluster)
373 | tf.tpu.experimental.initialize_tpu_system(cluster)
374 |
375 | sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1#SLICED
376 | run_config = tf.estimator.tpu.RunConfig(
377 | tpu_config=tf.estimator.tpu.TPUConfig(
378 | iterations_per_loop=checkpoint_steps,
379 | eval_training_input_configuration=sliced_eval_mode),
380 | model_dir=FLAGS.model_dir,
381 | save_summary_steps=checkpoint_steps,
382 | save_checkpoints_steps=checkpoint_steps,
383 | keep_checkpoint_max=FLAGS.keep_checkpoint_max,
384 | master=FLAGS.master,
385 | cluster=cluster)
386 | estimator = tf.estimator.tpu.TPUEstimator(
387 | model_lib.build_model_fn(model, num_classes, num_train_examples),
388 | config=run_config,
389 | train_batch_size=FLAGS.train_batch_size,
390 | eval_batch_size=FLAGS.eval_batch_size,
391 | use_tpu=FLAGS.use_tpu)
392 |
393 | if FLAGS.mode == 'eval':
394 | for ckpt in tf.train.checkpoints_iterator(
395 | run_config.model_dir, min_interval_secs=15):
396 | try:
397 | result = perform_evaluation(
398 | estimator=estimator,
399 | input_fn=data_lib.build_input_fn(builder, False),
400 | eval_steps=eval_steps,
401 | model=model,
402 | num_classes=num_classes,
403 | checkpoint_path=ckpt)
404 | except tf.errors.NotFoundError:
405 | continue
406 | if result['global_step'] >= train_steps:
407 | return
408 | else:
409 | estimator.train(
410 | data_lib.build_input_fn(builder, True), max_steps=train_steps)
411 | if FLAGS.mode == 'train_then_eval':
412 | perform_evaluation(
413 | estimator=estimator,
414 | input_fn=data_lib.build_input_fn(builder, False),
415 | eval_steps=eval_steps,
416 | model=model,
417 | num_classes=num_classes)
418 |
419 |
420 | if __name__ == '__main__':
421 | tf.disable_eager_execution() # Disable eager mode when running with TF2.
422 | app.run(main)
423 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yao-Hung Hubert Tsai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MSCOCO/README.md:
--------------------------------------------------------------------------------
1 | # MS COCO Experiments
2 |
3 | Download MSCOCO2017 dataset from [here](http://cocodataset.org/)
4 |
5 | Run ```main.ipynb``` to reproduce experiments.
6 |
7 | Run ```ResNet_baseline.ipynb``` to reproduce baseline results.
--------------------------------------------------------------------------------
/MSCOCO/ResNet_baseline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os, time"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "TRIAL_NAME='ResNet_baseline'"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "CONF={\n",
28 | " 'niter':200,\n",
29 | " 'GPU':0,\n",
30 | " 'BS':128,\n",
31 | " 'test_BS':256,\n",
32 | " 'N_neg':3,\n",
33 | " 'name':TRIAL_NAME,\n",
34 | " 'tb_dir':os.path.join('./runs', TRIAL_NAME),\n",
35 | " 'nz':64,\n",
36 | " 'seed':10708,\n",
37 | " 'data_dir':'/DataSet/COCO',\n",
38 | " 'dataType':'train2017',\n",
39 | " 'valType':'val2017',\n",
40 | " 'LAMBDA':0.5,\n",
41 | " 'use_super':False,\n",
42 | " 'test_classes':80\n",
43 | "}"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=str(CONF['GPU'])"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "# Clear any logs from previous runs\n",
62 | "time.sleep(2)\n",
63 | "import shutil\n",
64 | "shutil.rmtree(CONF['tb_dir'], ignore_errors=True)\n",
65 | "time.sleep(5)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "import texar.torch as tx\n",
75 | "import random\n",
76 | "import torch\n",
77 | "import torch.nn as nn\n",
78 | "import torch.nn.parallel\n",
79 | "from torch.nn import functional as F\n",
80 | "import torch.backends.cudnn as cudnn\n",
81 | "import torch.optim as optim\n",
82 | "import torch.utils.data\n",
83 | "import torchvision\n",
84 | "from torch.utils.tensorboard import SummaryWriter\n",
85 | "import torchvision.datasets as dset\n",
86 | "import torchvision.transforms as transforms\n",
87 | "import torchvision.utils as vutils\n",
88 | "import numpy as np\n",
89 | "from torch import autograd\n",
90 | "import multiprocessing\n",
91 | "from PIL import Image\n",
92 | "from sklearn import metrics"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "device = torch.device(\"cuda:0\" if True else \"cpu\")"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "writer = SummaryWriter(log_dir=CONF['tb_dir'])"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "random.seed(CONF['seed'])\n",
120 | "torch.manual_seed(CONF['seed'])\n",
121 | "np.random.seed(CONF['seed'])\n",
122 | "cudnn.benchmark = True"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "batch_size = CONF['BS']"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "T = transforms.Compose([\n",
141 | " transforms.RandomResizedCrop((256,256), scale=(0.3, 1.0), ratio=(0.75, 1.3333333333333333)),\n",
142 | " transforms.ColorJitter(brightness=.1, contrast=.05, saturation=.05, hue=.05),\n",
143 | " transforms.RandomHorizontalFlip(),\n",
144 | " transforms.ToTensor()\n",
145 | "])\n",
146 | "T_test = transforms.Compose([\n",
147 | " transforms.Resize((256,256)),\n",
148 | " transforms.ToTensor()\n",
149 | "])"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "from torchvision.datasets.vision import VisionDataset\n",
159 | "class CocoClassification(VisionDataset):\n",
160 | " \"\"\"`MS Coco Detection `_ Dataset.\n",
161 | "\n",
162 | " Args:\n",
163 | " root (string): Root directory where images are downloaded to.\n",
164 | " annFile (string): Path to json annotation file.\n",
165 | " transform (callable, optional): A function/transform that takes in an PIL image\n",
166 | " and returns a transformed version. E.g, ``transforms.ToTensor``\n",
167 | " target_transform (callable, optional): A function/transform that takes in the\n",
168 | " target and transforms it.\n",
169 | " transforms (callable, optional): A function/transform that takes input sample and its target as entry\n",
170 | " and returns a transformed version.\n",
171 | " \"\"\"\n",
172 | "\n",
173 | " def sample_class(self, k):\n",
174 | " if CONF['use_super']:\n",
175 | " self.classes = np.array([\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"])#np.arange(12)+1\n",
176 | " self.class_description = [\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"]\n",
177 | " return\n",
178 | " class_list = self.coco.getCatIds()\n",
179 | " self.classes = np.sort(np.random.choice(class_list, size=k, replace=False))\n",
180 | " self.class_description = self.coco.loadCats(self.classes)\n",
181 | " arr = []\n",
182 | " for catId in self.classes:\n",
183 | " arr+=self.coco.getImgIds(catIds=[catId])\n",
184 | " self.ids = sorted(list(set(arr)))\n",
185 | " \n",
186 | " def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):\n",
187 | " super(CocoClassification, self).__init__(root, transforms, transform, target_transform)\n",
188 | " from pycocotools.coco import COCO\n",
189 | " self.coco = COCO(annFile)\n",
190 | " self.sample_class(len(self.coco.getCatIds()))\n",
191 | "\n",
192 | " def __getitem__(self, index):\n",
193 | " \"\"\"\n",
194 | " Args:\n",
195 | " index (int): Index\n",
196 | "\n",
197 | " Returns:\n",
198 | " tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.\n",
199 | " \"\"\"\n",
200 | " coco = self.coco\n",
201 | " img_id = self.ids[index]\n",
202 | " ann_ids = coco.getAnnIds(imgIds=img_id)\n",
203 | " cat_ids = [ann['category_id'] for ann in coco.loadAnns(ann_ids)]\n",
204 | " target = coco.loadCats(cat_ids)\n",
205 | " if CONF['use_super']:\n",
206 | " target = np.array([x['supercategory'] for x in target])\n",
207 | " else:\n",
208 | " target = np.array([x['id'] for x in target if x['id'] in self.classes])\n",
209 | " targets = torch.FloatTensor([1 if (c in target) else 0 for c in self.classes])\n",
210 | " path = coco.loadImgs(img_id)[0]['file_name']\n",
211 | " img = Image.open(os.path.join(self.root, path)).convert('RGB')\n",
212 | " if self.transforms is not None:\n",
213 | " img, targets = self.transforms(img, targets)\n",
214 | "\n",
215 | " return img, targets\n",
216 | "\n",
217 | "\n",
218 | " def __len__(self):\n",
219 | " return len(self.ids)\n"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "\n",
229 | "clas_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n",
230 | " annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n",
231 | " transform=T)\n",
232 | "val_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['valType']),\n",
233 | " annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['valType']),\n",
234 | " transform=T_test)"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "metadata": {},
241 | "outputs": [],
242 | "source": [
243 | "clas_set.sample_class(CONF['test_classes'])\n",
244 | "train_loader = torch.utils.data.DataLoader(clas_set, batch_size=CONF['BS'],shuffle=True, num_workers=8, pin_memory=True)"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=CONF['test_BS'],shuffle=False, num_workers=8, pin_memory=False)"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "class Flatten(nn.Module):\n",
263 | " def forward(self, x):\n",
264 | " x = x.view(x.size()[0], -1)\n",
265 | " return x"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "resnet18 = torchvision.models.resnet18(pretrained=True)\n",
275 | "modules=list(resnet18.children())[:-1]\n",
276 | "modules.append(Flatten())\n",
277 | "modules.append(nn.Linear(512, CONF['test_classes']))\n",
278 | "resnet = nn.Sequential(*modules)\n",
279 | "resnet.cuda()"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": null,
285 | "metadata": {},
286 | "outputs": [],
287 | "source": [
288 | "opt = optim.SGD(resnet.parameters(), lr=0.02, momentum=0.9)\n",
289 | "scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.2, patience=10)"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "metadata": {},
296 | "outputs": [],
297 | "source": [
298 | "c=0"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": null,
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "criterion = nn.BCEWithLogitsLoss()"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": null,
313 | "metadata": {},
314 | "outputs": [],
315 | "source": [
316 | "from tqdm.notebook import tqdm, trange"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "metadata": {},
323 | "outputs": [],
324 | "source": [
325 | "def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):\n",
326 | " '''\n",
327 | " Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case\n",
328 | " https://stackoverflow.com/q/32239577/395857\n",
329 | " '''\n",
330 | " acc_list = []\n",
331 | " for i in range(y_true.shape[0]):\n",
332 | " set_true = set( np.where(y_true[i])[0] )\n",
333 | " set_pred = set( np.where(y_pred[i])[0] )\n",
334 | " tmp_a = None\n",
335 | " if len(set_true) == 0 and len(set_pred) == 0:\n",
336 | " tmp_a = 1\n",
337 | " else:\n",
338 | " tmp_a = len(set_true.intersection(set_pred))/\\\n",
339 | " float( len(set_true.union(set_pred)) )\n",
340 | " acc_list.append(tmp_a)\n",
341 | " return np.mean(acc_list)"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "metadata": {},
348 | "outputs": [],
349 | "source": [
350 | "for it in trange(CONF['niter']):\n",
351 | "\n",
352 | " l=[]\n",
353 | " LBL=[]\n",
354 | " activation = []\n",
355 | " Y = []\n",
356 | " resnet.eval()\n",
357 | " with torch.no_grad():\n",
358 | " count = 0\n",
359 | " corrects = torch.zeros(CONF['test_classes'])\n",
360 | " for img,lbl in tqdm(val_loader, leave=False, desc=\"testing\"):\n",
361 | " LBL.append(lbl.numpy())\n",
362 | " lbl=lbl.cuda()\n",
363 | " count += img.shape[0]\n",
364 | " pred = resnet(img.cuda())\n",
365 | " l.append(criterion(pred, lbl).item())\n",
366 | " activation.append(pred.data.cpu().numpy())\n",
367 | " pred = torch.sigmoid(pred)>.5\n",
368 | " Y.append(pred.data.cpu().numpy())\n",
369 | " corrects += torch.sum(torch.eq(pred, lbl), dim=0).cpu()\n",
370 | " acc = (corrects/float(count))\n",
371 | " writer.add_scalar(\"sup_acc/val_avg\", torch.mean(acc).item(), global_step=it)\n",
372 | " writer.add_histogram('baseline/acc_val', acc.data.cpu().numpy(), global_step=it)\n",
373 | " writer.add_scalar(\"sup_loss/val_loss\", np.average(l), global_step=it)\n",
374 | " writer.flush()\n",
375 | "\n",
376 | " Y, LBL, activation = np.concatenate(Y,axis=0), np.concatenate(LBL,axis=0), np.concatenate(activation,axis=0)\n",
377 | " writer.add_scalar(\"sup_metrics/subset_acc\", metrics.accuracy_score(LBL, Y, normalize=True), global_step=it)\n",
378 | " writer.add_scalar(\"sup_metrics/hamming_loss\", metrics.hamming_loss(LBL, Y), global_step=it)\n",
379 | " writer.add_scalar(\"sup_metrics/hamming_score\", hamming_score(LBL, Y), global_step=it)\n",
380 | " writer.add_scalar(\"sup_metrics/micro_f1\", metrics.f1_score(LBL, Y, average='micro'), global_step=it)\n",
381 | " writer.add_scalar(\"sup_metrics/macro_f1\", metrics.f1_score(LBL, Y, average='macro'), global_step=it)\n",
382 | " writer.add_scalar(\"sup_metrics/micro_roc_auc\", metrics.roc_auc_score(LBL, Y, average='micro'), global_step=it)\n",
383 | " writer.add_scalar(\"sup_metrics/macro_roc_auc\", metrics.roc_auc_score(LBL, Y, average='macro'), global_step=it)\n",
384 | " writer.add_scalar(\"sup_metrics/micro_precision\", metrics.precision_score(LBL,Y,average='micro'), global_step=it)\n",
385 | " writer.add_scalar(\"sup_metrics/macro_precision\", metrics.precision_score(LBL,Y,average='macro'), global_step=it)\n",
386 | " writer.add_scalar(\"sup_metrics/micro_recall\", metrics.recall_score(LBL,Y,average='micro'), global_step=it)\n",
387 | " writer.add_scalar(\"sup_metrics/macro_recall\", metrics.recall_score(LBL,Y,average='macro'), global_step=it)\n",
388 | " writer.add_scalar(\"sup_metrics/avg_acc\", np.average(np.sum((Y==LBL), axis=0)/Y.shape[0]), global_step=it)\n",
389 | " scheduler.step(np.average(l))\n",
390 | " \n",
391 | " count = 0\n",
392 | " corrects = torch.zeros(CONF['test_classes'])\n",
393 | " resnet.train()\n",
394 | " for img,lbl in tqdm(train_loader, leave=False, desc=\"training\"):\n",
395 | " count += img.shape[0]\n",
396 | " img = img.cuda()\n",
397 | " pred = resnet(img)\n",
398 | " loss = criterion(pred,lbl.cuda())\n",
399 | " with torch.no_grad():\n",
400 | " corrects += torch.sum(torch.eq((torch.sigmoid(pred)>.5).cpu(), lbl), dim=0)\n",
401 | " opt.zero_grad()\n",
402 | " loss.backward()\n",
403 | " opt.step()\n",
404 | " writer.add_scalar(\"sup_loss/loss\", loss.item(), global_step=c)\n",
405 | " c+=1\n",
406 | " acc = (corrects/float(count))\n",
407 | " writer.add_scalar(\"sup_acc/train_avg\", torch.mean(acc).item(), global_step=it)\n",
408 | " writer.add_histogram('baseline/acc_train', acc.data.cpu().numpy(), global_step=it)\n",
409 | "\n",
410 | "writer.close()"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "execution_count": null,
416 | "metadata": {},
417 | "outputs": [],
418 | "source": []
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": []
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": null,
430 | "metadata": {},
431 | "outputs": [],
432 | "source": []
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": null,
437 | "metadata": {},
438 | "outputs": [],
439 | "source": []
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": null,
444 | "metadata": {},
445 | "outputs": [],
446 | "source": []
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": null,
451 | "metadata": {},
452 | "outputs": [],
453 | "source": []
454 | }
455 | ],
456 | "metadata": {
457 | "kernelspec": {
458 | "display_name": "Python 3",
459 | "language": "python",
460 | "name": "python3"
461 | },
462 | "language_info": {
463 | "codemirror_mode": {
464 | "name": "ipython",
465 | "version": 3
466 | },
467 | "file_extension": ".py",
468 | "mimetype": "text/x-python",
469 | "name": "python",
470 | "nbconvert_exporter": "python",
471 | "pygments_lexer": "ipython3",
472 | "version": "3.7.6"
473 | }
474 | },
475 | "nbformat": 4,
476 | "nbformat_minor": 4
477 | }
478 |
--------------------------------------------------------------------------------
/MSCOCO/main.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os, time"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "TRIAL_NAME='trial'"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "CONF={\n",
28 | " 'niter':5,\n",
29 | " 'ntest':50,\n",
30 | " 'GPU':0,\n",
31 | " 'BS':4,\n",
32 | " 'test_BS':256,\n",
33 | " 'N_neg':3,\n",
34 | " 'name':TRIAL_NAME,\n",
35 | " 'nz':2048,\n",
36 | " 'ng':128,\n",
37 | " 'seed':10715,\n",
38 | " 'data_dir':'/DataSet/COCO', #Set it to where your COCO dataset is\n",
39 | " 'dataType':'train2017',\n",
40 | " 'valType':'val2017',\n",
41 | " 'testType':'test2017',\n",
42 | " 'max_len':512,\n",
43 | " 'hidden_size':768,\n",
44 | " 'bert_hdim':3072,\n",
45 | " 'LAMBDAs':0.005,\n",
46 | " 'use_super':False,\n",
47 | " 'test_classes':80,\n",
48 | " 'n_trials':5,\n",
49 | " 'bert_pretrained':True, #pretrain bert\n",
50 | " 'resnet_pretrained':True #pretrain resnet\n",
51 | "}"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=str(CONF['GPU'])"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": []
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "import texar.torch as tx\n",
77 | "import random\n",
78 | "import torch\n",
79 | "import torch.nn as nn\n",
80 | "import torch.nn.parallel\n",
81 | "from torch.nn import functional as F\n",
82 | "import torch.backends.cudnn as cudnn\n",
83 | "import torch.optim as optim\n",
84 | "import torch.utils.data\n",
85 | "import torchvision\n",
86 | "from torch.utils.tensorboard import SummaryWriter\n",
87 | "import torchvision.datasets as dset\n",
88 | "import torchvision.transforms as transforms\n",
89 | "import torchvision.utils as vutils\n",
90 | "import numpy as np\n",
91 | "from torch import autograd\n",
92 | "import multiprocessing\n",
93 | "from PIL import Image\n",
94 | "from sklearn import metrics"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "device = torch.device(\"cuda:0\" if True else \"cpu\")"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "time.sleep(2)\n",
113 | "import shutil\n",
114 | "tb_dir=os.path.join('./runs', CONF['name'] + \"_GLOBAL\")\n",
115 | "shutil.rmtree(tb_dir, ignore_errors=True)\n",
116 | "time.sleep(5)\n",
117 | "global_writer = SummaryWriter(log_dir=tb_dir)"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "random.seed(CONF['seed'])\n",
127 | "torch.manual_seed(CONF['seed'])\n",
128 | "np.random.seed(CONF['seed'])\n",
129 | "cudnn.benchmark = True"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "batch_size = CONF['BS']"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "T = transforms.Compose([\n",
148 | " transforms.RandomResizedCrop((224,224), scale=(0.3, 1.0), ratio=(0.75, 1.3333333333333333)),\n",
149 | " transforms.ColorJitter(brightness=.1, contrast=.05, saturation=.05, hue=.05),\n",
150 | " transforms.RandomHorizontalFlip(),\n",
151 | " transforms.ToTensor(),\n",
152 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
153 | "])\n",
154 | "T_test = transforms.Compose([\n",
155 | " transforms.Resize((224,224)),\n",
156 | " transforms.ToTensor(),\n",
157 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
158 | "])"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "from torchvision.datasets.vision import VisionDataset\n",
168 | "class CocoClassification(VisionDataset):\n",
169 | " \"\"\"`MS Coco Detection `_ Dataset.\n",
170 | "\n",
171 | " Args:\n",
172 | " root (string): Root directory where images are downloaded to.\n",
173 | " annFile (string): Path to json annotation file.\n",
174 | " transform (callable, optional): A function/transform that takes in an PIL image\n",
175 | " and returns a transformed version. E.g, ``transforms.ToTensor``\n",
176 | " target_transform (callable, optional): A function/transform that takes in the\n",
177 | " target and transforms it.\n",
178 | " transforms (callable, optional): A function/transform that takes input sample and its target as entry\n",
179 | " and returns a transformed version.\n",
180 | " \"\"\"\n",
181 | "\n",
182 | " def sample_class(self, k):\n",
183 | " if CONF['use_super']:\n",
184 | " self.classes = np.array([\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"])\n",
185 | " self.class_description = [\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"]\n",
186 | " return\n",
187 | " class_list = self.coco.getCatIds()\n",
188 | " self.classes = np.sort(np.random.choice(class_list, size=k, replace=False))\n",
189 | " self.class_description = self.coco.loadCats(self.classes)\n",
190 | " arr = []\n",
191 | " for catId in self.classes:\n",
192 | " arr+=self.coco.getImgIds(catIds=[catId])\n",
193 | " self.ids = sorted(list(set(arr)))\n",
194 | " \n",
195 | " def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):\n",
196 | " super(CocoClassification, self).__init__(root, transforms, transform, target_transform)\n",
197 | " from pycocotools.coco import COCO\n",
198 | " self.coco = COCO(annFile)\n",
199 | " self.sample_class(len(self.coco.getCatIds()))\n",
200 | "\n",
201 | " def __getitem__(self, index):\n",
202 | " \"\"\"\n",
203 | " Args:\n",
204 | " index (int): Index\n",
205 | "\n",
206 | " Returns:\n",
207 | " tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.\n",
208 | " \"\"\"\n",
209 | " coco = self.coco\n",
210 | " img_id = self.ids[index]\n",
211 | " ann_ids = coco.getAnnIds(imgIds=img_id)\n",
212 | " cat_ids = [ann['category_id'] for ann in coco.loadAnns(ann_ids)]\n",
213 | " target = coco.loadCats(cat_ids)\n",
214 | " if CONF['use_super']:\n",
215 | " target = np.array([x['supercategory'] for x in target])\n",
216 | " else:\n",
217 | " target = np.array([x['id'] for x in target if x['id'] in self.classes])\n",
218 | " targets = torch.FloatTensor([1 if (c in target) else 0 for c in self.classes])\n",
219 | " path = coco.loadImgs(img_id)[0]['file_name']\n",
220 | " img = Image.open(os.path.join(self.root, path)).convert('RGB')\n",
221 | " if self.transforms is not None:\n",
222 | " img, targets = self.transforms(img, targets)\n",
223 | "\n",
224 | " return img, targets\n",
225 | "\n",
226 | "\n",
227 | " def __len__(self):\n",
228 | " return len(self.ids)\n"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "metadata": {},
235 | "outputs": [],
236 | "source": [
237 | "dataset = dset.CocoCaptions(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n",
238 | " annFile = '{}/annotations/captions_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n",
239 | " transform=T)\n",
240 | "\n",
241 | "clas_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n",
242 | " annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n",
243 | " transform=T_test)\n",
244 | "val_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['valType']),\n",
245 | " annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['valType']),\n",
246 | " transform=T_test)"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size*(CONF['N_neg']+1),\n",
256 | " shuffle=True, num_workers=2, pin_memory=True, drop_last=True)"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=CONF['test_BS'],shuffle=False, num_workers=8, pin_memory=False)"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "hparams = {\n",
275 | " \"pretrained_model_name\": \"bert-base-uncased\",\n",
276 | " \"vocab_file\": None,\n",
277 | " \"max_len\": CONF['max_len'],\n",
278 | " \"unk_token\": \"[UNK]\",\n",
279 | " \"sep_token\": \"[SEP]\",\n",
280 | " \"pad_token\": \"[PAD]\",\n",
281 | " \"cls_token\": \"[CLS]\",\n",
282 | " \"mask_token\": \"[MASK]\",\n",
283 | " \"tokenize_chinese_chars\": True,\n",
284 | " \"do_lower_case\": True,\n",
285 | " \"do_basic_tokenize\": True,\n",
286 | " \"non_split_tokens\": None,\n",
287 | " \"name\": \"bert_tokenizer\",\n",
288 | "}"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": [
297 | "tokenizer = tx.data.BERTTokenizer(hparams=hparams, pretrained_model_name='bert-base-uncased')"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "bert_hparams = {\n",
307 | " 'embed': {'dim': CONF['hidden_size'], 'name': 'word_embeddings'},\n",
308 | " 'vocab_size': 30522,\n",
309 | " 'segment_embed': {'dim': CONF['hidden_size'], 'name': 'token_type_embeddings'},\n",
310 | " 'type_vocab_size': 2,\n",
311 | " 'position_embed': {'dim': CONF['hidden_size'], 'name': 'position_embeddings'},\n",
312 | " 'position_size': CONF['max_len'],\n",
313 | " 'encoder': {'dim': CONF['hidden_size'],\n",
314 | " 'embedding_dropout': 0.1,\n",
315 | " 'multihead_attention': {'dropout_rate': 0.1,\n",
316 | " 'name': 'self',\n",
317 | " 'num_heads': 6,\n",
318 | " 'num_units': CONF['hidden_size'],\n",
319 | " 'output_dim': CONF['hidden_size'],\n",
320 | " 'use_bias': True},\n",
321 | " 'name': 'encoder',\n",
322 | " 'num_blocks': 4,\n",
323 | " 'poswise_feedforward': {'layers': [{'kwargs': {'in_features': CONF['hidden_size'],\n",
324 | " 'out_features': CONF['bert_hdim'],\n",
325 | " 'bias': True},\n",
326 | " 'type': 'Linear'},\n",
327 | " {'type': 'BertGELU'},\n",
328 | " {'kwargs': {'in_features': CONF['bert_hdim'], 'out_features': CONF['hidden_size'], 'bias': True},\n",
329 | " 'type': 'Linear'}]},\n",
330 | " 'residual_dropout': 0.1,\n",
331 | " 'use_bert_config': True},\n",
332 | " 'hidden_size': CONF['hidden_size'],\n",
333 | " 'initializer': None,\n",
334 | " 'name': 'bert_encoder',\n",
335 | " 'pretrained_model_name':None}\n",
336 | "\n",
337 | "if CONF['bert_pretrained']:\n",
338 | " bert_hparams[\"pretrained_model_name\"]=\"bert-base-uncased\""
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": null,
344 | "metadata": {},
345 | "outputs": [],
346 | "source": [
347 | "class Flatten(nn.Module):\n",
348 | " def forward(self, x):\n",
349 | " x = x.view(x.size()[0], -1)\n",
350 | " return x"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": null,
356 | "metadata": {},
357 | "outputs": [],
358 | "source": [
359 | "class PC_Embedder(nn.Module):\n",
360 | "\n",
361 | "\n",
362 | " def __init__(self, k, nz, hparams):\n",
363 | " super(PC_Embedder, self).__init__()\n",
364 | " self.bert = tx.modules.BERTEncoder(hparams=hparams)\n",
365 | " self.q = nn.Linear(CONF['hidden_size'], nz)\n",
366 | " resnet50 = torchvision.models.resnet50(pretrained=CONF['resnet_pretrained'])\n",
367 | " modules=list(resnet50.children())[:-1]\n",
368 | " modules.append(Flatten())\n",
369 | " self.p = nn.Sequential(*modules)\n",
370 | " self.k = k+1\n",
371 | " self.G = nn.Sequential(\n",
372 | " nn.Linear(2048, 512, bias=True),\n",
373 | " nn.ReLU(),\n",
374 | " nn.Linear(512, nz, bias=True),\n",
375 | " )\n",
376 | "\n",
377 | " def forward(self, v):\n",
378 | " return self.p(v)\n",
379 | "\n",
380 | " def get_loss(self, v, v1, LAMBDA):\n",
381 | " inputs, segment_ids = v1\n",
382 | " _, v1 = self.bert(inputs=inputs, segment_ids=segment_ids)\n",
383 | " batch_size = v.shape[0]//self.k\n",
384 | " z = self.G(self.p(v))\n",
385 | " z = z.squeeze().view(batch_size, self.k, z.shape[1])[:,0,:]\n",
386 | " z_l = z.unsqueeze(1).expand(z.shape[0],self.k,z.shape[1]).contiguous()\n",
387 | " z_p = self.q(v1).view(z_l.shape).contiguous()\n",
388 | " l1 = F.log_softmax(torch.sum(z_l*z_p, dim = -1), dim=1)[:,0]\n",
389 | " l2 = torch.sum((z - z_p[:,0,:])**2, dim=-1)\n",
390 | " return torch.mean(- l1 + LAMBDA*l2, dim=0),l1,l2"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": null,
396 | "metadata": {},
397 | "outputs": [],
398 | "source": [
399 | "class Tracker(object):\n",
400 | "\n",
401 | " def __init__(self, VARS, ranks):\n",
402 | " self.var_dict = dict(zip(VARS, [(1000000.0 if ranks[i] else -1000000.0, int(ranks[i])) for i in range(len(VARS))]))\n",
403 | " \n",
404 | " def update(self, d):\n",
405 | " for k,v in d.items():\n",
406 | " o,r = self.var_dict[k]\n",
407 | " self.var_dict[k] = (np.minimum(v,o) if r else np.maximum(v,o), r)\n",
408 | " \n",
409 | " def return_dict(self):\n",
410 | " D = dict()\n",
411 | " for k,(v,_) in self.var_dict.items():\n",
412 | " D[k]=v\n",
413 | " return D"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "execution_count": null,
419 | "metadata": {},
420 | "outputs": [],
421 | "source": [
422 | "test_class = nn.Linear(CONF['nz'],CONF['test_classes']).cuda()"
423 | ]
424 | },
425 | {
426 | "cell_type": "code",
427 | "execution_count": null,
428 | "metadata": {},
429 | "outputs": [],
430 | "source": [
431 | "criterion = nn.BCEWithLogitsLoss()"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": null,
437 | "metadata": {},
438 | "outputs": [],
439 | "source": [
440 | "from tqdm import notebook\n",
441 | "tnrange=notebook.tnrange\n",
442 | "tqdm_notebook = notebook.tqdm"
443 | ]
444 | },
445 | {
446 | "cell_type": "code",
447 | "execution_count": null,
448 | "metadata": {},
449 | "outputs": [],
450 | "source": [
451 | "def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):\n",
452 | " '''\n",
453 | " Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case\n",
454 | " https://stackoverflow.com/q/32239577/395857\n",
455 | " '''\n",
456 | " acc_list = []\n",
457 | " for i in range(y_true.shape[0]):\n",
458 | " set_true = set( np.where(y_true[i])[0] )\n",
459 | " set_pred = set( np.where(y_pred[i])[0] )\n",
460 | " tmp_a = None\n",
461 | " if len(set_true) == 0 and len(set_pred) == 0:\n",
462 | " tmp_a = 1\n",
463 | " else:\n",
464 | " tmp_a = len(set_true.intersection(set_pred))/\\\n",
465 | " float( len(set_true.union(set_pred)) )\n",
466 | " acc_list.append(tmp_a)\n",
467 | " return np.mean(acc_list)"
468 | ]
469 | },
470 | {
471 | "cell_type": "code",
472 | "execution_count": null,
473 | "metadata": {},
474 | "outputs": [],
475 | "source": [
476 | "METRICS = ['train_avg', 'val_avg', 'subset_acc', 'hamming_loss', 'hamming_score',\n",
477 | " 'micro_f1', 'macro_f1', 'micro_roc_auc', 'macro_roc_auc',\n",
478 | " 'micro_precision', 'macro_precision', 'micro_recall', 'macro_recall']\n",
479 | "RANKS = [0, 0, 0, 1, 0,\n",
480 | " 0, 0, 0, 0,\n",
481 | " 0, 0, 1, 1]\n",
482 | "\n",
483 | "assert(len(METRICS)==len(RANKS))"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": null,
489 | "metadata": {},
490 | "outputs": [],
491 | "source": [
492 | "def trial(enc, opt, LAMBDA, trial_number):\n",
493 | " c=0\n",
494 | " # Clear any logs from previous runs\n",
495 | " import shutil\n",
496 | " tb_dir=os.path.join('./runs', CONF['name'] + \"_LAMBDA_{}_{}\".format(LAMBDA, trial_number))\n",
497 | " shutil.rmtree(tb_dir, ignore_errors=True)\n",
498 | " time.sleep(5)\n",
499 | " tracker = Tracker(METRICS, RANKS)\n",
500 | " writer = SummaryWriter(log_dir=tb_dir)\n",
501 | " for it in tnrange(CONF['niter'], desc=\"training with LAMBDA:{}\".format(LAMBDA)):\n",
502 | "\n",
503 | " test_class.reset_parameters()\n",
504 | " clas_set.sample_class(CONF['test_classes'])\n",
505 | " clas_loader = torch.utils.data.DataLoader(clas_set, batch_size=CONF['test_BS'],shuffle=True, num_workers=8, pin_memory=False)\n",
506 | " with torch.no_grad():\n",
507 | " enc.eval()\n",
508 | " test = []\n",
509 | " for img,lbl in tqdm_notebook(clas_loader,leave=False,desc=\"generating train set\"):\n",
510 | " test.append((enc(img.cuda()).data.cpu(), lbl))\n",
511 | " enc.train()\n",
512 | "\n",
513 | " opt_test = optim.Adam(test_class.parameters(), lr=1e-2)\n",
514 | " for test_it in tnrange(CONF['ntest'],leave=False,desc=\"training linear classifier\"):\n",
515 | " random.shuffle(test)\n",
516 | " for z,lbl in test:\n",
517 | " pred = test_class(z.cuda())\n",
518 | " loss = criterion(pred,lbl.cuda())\n",
519 | " opt_test.zero_grad()\n",
520 | " loss.backward()\n",
521 | " opt_test.step()\n",
522 | "\n",
523 | " with torch.no_grad():\n",
524 | " M = dict()\n",
525 | " count = 0\n",
526 | " corrects = torch.zeros(CONF['test_classes'])\n",
527 | " for z,lbl in test:\n",
528 | " count += z.shape[0]\n",
529 | " pred = torch.sigmoid(test_class(z.cuda()))>.5\n",
530 | " corrects += torch.sum(torch.eq(pred.cpu(), lbl), dim=0)\n",
531 | " acc = (corrects/float(count))\n",
532 | " writer.add_scalar(\"acc/train_avg\", torch.mean(acc).item(), global_step=it)\n",
533 | " writer.add_histogram('acc_train', acc.data.cpu().numpy(), global_step=it)\n",
534 | " M['train_avg'] = torch.mean(acc).item()\n",
535 | "\n",
536 | " count = 0\n",
537 | " corrects = torch.zeros(CONF['test_classes'])\n",
538 | " LBL, Y = [], []\n",
539 | " for img,lbl in val_loader:\n",
540 | " LBL.append(lbl.numpy())\n",
541 | " count += img.shape[0]\n",
542 | " z = enc(img.cuda())\n",
543 | " pred = torch.sigmoid(test_class(z))>.5\n",
544 | " Y.append(pred.data.cpu().numpy())\n",
545 | " corrects += torch.sum(torch.eq(pred.cpu(), lbl), dim=0)\n",
546 | " acc = (corrects/float(count))\n",
547 | " writer.add_scalar(\"acc/val_avg\", torch.mean(acc).item(), global_step=it)\n",
548 | " writer.add_histogram('acc_val', acc.data.cpu().numpy(), global_step=it)\n",
549 | "\n",
550 | " M['val_avg'] = torch.mean(acc).item()\n",
551 | " Y, LBL = np.concatenate(Y,axis=0), np.concatenate(LBL,axis=0)\n",
552 | " M['subset_acc'] = metrics.accuracy_score(LBL, Y, normalize=True)\n",
553 | " M['hamming_loss'] = metrics.hamming_loss(LBL, Y)\n",
554 | " M['hamming_score'] = hamming_score(LBL, Y)\n",
555 | " M['micro_f1'] = metrics.f1_score(LBL, Y, average='micro')\n",
556 | " M['macro_f1'] = metrics.f1_score(LBL, Y, average='macro')\n",
557 | " M['micro_roc_auc'] = metrics.roc_auc_score(LBL, Y, average='micro')\n",
558 | " M['macro_roc_auc'] = metrics.roc_auc_score(LBL, Y, average='macro')\n",
559 | " M['micro_precision'] = metrics.precision_score(LBL,Y,average='micro')\n",
560 | " M['macro_precision'] = metrics.precision_score(LBL,Y,average='macro')\n",
561 | " M['micro_recall'] = metrics.recall_score(LBL,Y,average='micro')\n",
562 | " M['macro_recall'] = metrics.recall_score(LBL,Y,average='macro')\n",
563 | " \n",
564 | " for k,v in M.items():\n",
565 | " writer.add_scalar(\"metrics/{}\".format(k), v, global_step=it)\n",
566 | " tracker.update(M)\n",
567 | " writer.flush()\n",
568 | "\n",
569 | " for img,ann in tqdm_notebook(train_loader, leave=False):\n",
570 | " ann = np.take_along_axis(np.array(ann), np.random.randint(0, len(ann)-1, size=(1,img.shape[0])),0).squeeze()\n",
571 | " inputs, segment_ids = [],[]\n",
572 | " for s in ann:\n",
573 | " x,y,_ = tokenizer.encode_text(s)\n",
574 | " inputs.append(x)\n",
575 | " segment_ids.append(y)\n",
576 | "\n",
577 | " inputs = torch.LongTensor(inputs).cuda()\n",
578 | " segment_ids = torch.LongTensor(segment_ids).cuda()\n",
579 | " img = img.cuda()\n",
580 | " loss, l1, l2 = enc.get_loss(img,(inputs,segment_ids), LAMBDA)\n",
581 | " opt.zero_grad()\n",
582 | " loss.backward()\n",
583 | " opt.step()\n",
584 | " writer.add_scalar(\"loss/loss\", loss.item(), global_step=c)\n",
585 | " writer.add_scalars(\"loss/parts\", {'l1':- l1.data.mean().item(), 'l2':l2.data.mean().item()}, global_step=c)\n",
586 | " c+=1\n",
587 | " torch.save(enc.state_dict(), \"./models/{}.pth\".format(CONF['name'] + \"_trial_{}_LAMBDA_{}_it_{}\".format(trial_number,LAMBDA,it)))\n",
588 | " writer.close()\n",
589 | " return tracker.return_dict()"
590 | ]
591 | },
592 | {
593 | "cell_type": "code",
594 | "execution_count": null,
595 | "metadata": {},
596 | "outputs": [],
597 | "source": [
598 | "enc = PC_Embedder(CONF['N_neg'], CONF['nz'], bert_hparams)\n",
599 | "enc.to(device)\n",
600 | "opt = optim.Adam(enc.parameters(), lr=1e-4)\n",
601 | "global_writer.add_hparams(hparam_dict={'lambda': CONF['LAMBDAs'], 'trial':t}, metric_dict=trial(enc, opt, CONF['LAMBDAs'], t))\n",
602 | "global_writer.close()"
603 | ]
604 | },
605 | {
606 | "cell_type": "code",
607 | "execution_count": null,
608 | "metadata": {},
609 | "outputs": [],
610 | "source": []
611 | },
612 | {
613 | "cell_type": "code",
614 | "execution_count": null,
615 | "metadata": {},
616 | "outputs": [],
617 | "source": []
618 | },
619 | {
620 | "cell_type": "code",
621 | "execution_count": null,
622 | "metadata": {},
623 | "outputs": [],
624 | "source": []
625 | },
626 | {
627 | "cell_type": "code",
628 | "execution_count": null,
629 | "metadata": {},
630 | "outputs": [],
631 | "source": []
632 | }
633 | ],
634 | "metadata": {
635 | "kernelspec": {
636 | "display_name": "Python 3",
637 | "language": "python",
638 | "name": "python3"
639 | },
640 | "language_info": {
641 | "codemirror_mode": {
642 | "name": "ipython",
643 | "version": 3
644 | },
645 | "file_extension": ".py",
646 | "mimetype": "text/x-python",
647 | "name": "python",
648 | "nbconvert_exporter": "python",
649 | "pygments_lexer": "ipython3",
650 | "version": "3.7.6"
651 | }
652 | },
653 | "nbformat": 4,
654 | "nbformat_minor": 4
655 | }
656 |
--------------------------------------------------------------------------------
/Omniglot/README.md:
--------------------------------------------------------------------------------
1 | # Omniglot Experiments
2 |
3 | The code is adapted from [here](https://github.com/leftthomas/SimCLR)
4 |
5 |
6 | ## Usage
7 |
8 | ### Evaluating Self-supervised Representations
9 |
10 | Contrastive Learning Objective only
11 | ```
12 | python main.py --loss_type 1 --recon_param 0.0 --inver_param 0.0 --epochs 1000
13 | ```
14 |
15 | Others please refer to `main.py`
16 |
17 | ### Measuring Information
18 |
19 | + Step 1: Train an Auto-Encoder (we assume the encoded features do not lose any information)
20 | ```
21 | python compute_MI_CondEntro.py --stage AE
22 | ```
23 |
24 | + Step 2: Estimate the raw information
25 | ```
26 | python compute_MI_CondEntro.py --stage Raw_Information
27 | ```
28 |
29 | + Step 3: Estimate the information for the SSL learned representations
30 |
31 | - for Contrastive Learning Objective only
32 | ```
33 | python main.py --loss_type 1 --recon_param 0.0 --inver_param 0.0 --epochs 1000 --with_info
34 | ```
35 | - for Contrastive Learning Objective + Inverse Predictive Learning Objective
36 | ```
37 | python main.py --loss_type 2 --recon_param 0.0 --inver_param 1.0 --epochs 1000 --with_info
38 | ```
39 |
--------------------------------------------------------------------------------
/Omniglot/compute_MI_CondEntro.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.models.resnet import resnet18, resnet34, resnet50
5 | import math
6 | import utils
7 | import argparse
8 | from torch.utils.data import DataLoader
9 | import torch.optim as optim
10 | from tqdm import tqdm
11 | import pandas as pd
12 |
13 |
14 | # ## Utility Functions
15 |
16 | class Lambda(nn.Module):
17 | def __init__(self, func):
18 | super(Lambda, self).__init__()
19 | self.func = func
20 |
21 | def forward(self, x):
22 | return self.func(x)
23 | def mlp(dim, hidden_dim, output_dim, layers=1, batch_norm=False):
24 | if batch_norm:
25 | seq = [nn.Linear(dim, hidden_dim), nn.BatchNorm1d(num_features=hidden_dim),\
26 | nn.ReLU(inplace=True)]
27 | for _ in range(layers):
28 | seq += [nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(num_features=hidden_dim),\
29 | nn.ReLU(inplace=True)]
30 | else:
31 | seq = [nn.Linear(dim, hidden_dim), nn.ReLU(inplace=True)]
32 | for _ in range(layers):
33 | seq += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
34 | seq += [nn.Linear(hidden_dim, output_dim)]
35 |
36 | return nn.Sequential(*seq)
37 |
38 |
39 | def init_models(stage, additional_encoder=False):
40 | if stage == 'AE':
41 | norm_encoder = Encoder(normalize=True).cuda()
42 | norm_decoder = Decoder().cuda()
43 | encoder = Encoder(normalize=False).cuda()
44 | decoder = Decoder().cuda()
45 |
46 | params = list(norm_encoder.parameters()) + list(norm_decoder.parameters()) +\
47 | list(encoder.parameters()) + list(decoder.parameters())
48 |
49 | norm_encoder.train()
50 | norm_decoder.train()
51 | encoder.train()
52 | decoder.train()
53 |
54 | models = {
55 | 'norm_encoder': norm_encoder,
56 | 'norm_decoder': norm_decoder,
57 | 'encoder': encoder,
58 | 'decoder': decoder,
59 | }
60 | elif stage == 'Raw_Information' or stage == 'Feature_Information':
61 | if stage == 'Raw_Information':
62 | norm_encoder = Encoder(normalize=True).cuda()
63 | norm_encoder.load_state_dict(torch.load('results/norm_encoder.pth'))
64 | norm_encoder.eval()
65 | if additional_encoder:
66 | feat_encoder = Encoder(normalize=True).cuda()
67 | feat_encoder.load_state_dict(torch.load('results/norm_encoder.pth'))
68 | feat_encoder.train()
69 | params = list(feat_encoder.parameters())
70 | else:
71 | feat_encoder = None
72 | params = []
73 | else:
74 | norm_encoder = None
75 | feat_encoder = None
76 | params = []
77 |
78 | encoder = Encoder(normalize=False).cuda()
79 | encoder.load_state_dict(torch.load('results/encoder.pth'))
80 |
81 | mi_z_z_model = MI_Z_Z_Model().cuda()
82 | mi_z_t_model = MI_Z_T_Model().cuda()
83 | cond_z_t_model = Cond_Z_T_Model().cuda()
84 | cond_z_z_model = Cond_Z_Z_Model().cuda()
85 |
86 | params = params + list(mi_z_z_model.parameters()) + list(mi_z_t_model.parameters()) +\
87 | list(cond_z_t_model.parameters()) + list(cond_z_z_model.parameters())
88 |
89 | encoder.eval()
90 | mi_z_z_model.train()
91 | mi_z_t_model.train()
92 | cond_z_t_model.train()
93 | cond_z_z_model.train()
94 |
95 | models = {
96 | 'encoder': encoder,
97 | 'norm_encoder': norm_encoder,
98 | 'feat_encoder': feat_encoder,
99 | 'mi_z_z_model': mi_z_z_model,
100 | 'mi_z_t_model': mi_z_t_model,
101 | 'cond_z_t_model': cond_z_t_model,
102 | 'cond_z_z_model': cond_z_z_model,
103 | }
104 | optimizer = optim.Adam(params, lr=1e-3)
105 |
106 | return models, optimizer
107 |
108 |
109 | # ## Auto-Encoding Structure (for one-to-one mapping)
110 |
111 | class Encoder(nn.Module):
112 | def __init__(self, normalize=False):
113 | super(Encoder, self).__init__()
114 | self.f = nn.Sequential(
115 | nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28
116 | #nn.BatchNorm2d(num_features=128),
117 | nn.ReLU(inplace=True),
118 | #nn.MaxPool2d(kernel_size=2, stride=2),
119 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28
120 | #nn.BatchNorm2d(num_features=128),
121 | nn.ReLU(inplace=True),
122 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 14
123 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 14
124 | #nn.BatchNorm2d(num_features=128),
125 | nn.ReLU(inplace=True),
126 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 7
127 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 7
128 | #nn.BatchNorm2d(num_features=128),
129 | nn.ReLU(inplace=True),
130 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 3
131 | nn.Flatten(),
132 | nn.Linear(9*128, 1024),
133 | )
134 | self.normalize = normalize
135 |
136 | def forward(self, _input):
137 | feature = self.f(_input)
138 | if self.normalize:
139 | return F.normalize(feature, dim=-1)
140 | else:
141 | return feature
142 | class Decoder(nn.Module):
143 | def __init__(self):
144 | super(Decoder, self).__init__()
145 | self.f = nn.Sequential(
146 | nn.Linear(1024, 9*128, bias=False),
147 | #nn.BatchNorm1d(num_features=9*128),
148 | nn.ReLU(inplace=True), # (9*128 -> 3*3*128)
149 | Lambda(lambda x: x.view(-1, 128, 3, 3)),
150 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=0,
151 | output_padding=0, bias=False), # out: 7
152 | #nn.BatchNorm2d(num_features=128),
153 | nn.ReLU(inplace=True),
154 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1,
155 | output_padding=1, bias=False), # out: 14
156 | #nn.BatchNorm2d(num_features=128),
157 | nn.ReLU(inplace=True),
158 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1,
159 | output_padding=1, bias=False), # out: 28
160 | #nn.BatchNorm2d(num_features=128),
161 | nn.ReLU(inplace=True),
162 | nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1,
163 | output_padding=0, bias=True), # out: 28
164 | )
165 |
166 | def forward(self, _input):
167 | return self.f(_input)
168 |
169 |
170 | # ## Infomation Functions
171 |
172 | class MI_Z_T_Model(nn.Module):
173 | def __init__(self):
174 | super(MI_Z_T_Model, self).__init__()
175 | self._g = mlp(1024, 512, 512)
176 | self._h = mlp(964, 512, 512)
177 |
178 | def forward(self, z, t):
179 | t = F.one_hot(t, num_classes=964)
180 |
181 | scores = torch.matmul(self._g(z), self._h(t.float()).t())
182 | return scores
183 |
184 |
185 | class MI_Z_Z_Model(nn.Module):
186 | def __init__(self):
187 | super(MI_Z_Z_Model, self).__init__()
188 | self._g = mlp(1024, 512, 512)
189 | self._h = mlp(1024, 512, 512)
190 |
191 | def forward(self, z1, z2):
192 | scores = torch.matmul(self._g(z1), self._h(z2).t())
193 | return scores
194 |
195 |
196 | class Cond_Z_T_Model(nn.Module):
197 | def __init__(self):
198 | super(Cond_Z_T_Model, self).__init__()
199 | self._g = mlp(964, 512, 1024)
200 |
201 | def forward(self, t):
202 | t = F.one_hot(t, num_classes=964)
203 | recon_z = self._g(t.float())
204 | return recon_z
205 |
206 |
207 | class Cond_Z_Z_Model(nn.Module):
208 | def __init__(self):
209 | super(Cond_Z_Z_Model, self).__init__()
210 | self._g = mlp(1024, 512, 1024)
211 |
212 | def forward(self, z):
213 | recon_z = self._g(z)
214 | return recon_z
215 |
216 |
217 | # ## Training
218 |
219 | def AE_loss(x, s, encoder, decoder):
220 | zx = encoder(x)
221 | zs = encoder(s)
222 | hat_x = decoder(zx)
223 | hat_s = decoder(zs)
224 | return F.binary_cross_entropy_with_logits(hat_x, x) +\
225 | F.binary_cross_entropy_with_logits(hat_s, s)
226 | def Enc_z(x, s, encoder):
227 | return encoder(x), encoder(s)
228 | def AE_Step(pos_1, pos_2, models, optimizer):
229 | norm_encoder, norm_decoder = models['norm_encoder'], models['norm_decoder']
230 | encoder, decoder = models['encoder'], models['decoder']
231 |
232 | optimizer.zero_grad()
233 | loss = 0.5*AE_loss(pos_1, pos_2, norm_encoder, norm_decoder) +\
234 | 0.5*AE_loss(pos_1, pos_2, encoder, decoder)
235 | loss.backward()
236 | optimizer.step()
237 |
238 | return loss.item()
239 |
240 |
241 | # Maximization
242 | def MI_Estimator(ft1, ft2, model):
243 | '''
244 | ft1, ft2: r.v.s
245 | model: takes ft1 and ft2, output batch_size x batch_size
246 | '''
247 | scores = model(ft1, ft2)
248 |
249 | # optimal critic f(ft1, ft2) = log {p(ft1, ft2)/p(ft1)p(ft2)}
250 | def js_fgan_lower_bound_obj(scores):
251 | """Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016)."""
252 | scores_diag = scores.diag()
253 | first_term = -F.softplus(-scores_diag).mean()
254 | n = scores.size(0)
255 | second_term = (torch.sum(F.softplus(scores)) -
256 | torch.sum(F.softplus(scores_diag))) / (n * (n - 1.))
257 | return first_term - second_term
258 | # if the input is in log form
259 | def direct_log_density_ratio_mi(scores):
260 | return scores.diag().mean()
261 |
262 | train_val = js_fgan_lower_bound_obj(scores)
263 | eval_val = direct_log_density_ratio_mi(scores)
264 |
265 | with torch.no_grad():
266 | eval_train = eval_val - train_val
267 |
268 | return train_val + eval_train
269 |
270 |
271 | # Minimization
272 | def Conditional_Entropy(ft1, ft2, model):
273 | '''
274 | Calculating H(ft2|ft1) by min_Q H[P(ft2|ft1), Q(ft2|ft1)]
275 | ft1: discrete or continuous
276 | ft2: continuous (k-dim.)
277 | We assume Q(ft2|ft1) is Gaussian.
278 | model (Q): takes ft1, out the reconstructed ft2
279 | '''
280 | hat_ft2 = model(ft1)
281 |
282 | # sigma = l2_norm of ft2 (we let it be 1)
283 | # when Q = Normal(mu(ft1), sigma^2I) -> -logQ = log(sqrt((2*pi)^k sigma^(2k))) +
284 | # 0.5*1/(sigma^2)*(y-mu)^T(y-mu)
285 | # H[P(ft2|(ft1), Q(ft2|(ft1)] = E_{P_{(ft1,ft2}} [-logQ]
286 | dim = ft2.shape[1]
287 | bsz = ft2.shape[0]
288 |
289 | #cond_entropy = 0.5*dim*math.log(2*math.pi) + 0.5*(F.mse_loss(hat_ft2, ft2, reduction='sum')/bsz)
290 | #return cond_entropy
291 | scaled_cond_entropy = F.mse_loss(hat_ft2, ft2, reduction='sum')/bsz
292 | return scaled_cond_entropy
293 |
294 |
295 |
296 | def Information_Step(pos_1, pos_2, t, models, optimizer, zx=None, zs=None):
297 | feat_encoder, norm_encoder, encoder,\
298 | mi_z_z_model, mi_z_t_model, cond_z_t_model, cond_z_z_model =\
299 | models['feat_encoder'], models['norm_encoder'], models['encoder'], models['mi_z_z_model'],\
300 | models['mi_z_t_model'], models['cond_z_t_model'], models['cond_z_z_model']
301 |
302 | ae_zx, ae_zs = Enc_z(pos_1, pos_2, encoder)
303 | if zx is None and zs is None:
304 | if feat_encoder is not None:
305 | I_zx, I_zs = Enc_z(pos_1, pos_2, feat_encoder)
306 | else:
307 | I_zx, I_zs = Enc_z(pos_1, pos_2, norm_encoder)
308 | H_zx, H_zs = Enc_z(pos_1, pos_2, norm_encoder)
309 | else:
310 | I_zx, I_zs = zx, zs
311 | H_zx, H_zs = zx, zs
312 |
313 | I_Z_T = 0.5*MI_Estimator(I_zx, t, mi_z_t_model) +\
314 | 0.5*MI_Estimator(I_zs, t, mi_z_t_model)
315 |
316 | I_Z_S = 0.5*MI_Estimator(I_zx, ae_zs, mi_z_z_model) +\
317 | 0.5*MI_Estimator(I_zs, ae_zx, mi_z_z_model)
318 |
319 | H_Z_T = 0.5*Conditional_Entropy(t, H_zx, cond_z_t_model) +\
320 | 0.5*Conditional_Entropy(t, H_zs, cond_z_t_model)
321 |
322 | H_Z_S = 0.5*Conditional_Entropy(ae_zs, H_zx, cond_z_z_model) +\
323 | 0.5*Conditional_Entropy(ae_zx, H_zs, cond_z_z_model)
324 |
325 | optimizer.zero_grad()
326 | loss = -I_Z_S - I_Z_T + H_Z_T + H_Z_S
327 | loss.backward()
328 | optimizer.step()
329 |
330 | return I_Z_S.item(), I_Z_T.item(), H_Z_T.item(), H_Z_S.item()
331 |
332 |
333 | # ## Script for Calculating I(X;S), I(X;S|T), I(X;T), H(X|T), H(X|S)
334 |
335 | if __name__ == '__main__':
336 | parser = argparse.ArgumentParser(description='Train SimCLR')
337 | parser.add_argument('--batch_size', default=482, type=int, help='Number of images in each mini-batch\
338 | (964/2=482 for omniglot and 512 for cifar)')
339 | parser.add_argument('--epochs', default=23000, type=int, help='Number of sweeps over the dataset to train')
340 | parser.add_argument('--stage', default='AE', type=str, help='AE or Raw_Information')
341 | parser.add_argument('--additional_encoder', default=False, action='store_true')
342 |
343 | # args parse
344 | args = parser.parse_args()
345 | batch_size, epochs, stage, additional_encoder = args.batch_size, args.epochs, args.stage, args.additional_encoder
346 |
347 | train_data = utils.Our_Omniglot(root='data', background=True, transform=utils.omniglot_train_transform,
348 | character_target_transform=None, alphabet_target_transform=None, download=True,
349 | contrast_training=True)
350 |
351 |
352 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,
353 | drop_last=True)
354 |
355 | # model setup and optimizer config
356 | models, optimizer = init_models(stage, additional_encoder)
357 |
358 | if stage == 'Raw_Information':
359 | results = {'I(X;S)': [], 'I(X;T)': [], 'H(X|T)': [], 'H(X|S)': []}
360 |
361 | for epoch in range(1, epochs + 1):
362 | train_bar = tqdm(train_loader)
363 | total_num = 0
364 | if stage == 'Raw_Information':
365 | I_X_S_total, I_X_T_total, H_X_T_total, H_X_S_total =\
366 | 0.0, 0.0, 0.0, 0.0
367 | elif stage == 'AE':
368 | AE_Loss_total = 0.0
369 |
370 | for pos_1, pos_2, target in train_bar:
371 | pos_1, pos_2, target = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True),\
372 | target.cuda(non_blocking=True)
373 | total_num += 1
374 | if stage == 'Raw_Information':
375 | I_X_S, I_X_T, H_X_T, H_X_S = \
376 | Information_Step(pos_1, pos_2, target, models, optimizer, zx=None, zs=None)
377 |
378 | I_X_S_total += I_X_S
379 | I_X_T_total += I_X_T
380 | H_X_T_total += H_X_T
381 | H_X_S_total += H_X_S
382 | elif stage == 'AE':
383 | AE_Loss_total += AE_Step(pos_1, pos_2, models, optimizer)
384 |
385 | if stage == 'Raw_Information':
386 | print('Epoch: {}, I(X;S): {}, I(X;T): {}, H(X|T): {}, H(X|S): {}'\
387 | .format(epoch, I_X_S_total / total_num, I_X_T_total / total_num,\
388 | H_X_T_total / total_num, H_X_S_total / total_num))
389 |
390 | results['I(X;S)'].append(I_X_S_total / total_num)
391 | results['I(X;T)'].append(I_X_T_total / total_num)
392 | results['H(X|T)'].append(H_X_T_total / total_num)
393 | results['H(X|S)'].append(H_X_S_total / total_num)
394 |
395 | # save statistics
396 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
397 |
398 | if additional_encoder:
399 | data_frame.to_csv('results/Raw_Information_additional_encoder.csv', index_label='epoch')
400 | else:
401 | data_frame.to_csv('results/Raw_Information.csv', index_label='epoch')
402 | elif stage == 'AE':
403 | print('Epoch: {}, AE_Loss: {}'.format(epoch, AE_Loss_total / total_num))
404 |
405 | # save encoder
406 | torch.save(models['norm_encoder'].state_dict(), 'results/norm_encoder.pth')
407 | torch.save(models['encoder'].state_dict(), 'results/encoder.pth')
408 |
409 |
410 | def information(epoch, train_loader, inner_epochs, net, models, optimizer):
411 | net.eval()
412 | I_Z_S_total, I_Z_T_total, H_Z_T_total, H_Z_S_total =\
413 | 0.0, 0.0, 0.0, 0.0
414 | total_num = 0
415 |
416 | for _in in range(inner_epochs):
417 | train_bar = tqdm(train_loader)
418 | for pos_1, pos_2, target in train_bar:
419 | pos_1, pos_2, target = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True),\
420 | target.cuda(non_blocking=True)
421 | total_num += 1
422 | zx, _ = net(pos_1)
423 | zs, _ = net(pos_2)
424 | I_Z_S, I_Z_T, H_Z_T, H_Z_S = \
425 | Information_Step(pos_1, pos_2, target, models, optimizer, zx=zx, zs=zs)
426 |
427 | I_Z_S_total += I_Z_S
428 | I_Z_T_total += I_Z_T
429 | H_Z_T_total += H_Z_T
430 | H_Z_S_total += H_Z_S
431 |
432 | print('Epoch: {}, Inner Epoch: {}, I(Z;S): {}, I(Z;T): {}, H(Z|T): {}, H(Z|S): {}'\
433 | .format(epoch, _in, I_Z_S, I_Z_T, H_Z_T, H_Z_S))
434 |
435 | return I_Z_S_total/total_num , I_Z_T_total/total_num, H_Z_T_total/total_num,\
436 | H_Z_S_total/total_num
437 |
--------------------------------------------------------------------------------
/Omniglot/linear.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import pandas as pd
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | from thop import profile, clever_format
9 | from torch.utils.data import DataLoader
10 | from torchvision.datasets import CIFAR10
11 | from tqdm import tqdm
12 |
13 | import utils
14 | from model import Model, Omniglot_Model
15 |
16 |
17 | class Net(nn.Module):
18 | def __init__(self, num_class, pretrained_path, resnet_depth=18):
19 | super(Net, self).__init__()
20 |
21 | if resnet_depth == 18:
22 | resnet_output_dim = 512
23 | elif resnet_depth == 34:
24 | resnet_output_dim = 512
25 | elif resnet_depth == 50:
26 | resnet_output_dim = 2048
27 |
28 | # encoder
29 | self.f = Model(resnet_depth=resnet_depth).f
30 | # classifier
31 | self.fc = nn.Linear(resnet_output_dim, num_class, bias=True)
32 | self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)
33 |
34 | def forward(self, x):
35 | x = self.f(x)
36 | feature = torch.flatten(x, start_dim=1)
37 | out = self.fc(feature)
38 | return out
39 |
40 |
41 | class Omniglot_Net(nn.Module):
42 | def __init__(self, num_class, pretrained_path):
43 | super(Omniglot_Net, self).__init__()
44 |
45 | # encoder
46 | self.f = Omniglot_Model().f
47 | # classifier
48 | self.fc = nn.Sequential(
49 | nn.Linear(1024, num_class, bias=False),
50 | #nn.BatchNorm1d(256),
51 | #nn.ReLU(inplace=True),
52 | #nn.Linear(256, 256, bias=True),
53 | #nn.BatchNorm1d(256),
54 | #nn.ReLU(inplace=True),
55 | #nn.Linear(256, num_class, bias=True)
56 | )
57 |
58 |
59 |
60 | self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)
61 |
62 | def forward(self, x):
63 | feature = self.f(x)
64 | feature = F.normalize(feature, dim=-1)
65 | out = self.fc(feature)
66 | return out
67 |
68 |
69 | # train or test for one epoch
70 | def train_val(net, data_loader, train_optimizer):
71 | is_train = train_optimizer is not None
72 | net.train() if is_train else net.eval()
73 |
74 | total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
75 | with (torch.enable_grad() if is_train else torch.no_grad()):
76 | for data, target in data_bar:
77 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
78 | out = net(data)
79 | loss = loss_criterion(out, target)
80 |
81 | if is_train:
82 | train_optimizer.zero_grad()
83 | loss.backward()
84 | train_optimizer.step()
85 |
86 | total_num += data.size(0)
87 | total_loss += loss.item() * data.size(0)
88 | prediction = torch.argsort(out, dim=-1, descending=True)
89 | total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
90 | total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
91 |
92 | data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%'
93 | .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num,
94 | total_correct_1 / total_num * 100, total_correct_5 / total_num * 100))
95 |
96 | return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100
97 |
98 |
99 | if __name__ == '__main__':
100 | parser = argparse.ArgumentParser(description='Linear Evaluation')
101 | parser.add_argument('--model_path', type=str, default='results/128_0.5_200_512_500_model.pth',
102 | help='The pretrained model path')
103 | parser.add_argument('--batch_size', type=int, default=512, help='Number of images in each mini-batch')
104 | parser.add_argument('--epochs', type=int, default=1000, help='Number of sweeps over the dataset to train')
105 | parser.add_argument('--resnet_depth', default=18, type=int, help='The depth of the resnet')
106 | parser.add_argument('--dataset', default='omniglot', type=str, help='omniglot or cifar')
107 |
108 | args = parser.parse_args()
109 | model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs
110 | if args.dataset == 'cifar':
111 | resnet_depth = args.resnet_depth
112 |
113 | train_data = CIFAR10(root='data', train=True, transform=utils.train_transform, download=True)
114 | test_data = CIFAR10(root='data', train=False, transform=utils.test_transform, download=True)
115 | else:
116 | train_data = utils.Our_Omniglot(root='data', background=False, transform=utils.omniglot_train_transform,
117 | character_target_transform=None, alphabet_target_transform=None, download=True,
118 | eval_split_train=True, out_character=False, contrast_training=False)
119 | test_data = utils.Our_Omniglot(root='data', background=False, transform=utils.omniglot_test_transform,
120 | character_target_transform=None, alphabet_target_transform=None, download=True,
121 | eval_split_train=False, out_character=False, contrast_training=False)
122 |
123 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
124 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
125 |
126 |
127 | if args.dataset == 'cifar':
128 | model = Net(num_class=len(train_data.classes), pretrained_path=model_path, resnet_depth=resnet_depth ).cuda()
129 | else:
130 | #model = Omniglot_Net(num_class=20, pretrained_path=model_path).cuda()
131 | model = Omniglot_Net(num_class=659, pretrained_path=model_path).cuda()
132 | for param in model.f.parameters():
133 | param.requires_grad = False
134 |
135 | if args.dataset == 'cifar':
136 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),))
137 | else:
138 | flops, params = profile(model, inputs=(torch.randn(1, 1, 28, 28).cuda(),))
139 | flops, params = clever_format([flops, params])
140 | print('# Model Params: {} FLOPs: {}'.format(params, flops))
141 | optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)
142 | #optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=1e-6)
143 | loss_criterion = nn.CrossEntropyLoss()
144 | results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [],
145 | 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []}
146 |
147 | save_name_pre = model_path.split('.pth')[0]
148 |
149 | best_acc = 0.0
150 | for epoch in range(1, epochs + 1):
151 | train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer)
152 | results['train_loss'].append(train_loss)
153 | results['train_acc@1'].append(train_acc_1)
154 | results['train_acc@5'].append(train_acc_5)
155 | test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None)
156 | results['test_loss'].append(test_loss)
157 | results['test_acc@1'].append(test_acc_1)
158 | results['test_acc@5'].append(test_acc_5)
159 | # save statistics
160 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
161 | data_frame.to_csv('{}_linear_statistics.csv'.format(save_name_pre), index_label='epoch')
162 | if test_acc_1 > best_acc:
163 | best_acc = test_acc_1
164 | torch.save(model.state_dict(), '{}_linear_model.pth'.format(save_name_pre))
165 |
--------------------------------------------------------------------------------
/Omniglot/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import pandas as pd
5 | import torch
6 | import torch.optim as optim
7 | from thop import profile, clever_format
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 | import torch.nn.functional as F
11 | import math
12 |
13 | import utils
14 | from model import Model, Omniglot_Model, Recon_Omniglot_Model
15 |
16 | from compute_MI_CondEntro import init_models, information
17 |
18 |
19 | def contrastive_loss(out_1, out_2, _type='NCE'):
20 | # compute loss
21 | if _type == 'NCE':
22 | # [2*B, D]
23 | out = torch.cat([out_1, out_2], dim=0)
24 | # [2*B, 2*B]
25 | sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
26 | mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()
27 | # [2*B, 2*B-1]
28 | sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)
29 |
30 | pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
31 | # [2*B]
32 | pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
33 | loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()
34 | elif _type == 'JS':
35 | temperature_JS = temperature
36 |
37 | scores_12 = torch.mm(out_1, out_2.t().contiguous()) / temperature_JS
38 | first_term = -F.softplus(-scores_12.diag()).mean()
39 |
40 | n = scores_12.size(0)
41 | second_term_12 = (torch.sum(F.softplus(scores_12)) -
42 | torch.sum(F.softplus(scores_12.diag()))) / (n * (n - 1.))
43 | scores_11 = torch.mm(out_1, out_1.t().contiguous()) / temperature_JS
44 | second_term_11 = (torch.sum(F.softplus(scores_11)) -
45 | torch.sum(F.softplus(scores_11.diag()))) / (n * (n - 1.))
46 | scores_22 = torch.mm(out_2, out_2.t().contiguous()) / temperature_JS
47 | second_term_22 = (torch.sum(F.softplus(scores_22)) -
48 | torch.sum(F.softplus(scores_22.diag()))) / (n * (n - 1.))
49 | second_term = (second_term_11 + second_term_22 + second_term_12*2.) / 4.
50 | loss = -1. * (first_term - second_term)
51 | return loss
52 |
53 |
54 | def inverse_perdictive_loss(feature_1, feature_2):
55 | # symmetric
56 | return F.mse_loss(feature_1, feature_2) + F.mse_loss(feature_2, feature_1)
57 |
58 |
59 | # Use MSE_Loss here (assuming Gaussian)
60 | # Other losses can be binary cross_entropy is assuming Bernoulli
61 | def forward_predictive_loss(target, recon, _type='RevBCE'):
62 | if _type == 'RevBCE':
63 | # empirically good
64 | return F.binary_cross_entropy_with_logits(target, recon)
65 | elif _type == 'BCE':
66 | # assuming factorized bernoulli
67 | return F.binary_cross_entropy_with_logits(recon, target)
68 | elif _type == 'MSE':
69 | # assuming diagnonal Gaussian
70 | # target has [0,1]
71 | # change it to [-\infty, \infty]
72 | # inverse of sigmoid: x = ln(y/(1-y))
73 | target = target.clamp(min=1e-4, max=1. - 1e-4)
74 | target = torch.log(target / (1.-target))
75 | return F.mse_loss(recon, target)
76 |
77 |
78 | # train for one epoch to learn unique features
79 | def train(epoch, net, data_loader, train_optimizer, recon_net, loss_type, recon_optimizer, info_dct=None):
80 | net.train()
81 | if recon_net is not None:
82 | recon_net.train()
83 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
84 |
85 | for pos_1, pos_2, target in train_bar:
86 | pos_1, pos_2, target = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True),\
87 | target.cuda(non_blocking=True)
88 | #norm = False if loss_type == 7 else True
89 | norm = True
90 |
91 | feature_1, out_1 = net(pos_1, norm)
92 | feature_2, out_2 = net(pos_2, norm)
93 |
94 | if loss_type == 1 or loss_type == 2 or loss_type == 5 or loss_type == 6 or loss_type == 7\
95 | or loss_type == 11 or loss_type == 12:
96 | contr_type = 'NCE' if not loss_type == 7 else 'JS'
97 | contr_loss = contrastive_loss(out_1, out_2, _type = contr_type)
98 | if loss_type == 2 or loss_type == 4 or loss_type == 6 or loss_type == 10\
99 | or loss_type == 12:
100 | inver_loss = inverse_perdictive_loss(feature_1, feature_2)
101 | if loss_type == 3 or loss_type == 4 or loss_type == 5 or loss_type == 6 or loss_type == 8\
102 | or loss_type == 9 or loss_type == 10 or loss_type == 11 or loss_type == 12:
103 | recon_for_2 = recon_net(feature_1)
104 | recon_for_1 = recon_net(feature_2)
105 | if loss_type == 3:
106 | recon_type = 'BCE'
107 | elif loss_type == 4 or loss_type == 5 or loss_type == 6 or loss_type == 8:
108 | recon_type = 'RevBCE'
109 | elif loss_type == 9 or loss_type == 10 or loss_type == 11 or loss_type == 12:
110 | recon_type = 'MSE'
111 | recon_loss = forward_predictive_loss(pos_1, recon_for_1, _type=recon_type) +\
112 | forward_predictive_loss(pos_2, recon_for_2, _type=recon_type)
113 | recon_optimizer.zero_grad()
114 | train_optimizer.zero_grad()
115 |
116 | if loss_type == 1 or loss_type == 7:
117 | loss = contr_loss
118 | elif loss_type == 2:
119 | loss = contr_loss + inver_param*inver_loss
120 | elif loss_type == 3 or loss_type == 8 or loss_type == 9:
121 | loss = recon_param*recon_loss
122 | elif loss_type == 4 or loss_type == 10:
123 | loss = recon_param*recon_loss + inver_param*inver_loss
124 | elif loss_type == 5 or loss_type == 11:
125 | loss = contr_loss + recon_param*recon_loss
126 | elif loss_type == 6 or loss_type == 12:
127 | loss = contr_loss + recon_param*recon_loss +\
128 | inver_param*inver_loss
129 |
130 |
131 | loss.backward()
132 |
133 | train_optimizer.step()
134 | if recon_optimizer is not None:
135 | recon_optimizer.step()
136 |
137 | total_num += batch_size
138 | total_loss += loss.item() * batch_size
139 | train_bar.set_description('Train Epoch: [{}/{}], loss_type: {}, Loss: {:.4f}'.format(\
140 | epoch, epochs, loss_type, total_loss / total_num))
141 |
142 |
143 |
144 | if info_dct is not None:
145 | inner_epochs = 200 if epoch < 100 else 80
146 | I_Z_S, I_Z_T, H_Z_T, H_Z_S = information(epoch, data_loader, inner_epochs, net,\
147 | info_dct['info_models'], info_dct['info_optimizer'])
148 |
149 | info_dct['info_results']['I(Z;S)'].append(I_Z_S)
150 | info_dct['info_results']['I(Z;T)'].append(I_Z_T)
151 | info_dct['info_results']['H(Z|T)'].append(H_Z_T)
152 | info_dct['info_results']['H(Z|S)'].append(H_Z_S)
153 |
154 | return total_loss / total_num
155 |
156 |
157 | # test for one epoch, use weighted knn to find the most similar images' label to assign the test image
158 | def test(net, memory_data_loader, test_data_loader):
159 | net.eval()
160 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
161 | with torch.no_grad():
162 | # generate feature bank
163 | for data, _, target in tqdm(memory_data_loader, desc='Feature extracting'):
164 | feature, out = net(data.cuda(non_blocking=True))
165 | feature_bank.append(feature)
166 | # [D, N]
167 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
168 | # [N]
169 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
170 | # loop test data to predict the label by weighted knn search
171 | test_bar = tqdm(test_data_loader)
172 | for data, _, target in test_bar:
173 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
174 | feature, out = net(data)
175 |
176 | total_num += data.size(0)
177 | # compute cos similarity between each feature vector and feature bank ---> [B, N]
178 | sim_matrix = torch.mm(feature, feature_bank)
179 | # [B, K]
180 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)
181 | # [B, K]
182 | sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)
183 | sim_weight = (sim_weight / temperature).exp()
184 |
185 | # counts for each class
186 | one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)
187 | # [B*K, C]
188 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
189 | # weighted score ---> [B, C]
190 | pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)
191 |
192 | pred_labels = pred_scores.argsort(dim=-1, descending=True)
193 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
194 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
195 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%'
196 | .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100))
197 |
198 | return total_top1 / total_num * 100, total_top5 / total_num * 100
199 |
200 |
201 | # test for one epoch, use weighted knn to find the most similar images' label to assign the test image
202 | def omniglot_test(net, memory_data_loader, test_data_loader):
203 | net.eval()
204 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
205 | with torch.no_grad():
206 | # generate feature bank
207 | for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
208 | feature, out = net(data.cuda(non_blocking=True))
209 | feature_bank.append(feature)
210 | # [D, N]
211 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
212 | # [N]
213 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
214 | # loop test data to predict the label by weighted knn search
215 | test_bar = tqdm(test_data_loader)
216 | for data, target in test_bar:
217 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
218 | feature, out = net(data)
219 |
220 | total_num += data.size(0)
221 | # compute cos similarity between each feature vector and feature bank ---> [B, N]
222 | sim_matrix = torch.mm(feature, feature_bank)
223 | # [B, K]
224 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)
225 | # [B, K]
226 | sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)
227 | sim_weight = (sim_weight / temperature).exp()
228 |
229 | # counts for each class
230 | one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)
231 | # [B*K, C]
232 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
233 | # weighted score ---> [B, C]
234 | pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)
235 |
236 | pred_labels = pred_scores.argsort(dim=-1, descending=True)
237 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
238 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
239 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%'
240 | .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100))
241 |
242 | return total_top1 / total_num * 100, total_top5 / total_num * 100
243 |
244 |
245 | if __name__ == '__main__':
246 | parser = argparse.ArgumentParser(description='Omniglot Experiments')
247 | parser.add_argument('--temperature', default=0.1, type=float, help='Temperature used in softmax\
248 | (0.1 for omniglot and 0.5 for cifar)')
249 | parser.add_argument('--k', default=1, type=int, help='Top k most similar images used to predict the label\
250 | (1 for omniglot and 200 for cifar)')
251 | parser.add_argument('--batch_size', default=482, type=int, help='Number of images in each mini-batch\
252 | (964/2=482 for omniglot and 512 for cifar)')
253 | parser.add_argument('--epochs', default=1000, type=int, help='Number of sweeps over the dataset to train')
254 | parser.add_argument('--resnet_depth', default=18, type=int, help='The depth of the resnet\
255 | (only for cifar)')
256 | parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for latent vector\
257 | (only for cifar)')
258 | parser.add_argument('--dataset', default='omniglot', type=str, help='omniglot or cifar')
259 | parser.add_argument('--trial', default=99, type=int, help='number of trial')
260 | parser.add_argument('--loss_type', default=1, type=int, help='1: only contrast (NCE),\
261 | 2: contrast (NCE) + inverse_pred, 3: only forward_pred (BCE),\
262 | 4: forward_pred (RevBCE) + inverse_pred, 5: contrast (NCE) + forward_pred (RevBCE),\
263 | 6: contrast (NCE) + forward_pred (RevBCE) + inverse_pred,\
264 | 7: only contrast (JS), 8: only forward_pred (RevBCE),\
265 | 9: only forward_pred (MSE), 10: forward_pred (MSE) + inverse_pred,\
266 | 11: contrast (NCE) + forward_pred (MSE),\
267 | 12: contrast (NCE) + forward_pred (MSE) + inverse_pred')
268 | parser.add_argument('--inver_param', default=0.001, type=float, help='Hyper_param for inverse_pred')
269 | parser.add_argument('--recon_param', default=0.001, type=float, help='Hyper_param for forward_pred')
270 | parser.add_argument('--with_info', default=False, action='store_true')
271 |
272 | # args parse
273 | args = parser.parse_args()
274 | feature_dim, temperature, k = args.feature_dim, args.temperature, args.k
275 | batch_size, epochs = args.batch_size, args.epochs
276 | resnet_depth = args.resnet_depth
277 | trial = args.trial
278 |
279 | recon_param = args.recon_param
280 | inver_param = args.inver_param
281 |
282 | # data prepare
283 | if args.dataset == 'cifar':
284 | train_data = utils.CIFAR10Pair(root='data', train=True, transform=utils.train_transform, download=True)
285 | else:
286 | # our self-supervised signal construction strategy
287 | train_data = utils.Our_Omniglot(root='data', background=True, transform=utils.omniglot_train_transform,
288 | character_target_transform=None, alphabet_target_transform=None, download=True,
289 | contrast_training=True)
290 | # self-supervised signal construction strategy in SimCLR
291 | #train_data = utils.Our_Omniglot_v2(root='data', background=True, transform=utils.omniglot_train_transform,
292 | # character_target_transform=None, alphabet_target_transform=None, download=True,
293 | # contrast_training=True)
294 |
295 |
296 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,
297 | drop_last=True)
298 |
299 | if args.dataset == 'cifar':
300 | memory_data = utils.CIFAR10Pair(root='data', train=True, transform=utils.test_transform, download=True)
301 | else:
302 | memory_data = utils.Our_Omniglot(root='data', background=False, transform=utils.omniglot_test_transform,
303 | character_target_transform=None, alphabet_target_transform=None, download=True,
304 | eval_split_train=True, out_character=True, contrast_training=False)
305 | memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
306 |
307 |
308 | if args.dataset == 'cifar':
309 | test_data = utils.CIFAR10Pair(root='data', train=False, transform=utils.test_transform, download=True)
310 | else:
311 | test_data = utils.Our_Omniglot(root='data', background=False, transform=utils.omniglot_test_transform,
312 | character_target_transform=None, alphabet_target_transform=None, download=True,
313 | eval_split_train=False, out_character=True, contrast_training=False)
314 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
315 |
316 | # calculating information
317 | if args.with_info:
318 | info_models, info_optimizer = init_models('Feature_Information')
319 | info_results = {'I(Z;S)': [], 'I(Z;T)': [], 'H(Z|T)': [], 'H(Z|S)': []}
320 | info_dct = {
321 | 'info_models': info_models,
322 | 'info_optimizer': info_optimizer,
323 | 'info_results': info_results,
324 | }
325 | else:
326 | info_dct = None
327 |
328 | # model setup and optimizer config
329 | if args.dataset == 'cifar':
330 | model = Model(feature_dim, resnet_depth=resnet_depth).cuda()
331 | recon_model = None
332 | else:
333 | model = Omniglot_Model().cuda()
334 | recon_model = Recon_Omniglot_Model().cuda() if args.loss_type >= 3 else None
335 |
336 | if args.dataset == 'cifar':
337 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),))
338 | else:
339 | flops, params = profile(model, inputs=(torch.randn(1, 1, 28, 28).cuda(),))
340 | #flops, params = profile(model, inputs=(torch.randn(1, 1, 56, 56).cuda(),))
341 | #flops, params = profile(model, inputs=(torch.randn(1, 1, 105, 105).cuda(),))
342 | if recon_model is not None:
343 | recon_flops, recon_params = profile(recon_model, inputs=(torch.randn(1, 1024).cuda(),))
344 | flops, params = clever_format([flops, params])
345 | print('# Model Params: {} FLOPs: {}'.format(params, flops))
346 | if recon_model is not None:
347 | recon_flops, recon_params = clever_format([recon_flops, recon_params])
348 | print('# Recon_Model Params: {} FLOPs: {}'.format(recon_params, recon_flops))
349 | optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
350 | recon_optimizer = optim.Adam(recon_model.parameters(), lr=1e-3, weight_decay=1e-6) if recon_model is not None \
351 | else None
352 | #optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=1e-6)
353 | #milestone1, milestone2 = int(args.epochs*0.4), int(args.epochs*0.7)
354 | #lr_decay = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[milestone1, milestone2], gamma=0.1)
355 | if args.dataset == 'cifar':
356 | c = len(memory_data.classes)
357 | else:
358 | c = 659 #c = 20
359 |
360 | # training loop
361 | results = {'train_loss': [], 'test_acc@1': [], 'test_acc@5': []}
362 | if args.dataset == 'cifar':
363 | save_name_pre = '{}_{}_{}_{}_{}_{}_{}'.format(args.dataset, resnet_depth, feature_dim, temperature, k, \
364 | batch_size, epochs)
365 | else:
366 | save_name_pre = '{}_{}_{}_{}_{}'.format(args.dataset, args.loss_type, recon_param, inver_param, trial)
367 | if not os.path.exists('results'):
368 | os.mkdir('results')
369 | best_acc = 0.0
370 | for epoch in range(1, epochs + 1):
371 | if(epoch%4)==1:
372 | train_loss = train(epoch, model, train_loader, optimizer, recon_model, args.loss_type, recon_optimizer, info_dct)
373 | else:
374 | train_loss = train(epoch, model, train_loader, optimizer, recon_model, args.loss_type, recon_optimizer, None)
375 | #lr_decay.step()
376 | results['train_loss'].append(train_loss)
377 |
378 | if args.dataset == 'cifar':
379 | test_acc_1, test_acc_5 = test(model, memory_loader, test_loader)
380 | else:
381 | test_acc_1, test_acc_5 = omniglot_test(model, memory_loader, test_loader)
382 |
383 | results['test_acc@1'].append(test_acc_1)
384 | results['test_acc@5'].append(test_acc_5)
385 |
386 | # save statistics
387 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
388 | data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch')
389 |
390 | if (epoch%4)==1 and info_dct is not None:
391 | info_data_frame = pd.DataFrame(data=info_dct['info_results'], index=range(1, epoch + 1, 4))
392 | if args.loss_type==1:
393 | info_data_frame.to_csv('results/Feature_Information.csv', index_label=info_dct['epoch'])
394 | elif args.loss_type==2:
395 | info_data_frame.to_csv('results/Feature_Information_min_H.csv', index_label=info_dct['epoch'])
396 |
397 | if args.dataset == 'cifar':
398 | if test_acc_1 > best_acc:
399 | best_acc = test_acc_1
400 | torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre))
401 |
--------------------------------------------------------------------------------
/Omniglot/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.models.resnet import resnet18, resnet34, resnet50
5 |
6 |
7 | class Model(nn.Module):
8 | def __init__(self, feature_dim=128, resnet_depth=18):
9 | super(Model, self).__init__()
10 |
11 | self.f = []
12 | if resnet_depth == 18:
13 | my_resnet = resnet18()
14 | resnet_output_dim = 512
15 | elif resnet_depth == 34:
16 | my_resnet = resnet34()
17 | resnet_output_dim = 512
18 | elif resnet_depth == 50:
19 | my_resnet = resnet50()
20 | resnet_output_dim = 2048
21 |
22 | for name, module in my_resnet.named_children():
23 | if name == 'conv1':
24 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
25 | if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
26 | self.f.append(module)
27 | # encoder
28 | self.f = nn.Sequential(*self.f)
29 | # projection head
30 | self.g = nn.Sequential(nn.Linear(resnet_output_dim, 512, bias=False), nn.BatchNorm1d(512),
31 | nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))
32 |
33 | def forward(self, x):
34 | x = self.f(x)
35 | feature = torch.flatten(x, start_dim=1)
36 | out = self.g(feature)
37 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)
38 |
39 |
40 | # for 105x105 size
41 | '''
42 | class Omniglot_Model(nn.Module):
43 | def __init__(self):
44 | super(Omniglot_Model, self).__init__()
45 | # encoder
46 | self.f = nn.Sequential(
47 | nn.Conv2d(in_channels=1, out_channels=64, kernel_size=10, stride=1, padding=0), # out: 96
48 | nn.BatchNorm2d(num_features=64),
49 | nn.ReLU(inplace=True),
50 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 48
51 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=0), # out: 42
52 | nn.BatchNorm2d(num_features=128),
53 | nn.ReLU(inplace=True),
54 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 21
55 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=1, padding=0), # out: 18
56 | nn.BatchNorm2d(num_features=128),
57 | nn.ReLU(inplace=True),
58 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 9
59 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=1, padding=0), # out: 6
60 | nn.BatchNorm2d(num_features=128),
61 | nn.ReLU(inplace=True),
62 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 3
63 | nn.Flatten(),
64 | nn.Linear(9*128, 1024),
65 | )
66 | # projection head
67 | self.g = nn.Identity()
68 |
69 | def forward(self, x):
70 | feature = self.f(x)
71 | out = self.g(feature)
72 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)
73 | '''
74 |
75 |
76 | # for 28x28 size (using max_pool)
77 | class Omniglot_Model(nn.Module):
78 | def __init__(self):
79 | super(Omniglot_Model, self).__init__()
80 | # encoder
81 | self.f = nn.Sequential(
82 | nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28
83 | nn.BatchNorm2d(num_features=128),
84 | nn.ReLU(inplace=True),
85 | #nn.MaxPool2d(kernel_size=2, stride=2),
86 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28
87 | nn.BatchNorm2d(num_features=128),
88 | nn.ReLU(inplace=True),
89 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 14
90 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 14
91 | nn.BatchNorm2d(num_features=128),
92 | nn.ReLU(inplace=True),
93 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 7
94 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 7
95 | nn.BatchNorm2d(num_features=128),
96 | nn.ReLU(inplace=True),
97 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 3
98 | nn.Flatten(),
99 | nn.Linear(9*128, 1024),
100 | )
101 | # projection head
102 | self.g = nn.Identity()
103 |
104 | def forward(self, x, norm=True):
105 | feature = self.f(x)
106 | out = self.g(feature)
107 | if norm:
108 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)
109 | else:
110 | return F.normalize(feature, dim=-1), out
111 |
112 |
113 | # for 28x28 size (not using maxpool)
114 | '''
115 | class Omniglot_Model(nn.Module):
116 | def __init__(self):
117 | super(Omniglot_Model, self).__init__()
118 | # encoder
119 | self.f = nn.Sequential(
120 | nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28
121 | nn.BatchNorm2d(num_features=128),
122 | nn.ReLU(inplace=True),
123 | #nn.MaxPool2d(kernel_size=2, stride=2),
124 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1), # out: 14
125 | nn.BatchNorm2d(num_features=128),
126 | nn.ReLU(inplace=True),
127 | #nn.MaxPool2d(kernel_size=2, stride=2), # out: 14
128 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1), # out: 7
129 | nn.BatchNorm2d(num_features=128),
130 | nn.ReLU(inplace=True),
131 | #nn.MaxPool2d(kernel_size=2, stride=2), # out: 7
132 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=0), # out: 3
133 | nn.BatchNorm2d(num_features=128),
134 | nn.ReLU(inplace=True),
135 | #nn.MaxPool2d(kernel_size=2, stride=2), # out: 3
136 | nn.Flatten(),
137 | nn.Linear(9*128, 1024),
138 | )
139 | # projection head
140 | self.g = nn.Identity()
141 |
142 | def forward(self, x):
143 | feature = self.f(x)
144 | out = self.g(feature)
145 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)
146 | '''
147 |
148 |
149 | # +
150 | # for 28x28 size
151 | class Lambda(nn.Module):
152 | def __init__(self, func):
153 | super(Lambda, self).__init__()
154 | self.func = func
155 |
156 | def forward(self, x):
157 | return self.func(x)
158 |
159 | class Recon_Omniglot_Model(nn.Module):
160 | def __init__(self):
161 | super(Recon_Omniglot_Model, self).__init__()
162 | # reconstructer (approximately the inverse of the encoder)
163 | self.f = nn.Sequential(
164 | nn.Linear(1024, 9*128, bias=False),
165 | nn.BatchNorm1d(num_features=9*128),
166 | nn.ReLU(inplace=True), # (9*128 -> 3*3*128)
167 | Lambda(lambda x: x.view(-1, 128, 3, 3)),
168 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=0,
169 | output_padding=0, bias=False), # out: 7
170 | nn.BatchNorm2d(num_features=128),
171 | nn.ReLU(inplace=True),
172 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1,
173 | output_padding=1, bias=False), # out: 14
174 | nn.BatchNorm2d(num_features=128),
175 | nn.ReLU(inplace=True),
176 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1,
177 | output_padding=1, bias=False), # out: 28
178 | nn.BatchNorm2d(num_features=128),
179 | nn.ReLU(inplace=True),
180 | nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1,
181 | output_padding=0, bias=True), # out: 28
182 | #nn.Sigmoid(),
183 | )
184 |
185 | def forward(self, x):
186 | recon = self.f(x)
187 | return recon
188 |
189 |
190 | # -
191 |
192 | # for 56x56 size
193 | '''
194 | class Omniglot_Model(nn.Module):
195 | def __init__(self):
196 | super(Omniglot_Model, self).__init__()
197 | # encoder
198 | self.f = nn.Sequential(
199 | nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 56
200 | nn.BatchNorm2d(num_features=128),
201 | nn.ReLU(inplace=True),
202 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 28
203 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 28
204 | nn.BatchNorm2d(num_features=128),
205 | nn.ReLU(inplace=True),
206 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 14
207 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 14
208 | nn.BatchNorm2d(num_features=128),
209 | nn.ReLU(inplace=True),
210 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 7
211 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # out: 7
212 | nn.BatchNorm2d(num_features=128),
213 | nn.ReLU(inplace=True),
214 | nn.MaxPool2d(kernel_size=2, stride=2), # out: 3
215 | nn.Flatten(),
216 | nn.Linear(9*128, 1024),
217 | )
218 | # projection head
219 | self.g = nn.Identity()
220 |
221 | def forward(self, x):
222 | feature = self.f(x)
223 | out = self.g(feature)
224 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)
225 | '''
226 |
--------------------------------------------------------------------------------
/Omniglot/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from torchvision import transforms
3 | from torchvision.datasets import CIFAR10, Omniglot
4 |
5 | # +
6 | import cv2
7 | import numpy as np
8 |
9 | from torchvision.datasets.utils import check_integrity, list_dir, list_files
10 | from os.path import join
11 |
12 |
13 | # -
14 |
15 | # np.random.seed(0)
16 |
17 | class GaussianBlur(object):
18 | # Implements Gaussian blur as described in the SimCLR paper
19 | def __init__(self, kernel_size, min=0.1, max=2.0):
20 | self.min = min
21 | self.max = max
22 | # kernel size is set to be 10% of the image height/width
23 | self.kernel_size = kernel_size
24 |
25 | def __call__(self, sample):
26 | sample = np.array(sample)
27 |
28 | # blur the image with a 50% chance
29 | prob = np.random.random_sample()
30 |
31 | if prob < 0.5:
32 | sigma = (self.max - self.min) * np.random.random_sample() + self.min
33 | sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
34 |
35 | return sample
36 |
37 |
38 | class CIFAR10Pair(CIFAR10):
39 | """CIFAR10 Dataset.
40 | """
41 |
42 | def __getitem__(self, index):
43 | img, target = self.data[index], self.targets[index]
44 | img = Image.fromarray(img)
45 |
46 | if self.transform is not None:
47 | pos_1 = self.transform(img)
48 | pos_2 = self.transform(img)
49 |
50 | if self.target_transform is not None:
51 | target = self.target_transform(target)
52 |
53 | return pos_1, pos_2, target
54 |
55 | class Our_Omniglot(Omniglot):
56 | '''
57 | The code is adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/omniglot.py
58 | [Usage]
59 | contrastive_training_data = Our_Omniglot(root='data', background=True, transform=None,
60 | character_target_transform=None, alphabet_target_transform=None, download=True,
61 | contrast_training=True)
62 | classifier_train_data = Our_Omniglot(root='data', background=False, transform=None,
63 | character_target_transform=None, alphabet_target_transform=None, download=True,
64 | eval_split_train=True, out_character=False, contrast_training=False)
65 | classifier_test_data = Our_Omniglot(root='data', background=False, transform=None,
66 | character_target_transform=None, alphabet_target_transform=None, download=True,
67 | eval_split_train=False, out_character=False, contrast_training=False)
68 | '''
69 | def __init__(self, root, background=True, transform=None, character_target_transform=None,
70 | alphabet_target_transform=None, download=False, eval_split_train=True, out_character=False,
71 | contrast_training=True):
72 | super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
73 | target_transform=character_target_transform)
74 | self.background = background
75 |
76 | if download:
77 | self.download()
78 |
79 | if not self._check_integrity():
80 | raise RuntimeError('Dataset not found or corrupted.' +
81 | ' You can use download=True to download it')
82 |
83 | self.character_target_transform = character_target_transform
84 | self.alphabet_target_transform = alphabet_target_transform
85 |
86 |
87 | self.target_folder = join(self.root, self._get_target_folder())
88 | self._alphabets = list_dir(self.target_folder)
89 | self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
90 | for a in self._alphabets], [])
91 | self._character_images = [[(image, idx, self._alphabets.index(character.split('/')[0]))
92 | for image in list_files(join(self.target_folder, character), '.png')]
93 | for idx, character in enumerate(self._characters)]
94 | self._flat_character_images = sum(self._character_images, [])
95 |
96 | self.contrast_training = contrast_training
97 |
98 | # we adopt contrastive training in the background split
99 | if self.contrast_training:
100 | # 20 samples per character
101 | self._flat_character_images = np.array(self._flat_character_images).reshape(-1,20,3)
102 | self.out_character = out_character
103 | # we adopt standard classification training in the evaluation split
104 | else:
105 | # 20 samples per character
106 | self._flat_character_images = np.array(self._flat_character_images).reshape(-1,20,3)
107 | if eval_split_train:
108 | self._flat_character_images = self._flat_character_images[:,:5,:]
109 | else:
110 | self._flat_character_images = self._flat_character_images[:,5:,:]
111 | self._flat_character_images = self._flat_character_images.reshape(-1,3)
112 | self.out_character = out_character
113 | if self.out_character:
114 | self.targets = self._flat_character_images[:,1].astype(np.int64)
115 | else:
116 | self.targets = self._flat_character_images[:,2].astype(np.int64)
117 |
118 | def __getitem__(self, index):
119 | """
120 | Args:
121 | index (int): Index
122 | Returns:
123 | when contrastive training:
124 | tuple: (image0, image1)
125 | image0 and image1 are belong to the same class
126 | when not contrastive training:
127 | tuple: (image, character_target, alphabet_target)
128 | where character_target is index of the target character class
129 | and alphabet_target is index of the target alphabet class.
130 | """
131 | if self.contrast_training:
132 | random_idx = np.random.randint(20, size=2)
133 | image_name_0, character_class_0, alphabet_class_0 = self._flat_character_images[index,random_idx[0]]
134 | character_class_0, alphabet_class_0 = int(character_class_0), int(alphabet_class_0)
135 | image_name_1, character_class_1, alphabet_class_1 = self._flat_character_images[index,random_idx[1]]
136 | character_class_1, alphabet_class_1 = int(character_class_1), int(alphabet_class_1)
137 | image_path_0 = join(self.target_folder, self._characters[character_class_0], image_name_0)
138 | image_0 = Image.open(image_path_0, mode='r').convert('L')
139 | image_path_1 = join(self.target_folder, self._characters[character_class_1], image_name_1)
140 | image_1 = Image.open(image_path_1, mode='r').convert('L')
141 |
142 | if self.transform:
143 | image_0 = self.transform(image_0)
144 | image_1 = self.transform(image_1)
145 |
146 | if self.character_target_transform:
147 | character_class_0 = self.character_target_transform(character_class_0)
148 | # character_class_1 = self.character_target_transform(character_class_1)
149 | if self.alphabet_target_transform:
150 | alphabet_class_0 = self.alphabet_target_transform(alphabet_class_0)
151 | # alphabet_class_1 = self.alphabet_target_transform(alphabet_class_1)
152 |
153 | if self.out_character:
154 | return image_0, image_1, character_class_0#, character_class_1, alphabet_class_0, alphabet_class_1
155 | else:
156 | return image_0, image_1, alphabet_class_0#, character_class_1, alphabet_class_0, alphabet_class_1
157 | else:
158 | image_name, character_class, alphabet_class = self._flat_character_images[index]
159 | character_class, alphabet_class = int(character_class), int(alphabet_class)
160 | image_path = join(self.target_folder, self._characters[character_class], image_name)
161 | image = Image.open(image_path, mode='r').convert('L')
162 |
163 | if self.transform:
164 | image = self.transform(image)
165 |
166 | if self.character_target_transform:
167 | character_class = self.character_target_transform(character_class)
168 | if self.alphabet_target_transform:
169 | alphabet_class = self.alphabet_target_transform(alphabet_class)
170 |
171 | if self.out_character:
172 | return image, character_class
173 | else:
174 | return image, alphabet_class
175 |
176 |
177 | class Our_Omniglot_v2(Omniglot):
178 | '''
179 | The code is adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/omniglot.py
180 | [Usage]
181 | contrastive_training_data = Our_Omniglot(root='data', background=True, transform=None,
182 | character_target_transform=None, alphabet_target_transform=None, download=True,
183 | contrast_training=True)
184 | classifier_train_data = Our_Omniglot(root='data', background=False, transform=None,
185 | character_target_transform=None, alphabet_target_transform=None, download=True,
186 | eval_split_train=True, out_character=False, contrast_training=False)
187 | classifier_test_data = Our_Omniglot(root='data', background=False, transform=None,
188 | character_target_transform=None, alphabet_target_transform=None, download=True,
189 | eval_split_train=False, out_character=False, contrast_training=False)
190 | '''
191 | def __init__(self, root, background=True, transform=None, character_target_transform=None,
192 | alphabet_target_transform=None, download=False, eval_split_train=True, out_character=True,
193 | contrast_training=True):
194 | super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
195 | target_transform=character_target_transform)
196 | self.background = background
197 |
198 | if download:
199 | self.download()
200 |
201 | if not self._check_integrity():
202 | raise RuntimeError('Dataset not found or corrupted.' +
203 | ' You can use download=True to download it')
204 |
205 | self.character_target_transform = character_target_transform
206 | self.alphabet_target_transform = alphabet_target_transform
207 |
208 |
209 | self.target_folder = join(self.root, self._get_target_folder())
210 | self._alphabets = list_dir(self.target_folder)
211 | self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
212 | for a in self._alphabets], [])
213 | self._character_images = [[(image, idx, self._alphabets.index(character.split('/')[0]))
214 | for image in list_files(join(self.target_folder, character), '.png')]
215 | for idx, character in enumerate(self._characters)]
216 | self._flat_character_images = sum(self._character_images, [])
217 |
218 | self.contrast_training = contrast_training
219 |
220 | # 20 samples per character
221 | self._flat_character_images = np.array(self._flat_character_images).reshape(-1,20,3)
222 | if eval_split_train:
223 | self._flat_character_images = self._flat_character_images[:,:5,:]
224 | else:
225 | self._flat_character_images = self._flat_character_images[:,5:,:]
226 | self._flat_character_images = self._flat_character_images.reshape(-1,3)
227 | self.out_character = out_character
228 | if self.out_character:
229 | self.targets = self._flat_character_images[:,1].astype(np.int64)
230 | else:
231 | self.targets = self._flat_character_images[:,2].astype(np.int64)
232 |
233 | def __getitem__(self, index):
234 | """
235 | Args:
236 | index (int): Index
237 | Returns:
238 | when contrastive training:
239 | tuple: (image0, image1)
240 | image0 and image1 are the same image with different image augmentations
241 | when not contrastive training:
242 | tuple: (image, character_target, alphabet_target)
243 | where character_target is index of the target character class
244 | and alphabet_target is index of the target alphabet class.
245 | """
246 | image_name, character_class, alphabet_class = self._flat_character_images[index]
247 | character_class, alphabet_class = int(character_class), int(alphabet_class)
248 | image_path = join(self.target_folder, self._characters[character_class], image_name)
249 | image = Image.open(image_path, mode='r').convert('L')
250 |
251 | if self.character_target_transform:
252 | character_class = self.character_target_transform(character_class)
253 | if self.alphabet_target_transform:
254 | alphabet_class = self.alphabet_target_transform(alphabet_class)
255 |
256 | if self.contrast_training:
257 | if self.transform:
258 | image_0 = self.transform(image)
259 | image_1 = self.transform(image)
260 |
261 | if self.out_character:
262 | return image_0, image_1, character_class
263 | else:
264 | return image_0, image_1, alphabet_class
265 | else:
266 | if self.transform:
267 | image = self.transform(image)
268 |
269 | if self.out_character:
270 | return image, character_class
271 | else:
272 | return image, alphabet_class
273 |
274 | # GausssianBlur is False for CIFAR10
275 |
276 | train_transform = transforms.Compose([
277 | transforms.RandomResizedCrop(32),
278 | transforms.RandomHorizontalFlip(p=0.5),
279 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
280 | transforms.RandomGrayscale(p=0.2),
281 | #GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
282 | transforms.ToTensor(),
283 | #transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
284 | ])
285 |
286 | omniglot_train_transform = transforms.Compose([
287 | transforms.RandomAffine(degrees=10.0, translate=(0.1, 0.1)),
288 | #transforms.RandomResizedCrop(105, scale=(0.85, 1.0), ratio=(0.8, 1.25)),
289 | #transforms.RandomResizedCrop(56, scale=(0.85, 1.0), ratio=(0.8, 1.25)),
290 | transforms.RandomResizedCrop(28, scale=(0.85, 1.0), ratio=(0.8, 1.25)),
291 | transforms.ToTensor(),
292 | lambda x: 1. - x,
293 | ])
294 |
295 | test_transform = transforms.Compose([
296 | transforms.ToTensor(),
297 | #transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
298 | ])
299 |
300 | omniglot_test_transform = transforms.Compose([
301 | #transforms.Resize(105),
302 | #transforms.Resize(56),
303 | transforms.Resize(28),
304 | transforms.ToTensor(),
305 | lambda x: 1. - x,
306 | ])
307 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repo contains codes to reproduce the results in the paper.
2 |
3 | [PDF link](https://arxiv.org/pdf/2006.05576.pdf)
4 |
5 | I will make this repo much more informative in a later date. But welcome any questions regarding the paper or the experiments.
6 |
--------------------------------------------------------------------------------