├── __init__.py ├── models ├── __init__.py ├── utils.py ├── vggnet.py └── resnet.py ├── self_supervision ├── __init__.py ├── relative_patch_location.py ├── jigsaw.py ├── self_supervision_lib.py ├── patch_model_preprocess.py ├── rotation.py ├── supervised.py ├── exemplar.py ├── linear_eval.py └── patch_utils.py ├── config ├── supervised │ └── imagenet.sh ├── rotation │ └── imagenet.sh ├── evaluation │ ├── rotation_or_exemplar.sh │ └── jigsaw_or_relative_patch_location.sh ├── exemplar │ └── imagenet.sh ├── jigsaw │ └── imagenet.sh └── relative_patch_location │ └── imagenet.sh ├── CONTRIBUTING.md ├── setup.py ├── trainer.py ├── README.md ├── utils.py ├── preprocess.py ├── datasets.py ├── LICENSE └── train_and_eval.py /__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /self_supervision/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /config/supervised/imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #!/bin/sh 18 | 19 | python train_and_eval.py \ 20 | --task supervised \ 21 | --architecture resnet50 \ 22 | --filters_factor 4 \ 23 | --weight_decay 1e-4 \ 24 | --dataset imagenet \ 25 | --train_split trainval \ 26 | --val_split test \ 27 | --batch_size 256 \ 28 | --eval_batch_size 10 \ 29 | \ 30 | --preprocessing inception_preprocess \ 31 | --resize_size 224 \ 32 | \ 33 | --lr 0.1 \ 34 | --lr_scale_batch_size 256 \ 35 | --decay_epochs 30,60,80 \ 36 | --epochs 90 \ 37 | --warmup_epochs 5 \ 38 | "$@" 39 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /config/rotation/imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #!/bin/sh 18 | 19 | python train_and_eval.py \ 20 | --task rotation \ 21 | --dataset imagenet \ 22 | --train_split train \ 23 | --val_split val \ 24 | --batch_size 64 \ 25 | --eval_batch_size 16 \ 26 | \ 27 | --architecture revnet50 \ 28 | --filters_factor 16 \ 29 | --last_relu True \ 30 | \ 31 | --preprocessing inception_preprocess,rotate \ 32 | --resize_size 224,224 \ 33 | \ 34 | --lr 0.1 \ 35 | --lr_scale_batch_size 256 \ 36 | --decay_epochs 15,25 \ 37 | --epochs 35 \ 38 | --warmup_epochs 5 \ 39 | \ 40 | --serving_input_shape None,224,224,3 \ 41 | "$@" 42 | -------------------------------------------------------------------------------- /config/evaluation/rotation_or_exemplar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #!/bin/sh 18 | 19 | # For the reported results, try following parameters on 4x4 TPUs: 20 | # batch_size: 2048 21 | # decay_epochs: 480,500 22 | # epochs: 520 23 | python train_and_eval.py \ 24 | --task linear_eval \ 25 | --dataset imagenet \ 26 | --train_split trainval \ 27 | --val_split test \ 28 | --batch_size 512 \ 29 | --eval_batch_size 32 \ 30 | \ 31 | --preprocessing resize_small,crop,-1_to_1 \ 32 | --crop_size 224,224 \ 33 | --smaller_size 256 \ 34 | --pool_mode max \ 35 | \ 36 | --lr 0.1 \ 37 | --decay_epochs 30,50 \ 38 | --epochs 70 \ 39 | --lr_scale_batch_size 256 \ 40 | --hub_module ~/Downloads/module \ 41 | "$@" 42 | -------------------------------------------------------------------------------- /config/exemplar/imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #!/bin/sh 18 | 19 | python train_and_eval.py \ 20 | --task exemplar \ 21 | --dataset imagenet \ 22 | --train_split train \ 23 | --val_split val \ 24 | --batch_size 64 \ 25 | --eval_batch_size 16 \ 26 | \ 27 | --architecture resnet50 \ 28 | --filters_factor 12 \ 29 | --last_relu True \ 30 | --mode v1 \ 31 | \ 32 | --preprocessing to_gray,crop_inception_preprocess_patches,standardization \ 33 | --resize_size 224,224 \ 34 | --grayscale_probability 0.66 \ 35 | --embed_dim 1000 \ 36 | --margin 0.5 \ 37 | --num_of_inception_patches 8 \ 38 | \ 39 | --lr 0.1 \ 40 | --lr_scale_batch_size 256 \ 41 | --decay_epochs 15,25 \ 42 | --epochs 35 \ 43 | --warmup_epochs 5 \ 44 | \ 45 | --serving_input_shape None,224,224,3 \ 46 | "$@" 47 | -------------------------------------------------------------------------------- /config/jigsaw/imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #!/bin/sh 18 | 19 | python train_and_eval.py \ 20 | --task jigsaw \ 21 | --dataset imagenet \ 22 | --train_split train \ 23 | --val_split val \ 24 | --batch_size 128 \ 25 | --eval_batch_size 8 \ 26 | \ 27 | --architecture resnet50 \ 28 | --filters_factor 8 \ 29 | --last_relu True \ 30 | --mode v1 \ 31 | \ 32 | --preprocessing resize,to_gray,crop,crop_patches,standardization \ 33 | --resize_size 292,292 \ 34 | --crop_size 255 \ 35 | --grayscale_probability 0.66 \ 36 | --splits_per_side 3 \ 37 | --patch_jitter 21 \ 38 | --embed_dim 1000 \ 39 | \ 40 | --lr 0.1 \ 41 | --lr_scale_batch_size 256 \ 42 | --decay_epochs 15,25 \ 43 | --epochs 35 \ 44 | --warmup_epochs 5 \ 45 | \ 46 | --serving_input_shape None,64,64,3 \ 47 | "$@" 48 | -------------------------------------------------------------------------------- /config/relative_patch_location/imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #!/bin/sh 18 | 19 | python train_and_eval.py \ 20 | --task relative_patch_location \ 21 | --dataset imagenet \ 22 | --train_split train \ 23 | --val_split val \ 24 | --batch_size 128 \ 25 | --eval_batch_size 8 \ 26 | \ 27 | --architecture resnet50 \ 28 | --filters_factor 8 \ 29 | --last_relu True \ 30 | --mode v1 \ 31 | \ 32 | --preprocessing resize,to_gray,crop,crop_patches,standardization \ 33 | --resize_size 292,292 \ 34 | --crop_size 255 \ 35 | --grayscale_probability 0.66 \ 36 | --splits_per_side 3 \ 37 | --patch_jitter 21 \ 38 | --embed_dim 1000 \ 39 | \ 40 | --lr 0.1 \ 41 | --lr_scale_batch_size 256 \ 42 | --decay_epochs 15,25 \ 43 | --epochs 35 \ 44 | --warmup_epochs 5 \ 45 | \ 46 | --serving_input_shape None,64,64,3 \ 47 | "$@" 48 | -------------------------------------------------------------------------------- /config/evaluation/jigsaw_or_relative_patch_location.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | #!/bin/sh 18 | 19 | # For the reported results, try following parameters on 4x4 TPUs: 20 | # batch_size: 2048 21 | # decay_epochs: 480,500 22 | # epochs: 520 23 | python train_and_eval.py \ 24 | --task linear_eval \ 25 | --dataset imagenet \ 26 | --train_split trainval \ 27 | --val_split test \ 28 | --batch_size 512 \ 29 | --eval_batch_size 32 \ 30 | \ 31 | --preprocessing resize_small,crop,crop_patches,standardization \ 32 | --resize_size 256,256 \ 33 | --crop_size 192,192 \ 34 | --smaller_size 224 \ 35 | --patch_jitter 0 \ 36 | --splits_per_side 3 \ 37 | --pool_mode max \ 38 | --combine_patches avg_pool \ 39 | \ 40 | --lr 0.1 \ 41 | --decay_epochs 30,50 \ 42 | --epochs 70 \ 43 | --lr_scale_batch_size 256 \ 44 | --hub_module ~/Downloads/module \ 45 | "$@" 46 | -------------------------------------------------------------------------------- /self_supervision/relative_patch_location.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Produces ratations for input images. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow as tf 25 | 26 | from self_supervision import patch_utils 27 | 28 | 29 | def model_fn(data, mode): 30 | """Produces a loss for the relative patch location task. 31 | 32 | Args: 33 | data: Dict of inputs ("image" being the image) 34 | mode: model's mode: training, eval or prediction 35 | 36 | Returns: 37 | EstimatorSpec 38 | """ 39 | images = data['image'] 40 | 41 | # Patch locations 42 | perms, num_classes = patch_utils.generate_patch_locations() 43 | labels = tf.tile(list(range(num_classes)), tf.shape(images)[:1]) 44 | 45 | return patch_utils.creates_estimator_model( 46 | images, labels, perms, num_classes, mode) 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Install self supervised learning package.""" 18 | 19 | from setuptools import find_packages 20 | from setuptools import setup 21 | 22 | setup( 23 | name='self_supervised_learning', 24 | version='1.0', 25 | description=('Self Supervised Learning - code from "Revisiting ' 26 | 'Self-Supervised Visual Representation Learning" paper'), 27 | author='Google LLC', 28 | author_email='no-reply@google.com', 29 | url='http://github.com/TODO', 30 | license='Apache 2.0', 31 | packages=find_packages(), 32 | package_data={ 33 | }, 34 | scripts=[ 35 | ], 36 | install_requires=[ 37 | 'future', 38 | 'numpy', 39 | 'absl-py', 40 | 'tensorflow', 41 | 'tensorflow-hub', 42 | ], 43 | classifiers=[ 44 | 'Development Status :: 4 - Beta', 45 | 'Intended Audience :: Developers', 46 | 'Intended Audience :: Science/Research', 47 | 'License :: OSI Approved :: Apache Software License', 48 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 49 | ], 50 | keywords='tensorflow self supervised learning', 51 | ) 52 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Helper functions for NN models. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import functools 25 | import absl.flags as flags 26 | 27 | import models.resnet 28 | import models.vggnet 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def get_net(num_classes=None): # pylint: disable=missing-docstring 34 | architecture = FLAGS.architecture 35 | 36 | if 'vgg19' in architecture: 37 | net = functools.partial( 38 | models.vggnet.vgg19, 39 | filters_factor=FLAGS.get_flag_value('filters_factor', 8)) 40 | else: 41 | if 'resnet50' in architecture: 42 | net = models.resnet.resnet50 43 | elif 'revnet50' in architecture: 44 | net = models.resnet.revnet50 45 | else: 46 | raise ValueError('Unsupported architecture: %s' % architecture) 47 | 48 | net = functools.partial( 49 | net, 50 | filters_factor=FLAGS.get_flag_value('filters_factor', 4), 51 | last_relu=FLAGS.get_flag_value('last_relu', True), 52 | mode=FLAGS.get_flag_value('mode', 'v2')) 53 | 54 | if FLAGS.task in ('jigsaw', 'relative_patch_location'): 55 | net = functools.partial(net, root_conv_stride=1, strides=(2, 2, 1)) 56 | 57 | # Few things that are common across all models. 58 | net = functools.partial( 59 | net, num_classes=num_classes, 60 | weight_decay=FLAGS.get_flag_value('weight_decay', 1e-4)) 61 | 62 | return net 63 | -------------------------------------------------------------------------------- /self_supervision/jigsaw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Produces ratations for input images. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow as tf 25 | 26 | from self_supervision import patch_utils 27 | 28 | FLAGS = tf.flags.FLAGS 29 | 30 | 31 | def model_fn(data, mode): 32 | """Produces a loss for the jigsaw task. 33 | 34 | Args: 35 | data: Dict of inputs ("image" being the image) 36 | mode: model's mode: training, eval or prediction 37 | 38 | Returns: 39 | EstimatorSpec 40 | """ 41 | images = data['image'] 42 | 43 | # Patch locations 44 | perms, num_classes = patch_utils.load_permutations() 45 | labels = list(range(num_classes)) 46 | 47 | # Selects a subset of permutation for training. There're two methods: 48 | # 1. For each image, selects 16 permutations independently. 49 | # 2. For each batch of images, selects the same 16 permutations. 50 | # Here we used method 2, for simplicity. 51 | if mode in [tf.estimator.ModeKeys.TRAIN]: 52 | perm_subset_size = FLAGS.get_flag_value('perm_subset_size', 8) 53 | indexs = list(range(num_classes)) 54 | indexs = tf.random_shuffle(indexs) 55 | labels = indexs[:perm_subset_size] 56 | perms = tf.gather(perms, labels, axis=0) 57 | tf.logging.info('subsample %s' % perms) 58 | 59 | labels = tf.tile(labels, tf.shape(images)[:1]) 60 | 61 | return patch_utils.creates_estimator_model( 62 | images, labels, perms, num_classes, mode) 63 | -------------------------------------------------------------------------------- /self_supervision/self_supervision_lib.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Generates training data with self supervision. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow as tf 25 | 26 | from self_supervision import exemplar 27 | from self_supervision import jigsaw 28 | from self_supervision import linear_eval 29 | from self_supervision import relative_patch_location 30 | from self_supervision import rotation 31 | from self_supervision import supervised 32 | 33 | 34 | def get_self_supervision_model(self_supervision): 35 | """Gets self supervised training data and labels.""" 36 | 37 | mapping = { 38 | "linear_eval": linear_eval.model_fn, 39 | "supervised": supervised.model_fn, 40 | 41 | "rotation": rotation.model_fn, 42 | "jigsaw": jigsaw.model_fn, 43 | "relative_patch_location": relative_patch_location.model_fn, 44 | "exemplar": exemplar.model_fn, 45 | } 46 | 47 | model_fn = mapping.get(self_supervision) 48 | if model_fn is None: 49 | raise ValueError("Unknown self-supervision: %s" % self_supervision) 50 | 51 | def _model_fn(features, labels, mode, params): 52 | """Returns the EstimatorSpec to run the model. 53 | 54 | Args: 55 | features: Dict of inputs ("image" being the image). 56 | labels: unused but required by Estimator API. 57 | mode: model's mode: training, eval or prediction 58 | params: required by Estimator API, contains TPU local `batch_size`. 59 | 60 | Returns: 61 | EstimatorSpec 62 | 63 | Raises: 64 | ValueError when the self_supervision is unknown. 65 | """ 66 | del labels, params # unused 67 | tf.logging.info("Calling model_fn in mode %s with data:", mode) 68 | tf.logging.info(features) 69 | return model_fn(features, mode) 70 | 71 | return _model_fn 72 | -------------------------------------------------------------------------------- /models/vggnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements VGG model. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import functools 25 | 26 | import tensorflow as tf 27 | 28 | 29 | def convbnrelu(x, filters, size, training, **convkw): 30 | x = tf.layers.conv2d(x, kernel_size=size, filters=filters, 31 | use_bias=False, **convkw) 32 | x = tf.layers.batch_normalization(x, fused=True, training=training) 33 | x = tf.nn.relu(x) 34 | return x 35 | 36 | 37 | def vgg19(x, is_training, num_classes=1000, # pylint: disable=missing-docstring 38 | filters_factor=8, 39 | weight_decay=5e-4): 40 | # NOTE: default weight_decay here is as in the VGGNet paper, which is 41 | # different from the ResNet and RevNet models. 42 | # NOTE: Another difference is that we are using BatchNorm, and because of 43 | # that, we are not using Dropout in the final FC layers. 44 | 45 | regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay) 46 | conv3 = functools.partial(convbnrelu, size=3, training=is_training, 47 | kernel_regularizer=regularizer, padding='same') 48 | fc = functools.partial(convbnrelu, training=is_training, 49 | kernel_regularizer=regularizer, padding='valid') 50 | 51 | end_points = {} 52 | 53 | # After long discussion, we settled on filters_factor=8 being the default and 54 | # thus needing to match the vanilla VGGNet, which starts with 64. 55 | w = 8 * filters_factor # w stands for width (a la wide-resnet) 56 | x = conv3(x, w) 57 | x = conv3(x, w) 58 | end_points['block1'] = x 59 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 112x112 60 | x = conv3(x, 2*w) 61 | x = conv3(x, 2*w) 62 | end_points['block2'] = x 63 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 56x56 64 | x = conv3(x, 4*w) 65 | x = conv3(x, 4*w) 66 | x = conv3(x, 4*w) 67 | x = conv3(x, 4*w) 68 | end_points['block3'] = x 69 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 28x28 70 | x = conv3(x, 8*w) 71 | x = conv3(x, 8*w) 72 | x = conv3(x, 8*w) 73 | x = conv3(x, 8*w) 74 | end_points['block4'] = x 75 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 14x14 76 | x = conv3(x, 8*w) 77 | x = conv3(x, 8*w) 78 | x = conv3(x, 8*w) 79 | x = conv3(x, 8*w) 80 | end_points['block5'] = x 81 | 82 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 7x7 83 | x = fc(x, 512*filters_factor, x.get_shape().as_list()[-3:-1]) 84 | end_points['fc6'] = x 85 | x = fc(x, 512*filters_factor, (1, 1)) 86 | 87 | end_points['pre_logits'] = x 88 | x = tf.layers.conv2d(x, kernel_size=1, filters=num_classes, 89 | use_bias=True, kernel_regularizer=regularizer) 90 | x = tf.squeeze(x, [1, 2]) 91 | end_points['logits'] = x 92 | 93 | return x, end_points 94 | -------------------------------------------------------------------------------- /self_supervision/patch_model_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # pylint: disable=missing-docstring 18 | """Preprocessing methods for self supervised representation learning. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import functools 26 | import tensorflow as tf 27 | 28 | import utils as utils 29 | 30 | 31 | def crop(image, is_training, crop_size): 32 | h, w, c = crop_size[0], crop_size[1], image.shape[-1] 33 | 34 | if is_training: 35 | return tf.random_crop(image, [h, w, c]) 36 | else: 37 | # Central crop for now. (See Table 5 in Appendix of 38 | # https://arxiv.org/pdf/1703.07737.pdf for why) 39 | dy = (tf.shape(image)[0] - h)//2 40 | dx = (tf.shape(image)[1] - w)//2 41 | return tf.image.crop_to_bounding_box(image, dy, dx, h, w) 42 | 43 | 44 | def image_to_patches(image, is_training, split_per_side, patch_jitter=0): 45 | """Crops split_per_side x split_per_side patches from input image. 46 | 47 | Args: 48 | image: input image tensor with shape [h, w, c]. 49 | is_training: is training flag. 50 | split_per_side: split of patches per image side. 51 | patch_jitter: jitter of each patch from each grid. 52 | 53 | Returns: 54 | Patches tensor with shape [patch_count, hc, wc, c]. 55 | """ 56 | h, w, _ = image.get_shape().as_list() 57 | 58 | h_grid = h // split_per_side 59 | w_grid = w // split_per_side 60 | h_patch = h_grid - patch_jitter 61 | w_patch = w_grid - patch_jitter 62 | 63 | tf.logging.info( 64 | "Crop patches - image size: (%d, %d), split_per_side: %d, " 65 | "grid_size: (%d, %d), patch_size: (%d, %d), split_jitter: %d", 66 | h, w, split_per_side, h_grid, w_grid, h_patch, w_patch, patch_jitter) 67 | 68 | patches = [] 69 | for i in range(split_per_side): 70 | for j in range(split_per_side): 71 | 72 | p = tf.image.crop_to_bounding_box(image, i * h_grid, j * w_grid, h_grid, 73 | w_grid) 74 | # Trick: crop a small tile from pixel cell, to avoid edge continuity. 75 | if h_patch < h_grid or w_patch < w_grid: 76 | p = crop(p, is_training, [h_patch, w_patch]) 77 | 78 | patches.append(p) 79 | 80 | return tf.stack(patches) 81 | 82 | 83 | def get_crop_patches_fn(is_training, split_per_side, patch_jitter=0): 84 | """Gets a function which crops split_per_side x split_per_side patches. 85 | 86 | Args: 87 | is_training: is training flag. 88 | split_per_side: split of patches per image side. 89 | patch_jitter: jitter of each patch from each grid. E.g. 255x255 input 90 | image with split_per_side=3 will be split into 3 85x85 grids, and 91 | patches are cropped from each grid with size (grid_size-patch_jitter, 92 | grid_size-patch_jitter). 93 | 94 | Returns: 95 | A function returns name to tensor dict. This function crops split_per_side x 96 | split_per_side patches from "image" tensor in input data dict. 97 | """ 98 | 99 | def _crop_patches_pp(data): 100 | image = data["image"] 101 | 102 | image_to_patches_fn = functools.partial( 103 | image_to_patches, 104 | is_training=is_training, 105 | split_per_side=split_per_side, 106 | patch_jitter=patch_jitter) 107 | image = utils.tf_apply_to_image_or_images(image_to_patches_fn, image) 108 | 109 | data["image"] = image 110 | return data 111 | return _crop_patches_pp 112 | 113 | -------------------------------------------------------------------------------- /self_supervision/rotation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Produces ratations for input images. 18 | 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import functools 26 | import tensorflow as tf 27 | import tensorflow_hub as hub 28 | 29 | from models.utils import get_net 30 | import trainer 31 | import utils 32 | 33 | 34 | FLAGS = tf.flags.FLAGS 35 | 36 | 37 | def apply_model(image_fn, # pylint: disable=missing-docstring 38 | is_training, 39 | num_outputs, 40 | make_signature=False): 41 | 42 | # Image tensor needs to be created lazily in order to satisfy tf-hub 43 | # restriction: all tensors should be created inside tf-hub helper function. 44 | images = image_fn() 45 | 46 | net = get_net(num_classes=num_outputs) 47 | 48 | output, end_points = net(images, is_training) 49 | 50 | if make_signature: 51 | hub.add_signature(inputs={'image': images}, outputs=output) 52 | hub.add_signature( 53 | name='representation', 54 | inputs={'image': images}, 55 | outputs=end_points) 56 | return output 57 | 58 | 59 | def model_fn(data, mode): 60 | """Produces a loss for the rotation task. 61 | 62 | Args: 63 | data: Dict of inputs containing, among others, "image" and "label." 64 | mode: model's mode: training, eval or prediction 65 | 66 | Returns: 67 | EstimatorSpec 68 | """ 69 | num_angles = 4 70 | images = data['image'] 71 | 72 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: 73 | images = tf.reshape(images, [-1] + images.get_shape().as_list()[-3:]) 74 | with tf.variable_scope('module'): 75 | image_fn = lambda: images 76 | logits = apply_model( 77 | image_fn=image_fn, 78 | is_training=(mode == tf.estimator.ModeKeys.TRAIN), 79 | num_outputs=num_angles, 80 | make_signature=False) 81 | else: 82 | input_shape = utils.str2intlist( 83 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3')) 84 | image_fn = lambda: tf.placeholder(shape=input_shape, # pylint: disable=g-long-lambda 85 | dtype=tf.float32) 86 | apply_model_function = functools.partial( 87 | apply_model, 88 | image_fn=image_fn, 89 | num_outputs=num_angles, 90 | make_signature=True) 91 | tf_hub_module_spec = hub.create_module_spec(apply_model_function, 92 | [(utils.TAGS_IS_TRAINING, { 93 | 'is_training': True 94 | }), 95 | (set(), { 96 | 'is_training': False 97 | })]) 98 | tf_hub_module = hub.Module(tf_hub_module_spec, trainable=False, tags=set()) 99 | hub.register_module_for_export(tf_hub_module, export_name='module') 100 | logits = tf_hub_module(images) 101 | 102 | return trainer.make_estimator(mode, predictions=logits) 103 | 104 | labels = tf.reshape(data['label'], [-1]) 105 | 106 | # build loss and accuracy 107 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 108 | labels=labels, logits=logits) 109 | loss = tf.reduce_mean(loss) 110 | 111 | eval_metrics = ( 112 | lambda labels, logits: { # pylint: disable=g-long-lambda 113 | 'accuracy': tf.metrics.accuracy( 114 | labels=labels, 115 | predictions=tf.argmax(logits, axis=-1))}, 116 | [labels, logits]) 117 | return trainer.make_estimator(mode, loss, eval_metrics, logits) 118 | -------------------------------------------------------------------------------- /self_supervision/supervised.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements fully-supervised model. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import functools 25 | import tensorflow as tf 26 | import tensorflow_hub as hub 27 | 28 | import datasets 29 | from models.utils import get_net 30 | import trainer 31 | import utils 32 | 33 | FLAGS = tf.flags.FLAGS 34 | 35 | 36 | def apply_model(image_fn, # pylint: disable=missing-docstring 37 | is_training, 38 | num_outputs, 39 | make_signature=False): 40 | 41 | # Image tensor needs to be created lazily in order to satisfy tf-hub 42 | # restriction: all tensors should be created inside tf-hub helper function. 43 | images = image_fn() 44 | 45 | net = get_net(num_classes=num_outputs) 46 | 47 | output, end_points = net(images, is_training) 48 | 49 | if make_signature: 50 | hub.add_signature(inputs={'image': images}, outputs=output) 51 | hub.add_signature(inputs={'image': images}, outputs=end_points, 52 | name='representation') 53 | return output 54 | 55 | 56 | def model_fn(data, mode): 57 | """Produces a loss for the fully-supervised task. 58 | 59 | Args: 60 | data: Dict of inputs containing, among others, "image" and "label." 61 | mode: model's mode: training, eval or prediction 62 | 63 | Returns: 64 | EstimatorSpec 65 | """ 66 | images = data['image'] 67 | 68 | # In predict mode (called once at the end of training), we only instantiate 69 | # the model in order to export a tf.hub module for it. 70 | # This will then make saving and loading much easier down the line. 71 | if mode == tf.estimator.ModeKeys.PREDICT: 72 | input_shape = utils.str2intlist( 73 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3')) 74 | apply_model_function = functools.partial( 75 | apply_model, 76 | image_fn=lambda: tf.placeholder(shape=input_shape, dtype=tf.float32), # pylint: disable=g-long-lambda 77 | num_outputs=datasets.get_num_classes(), 78 | make_signature=True) 79 | tf_hub_module_spec = hub.create_module_spec( 80 | apply_model_function, 81 | [(utils.TAGS_IS_TRAINING, {'is_training': True}), 82 | (set(), {'is_training': False})]) 83 | tf_hub_module = hub.Module(tf_hub_module_spec, trainable=False, tags=set()) 84 | hub.register_module_for_export(tf_hub_module, export_name='module') 85 | logits = tf_hub_module(images) 86 | 87 | # There is no training happening anymore, only prediciton and model export. 88 | return trainer.make_estimator(mode, predictions=logits) 89 | 90 | # From here on, we are either in train or eval modes. 91 | # Create the model in the 'module' name scope so it matches nicely with 92 | # tf.hub's requirements for import/export later. 93 | with tf.variable_scope('module'): 94 | logits = apply_model( 95 | image_fn=lambda: images, 96 | is_training=(mode == tf.estimator.ModeKeys.TRAIN), 97 | num_outputs=datasets.get_num_classes(), 98 | make_signature=False) 99 | 100 | labels = data['label'] 101 | 102 | # build loss and accuracy 103 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 104 | labels=labels, logits=logits) 105 | loss = tf.reduce_mean(loss) 106 | 107 | # Gets a metric_fn which evaluates the "top1_accuracy" and "top5_accuracy". 108 | # The resulting metrics are named "top1_accuracy_{tensor_name}", 109 | # "top5_accuracy_{tensor_name}". 110 | metrics_fn = utils.get_classification_metrics(['logits']) 111 | # A tuple of metric_fn and a list of tensors to be evaluated by TPUEstimator. 112 | eval_metrics_tuple = (metrics_fn, [labels, logits]) 113 | 114 | return trainer.make_estimator(mode, loss, eval_metrics_tuple) 115 | -------------------------------------------------------------------------------- /self_supervision/exemplar.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Exemplar implementation with triplet semihard loss.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import functools 24 | import tensorflow as tf 25 | import tensorflow_hub as hub 26 | 27 | import utils 28 | from models.utils import get_net 29 | from trainer import make_estimator 30 | 31 | FLAGS = tf.flags.FLAGS 32 | 33 | 34 | def apply_model(image_fn, 35 | is_training, 36 | num_outputs, 37 | make_signature=False): 38 | """Creates the patch based model output from patches representations. 39 | 40 | Args: 41 | image_fn: function returns image tensor. 42 | is_training: is training flag used for batch norm and drop out. 43 | num_outputs: number of output classes. 44 | make_signature: whether to create signature for hub module. 45 | 46 | 47 | Returns: 48 | out: output tensor with shape [n*m, 1, 1, num_outputs]. 49 | 50 | Raises: 51 | ValueError: An error occurred when the architecture is unknown. 52 | """ 53 | # Image tensor needs to be created lazily in order to satisfy tf-hub 54 | # restriction: all tensors should be created inside tf-hub helper function. 55 | images = image_fn() 56 | 57 | net = get_net(num_classes=num_outputs) 58 | out, end_points = net(images, is_training, 59 | weight_decay=FLAGS.get_flag_value('weight_decay', 1e-4)) 60 | 61 | print(end_points) 62 | 63 | if len(out.get_shape().as_list()) == 4: 64 | out = tf.squeeze(out, [1, 2]) 65 | 66 | if make_signature: 67 | hub.add_signature(inputs={'image': images}, outputs=out) 68 | hub.add_signature( 69 | name='representation', 70 | inputs={'image': images}, 71 | outputs=end_points) 72 | return out 73 | 74 | 75 | def repeat(x, times): 76 | """Exactly like np.repeat.""" 77 | return tf.reshape(tf.tile(tf.expand_dims(x, -1), [1, times]), [-1]) 78 | 79 | 80 | def model_fn(data, mode): 81 | """Produces a loss for the exemplar task supervision. 82 | 83 | Args: 84 | data: Dict of inputs containing, among others, "image" and "label." 85 | mode: model's mode: training, eval or prediction 86 | 87 | Returns: 88 | EstimatorSpec 89 | """ 90 | images = data['image'] 91 | batch_size = tf.shape(images)[0] 92 | print(' +++ Mode: %s, data: %s' % (mode, data)) 93 | 94 | embed_dim = FLAGS.embed_dim 95 | patch_count = images.get_shape().as_list()[1] 96 | 97 | images = tf.reshape( 98 | images, shape=[-1] + images.get_shape().as_list()[-3:]) 99 | 100 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: 101 | images = tf.reshape(images, [-1] + images.get_shape().as_list()[-3:]) 102 | with tf.variable_scope('module'): 103 | image_fn = lambda: images 104 | logits = apply_model( 105 | image_fn=image_fn, 106 | is_training=(mode == tf.estimator.ModeKeys.TRAIN), 107 | num_outputs=embed_dim, 108 | make_signature=False) 109 | else: 110 | input_shape = utils.str2intlist( 111 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3')) 112 | image_fn = lambda: tf.placeholder(shape=input_shape, # pylint: disable=g-long-lambda 113 | dtype=tf.float32) 114 | apply_model_function = functools.partial( 115 | apply_model, 116 | image_fn=image_fn, 117 | num_outputs=embed_dim, 118 | make_signature=True) 119 | 120 | tf_hub_module_spec = hub.create_module_spec(apply_model_function, 121 | [(utils.TAGS_IS_TRAINING, { 122 | 'is_training': True 123 | }), 124 | (set(), { 125 | 'is_training': False 126 | })], 127 | drop_collections=['summaries']) 128 | tf_hub_module = hub.Module(tf_hub_module_spec, trainable=False, tags=set()) 129 | hub.register_module_for_export(tf_hub_module, export_name='module') 130 | logits = tf_hub_module(images) 131 | return make_estimator(mode, predictions=logits) 132 | 133 | labels = repeat(tf.range(batch_size), patch_count) 134 | norm_logits = tf.nn.l2_normalize(logits, axis=0) 135 | loss = tf.contrib.losses.metric_learning.triplet_semihard_loss( 136 | labels, norm_logits, margin=FLAGS.margin) 137 | 138 | return make_estimator(mode, loss, predictions=logits) 139 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Base trainer class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | import datasets 26 | import utils 27 | 28 | FLAGS = tf.flags.FLAGS 29 | 30 | 31 | def get_lr(global_step, base_lr, steps_per_epoch, # pylint: disable=missing-docstring 32 | decay_epochs, lr_decay_factor, warmup_epochs): 33 | 34 | warmup_lr = 0.0 35 | if warmup_epochs > 0: 36 | warmup_lr = (tf.cast(global_step, tf.float32) * 37 | (base_lr / (warmup_epochs * steps_per_epoch))) 38 | 39 | normal_lr = tf.train.piecewise_constant( 40 | global_step, 41 | [e * steps_per_epoch for e in decay_epochs], 42 | [base_lr * (lr_decay_factor ** i) for i in range(len(decay_epochs) + 1)] 43 | ) 44 | 45 | lr = tf.cond(tf.less(global_step, warmup_epochs * steps_per_epoch), 46 | lambda: warmup_lr, 47 | lambda: normal_lr) 48 | 49 | return lr 50 | 51 | 52 | # TODO(akolesnikov): add more logging 53 | class Trainer(object): 54 | """Base trainer class.""" 55 | 56 | def __init__(self, 57 | update_batchnorm_params=True): 58 | self.update_batchnorm_params = update_batchnorm_params 59 | 60 | split = FLAGS.get_flag_value('train_split', 'train') 61 | num_samples = datasets.get_count(split) 62 | steps_per_epoch = num_samples // FLAGS.batch_size 63 | 64 | global_step = tf.train.get_or_create_global_step() 65 | self.global_step_inc = tf.assign_add(global_step, 1) 66 | 67 | # lr_scale_batch_size defines a canonical batch size that is coupled with 68 | # the initial learning rate. If actual batch size is not the same as 69 | # canonical than learning rate is linearly scaled. This is very convinient 70 | # as this allows to vary batch size without recomputing learning rate. 71 | lr_factor = 1.0 72 | if FLAGS.get_flag_value('lr_scale_batch_size', 0): 73 | lr_factor = FLAGS.batch_size / float(FLAGS.lr_scale_batch_size) 74 | 75 | deps = FLAGS.get_flag_value('decay_epochs', None) 76 | decay_epochs = utils.str2intlist(deps) if deps else [FLAGS.epochs] 77 | 78 | self.lr = get_lr( 79 | global_step, 80 | base_lr=FLAGS.lr * lr_factor, 81 | steps_per_epoch=steps_per_epoch, 82 | decay_epochs=decay_epochs, 83 | lr_decay_factor=FLAGS.get_flag_value('lr_decay_factor', 0.1), 84 | warmup_epochs=FLAGS.get_flag_value('warmup_epochs', 0)) 85 | 86 | # TODO(marvinritter): Re-enable summaries with support for TPU training. 87 | # tf.summary.scalar('learning_rate', self.lr) 88 | 89 | def get_train_op(self, loss, # pylint: disable=missing-docstring 90 | var_list=None, 91 | add_reg_loss=True, 92 | use_tpu=False): 93 | 94 | if add_reg_loss: 95 | l2_loss = tf.reduce_sum(tf.losses.get_regularization_losses()) 96 | loss += l2_loss 97 | 98 | optimizer = FLAGS.get_flag_value('optimizer', 'sgd') 99 | if optimizer == 'sgd': 100 | optimizer = tf.train.MomentumOptimizer(learning_rate=self.lr, 101 | momentum=0.9) 102 | elif optimizer == 'adam': 103 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 104 | else: 105 | raise ValueError('Unknown optimizer: %s' % optimizer) 106 | 107 | if use_tpu: 108 | # Wrap optimizer in CrossShardOptimizer which takes care of 109 | # synchronizing the weight updates between TPU cores. 110 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 111 | 112 | opt_step = optimizer.minimize(loss, var_list=var_list, 113 | colocate_gradients_with_ops=True) 114 | 115 | if self.update_batchnorm_params: 116 | opt_step = tf.group([opt_step] + 117 | tf.get_collection(tf.GraphKeys.UPDATE_OPS)) 118 | 119 | opt_step = tf.group([opt_step, self.global_step_inc]) 120 | 121 | return opt_step 122 | 123 | 124 | def make_estimator(mode, loss=None, eval_metrics=None, predictions=None): 125 | """Returns an EstimatorSpec (maybe TPU) for all modes.""" 126 | 127 | # Always use TPUEstimator, even when not using TPU, then it's (almost) no-op. 128 | spec_type = tf.contrib.tpu.TPUEstimatorSpec 129 | 130 | if mode == tf.estimator.ModeKeys.PREDICT: 131 | assert predictions is not None, 'Need to pass `predict` arg.' 132 | return spec_type(mode=mode, predictions=predictions) 133 | 134 | if mode == tf.estimator.ModeKeys.EVAL: 135 | return spec_type(mode=mode, loss=loss, eval_metrics=eval_metrics) 136 | 137 | if mode == tf.estimator.ModeKeys.TRAIN: 138 | assert loss is not None, 'Need to pass `loss` arg.' 139 | trainer = Trainer(update_batchnorm_params=True) 140 | train_op = trainer.get_train_op(loss, use_tpu=FLAGS.use_tpu) 141 | return spec_type(mode=mode, loss=loss, train_op=train_op) 142 | 143 | raise ValueError('Unsupported mode %s' % mode) 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Revisiting self-supervised visual representation learning 2 | 3 | Tensorflow implementation of experiments from 4 | [our paper on unsupervised visual representation learning](http://arxiv.org/abs/1901.09005). 5 | 6 | If you find this repository useful in your research, please consider citing: 7 | 8 | ``` 9 | @inproceedings{kolesnikov2019revisiting, 10 | title={Revisiting self-supervised visual representation learning}, 11 | author={Kolesnikov, Alexander and Zhai, Xiaohua and Beyer, Lucas}, 12 | journal={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 13 | month={June}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | ## Overview 19 | 20 | This codebase allows to reproduce core experiments from our paper. It contains 21 | our re-implementation of four self-supervised representation learning 22 | techniques, utility code for running training and evaluation loops (including on 23 | TPUs) and an implementation of standard CNN models, such as ResNet v1, ResNet v2 24 | and VGG19. 25 | 26 | Specifically, we provide a re-implementation of the following self-supervised 27 | representation learning techniques: 28 | 29 | 1. [Unsupervised Representation Learning by Predicting Image Rotations](https://arxiv.org/abs/1803.07728) 30 | 2. [Unsupervised Visual Representation Learning by Context Prediction](https://arxiv.org/abs/1505.05192) 31 | 3. [Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles](https://arxiv.org/abs/1603.09246) 32 | 4. [Discriminative Unsupervised Feature Learning with Exemplar Convolutional 33 | Neural Networks](https://arxiv.org/abs/1406.6909) 34 | 35 | ## Usage instructions 36 | 37 | In the paper we train self-supervised models using 32 or 128 TPU cores. We 38 | evaluate the resulting representations by training a logistic regression model 39 | on 32 TPU cores. 40 | 41 | In this codebase we provide configurations for training/evaluation of our models 42 | using an 8 TPU core setup as this setup is more affordable for public TPU users 43 | through the Google Cloud API. These configurations produce results close to those 44 | reported in the paper, which used more TPU chips. 45 | 46 | For debugging or running small experiments we also support training and 47 | evaluation using a single GPU device. 48 | 49 | ### Preparing data 50 | 51 | Please refer to the 52 | [instructions in the slim library](https://github.com/tensorflow/models/blob/master/research/inception/README.md#getting-started) 53 | for downloading and preprocessing ImageNet data. 54 | 55 | ### Clone the repository and install dependencies 56 | 57 | ``` 58 | git clone https://github.com/google/revisiting-self-supervised 59 | cd revisiting-self-supervised 60 | python -m pip install -e . --user 61 | ``` 62 | 63 | We depend on some external files that need to be downloaded and placed in the 64 | root repository folder. You can run the following commands to download them: 65 | 66 | ``` 67 | wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/preprocessing/inception_preprocessing.py 68 | wget https://github.com/MehdiNoroozi/JigsawPuzzleSolver/raw/master/permutations_100_max.bin 69 | ``` 70 | 71 | ### Running locally on a single GPU 72 | 73 | Run any experiment by running the corresponding shell script with the following 74 | options, here exemplified for the fully supervised experiment: 75 | 76 | ``` 77 | ./config/supervised/imagenet.sh \ 78 | --workdir \ 79 | --nouse_tpu \ 80 | --master='' \ 81 | --dataset_dir 82 | ``` 83 | 84 | ### Running on Google Cloud using TPUs 85 | 86 | #### Step 1: 87 | 88 | Create your own TPU cloud instance by following the 89 | [official documentation](https://cloud.google.com/tpu/). 90 | 91 | #### Step 2: 92 | 93 | Clone the repository and install dependencies as described above. 94 | 95 | #### Step 3: 96 | 97 | Run the self supervised model training script with TPUs. For example: 98 | 99 | ``` 100 | gsutil mb gs:// 101 | export TPU_NAME= 102 | config/supervised/imagenet.sh --workdir gs:// --dataset_dir gs:// 103 | ``` 104 | 105 | After/during training, run the self supervised model evaluation script with 106 | TPUs. It generates the loss and metric on the validation set, and exports a hub 107 | module under directory `gs:///export/hub//module`: 108 | 109 | ``` 110 | config/supervised/imagenet.sh --workdir gs:// --dataset_dir gs:// --run_eval 111 | ``` 112 | 113 | Note, that `` is set by the user when creating the Cloud TPU 114 | node. Moreover, ImageNet data and the working directory should be placed in a 115 | Google Cloud bucket storage. 116 | 117 | #### Step 4: 118 | 119 | Evaluates the self supervised models with logistic regression. You need to pass 120 | the exported hub module from step 3 above as an additional argument: 121 | 122 | ``` 123 | gsutil mb gs:// 124 | export TPU_NAME= 125 | config/evaluation/rotation_or_exemplar.sh --workdir gs:// --dataset_dir gs:// --hub_module gs:// 126 | 127 | config/evaluation/rotation_or_exemplar.sh --workdir gs:// --dataset_dir gs:// --hub_module gs:// --run_eval 128 | ``` 129 | 130 | You could start a tensorboard to visualize the training/evaluation progress: 131 | 132 | ``` 133 | tensorboard --port 2222 --logdir gs:// 134 | ``` 135 | 136 | ## Pretrained models 137 | 138 | If you want to download and try our best self-supervised models please see this [Ipython 139 | notebook](https://colab.research.google.com/drive/1HdApkScZpulQrACrPKZiKYHhy7MeR3iN). 140 | 141 | 142 | ## Authors 143 | 144 | - [Alexander Kolesnikov](https://github.com/kolesman) 145 | - [Xiaohua Zhai](https://sites.google.com/site/xzhai89/) 146 | - [Lucas Beyer](http://lucasb.eyer.be/) 147 | - [Marvin Ritter](https://github.com/Marvin182) 148 | 149 | ### This is not an official Google product 150 | -------------------------------------------------------------------------------- /self_supervision/linear_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements fully-supervised model. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import collections 25 | import functools 26 | import os 27 | 28 | import numpy as np 29 | import tensorflow as tf 30 | import tensorflow_hub as hub 31 | 32 | import datasets 33 | from self_supervision.patch_utils import get_patch_representation 34 | from trainer import make_estimator 35 | import utils 36 | 37 | FLAGS = tf.flags.FLAGS 38 | 39 | 40 | def apply_fractional_pooling(taps, target_features=9000, mode='adaptive_max'): 41 | """Applies fractional pooling to each of `taps`. 42 | 43 | Args: 44 | taps: A dict of names:tensors to which to attach the head. 45 | target_features: If the input tensor has more than this number of features, 46 | perform fractional pooling to reduce it to this amount. 47 | mode: one of {'adaptive_max', 'adaptive_avg', 'max', 'avg'} 48 | 49 | Returns: 50 | tensors: An ordered dict with target_features dimension tensors. 51 | 52 | Raises: 53 | ValueError: mode is unexpected. 54 | """ 55 | out_tensors = collections.OrderedDict() 56 | for k, t in sorted(taps.items()): 57 | if len(t.get_shape().as_list()) == 2: 58 | t = t[:, None, None, :] 59 | _, h, w, num_channels = t.get_shape().as_list() 60 | if h * w * num_channels > target_features: 61 | t = utils.adaptive_pool(t, target_features, mode) 62 | _, h, w, num_channels = t.get_shape().as_list() 63 | out_tensors[k] = t 64 | 65 | return out_tensors 66 | 67 | 68 | def add_linear_heads(rep_tensors, n_out): 69 | """Adds a linear head to each of rep_tensors. 70 | 71 | Args: 72 | rep_tensors: A dict of names:tensors to which to attach the head. 73 | n_out: The number of features the head should map to. 74 | 75 | Returns: 76 | tensors: An ordered dict like `taps` but with the head's output as value. 77 | """ 78 | for k in list(rep_tensors.keys()): 79 | t = rep_tensors[k] 80 | t = tf.reshape( 81 | t, tf.stack([-1, 1, 1, tf.reduce_prod(t.shape[1:])])) 82 | t = tf.layers.conv2d( 83 | t, 84 | filters=n_out, 85 | kernel_size=1, 86 | padding='valid', 87 | activation=None, 88 | kernel_regularizer=tf.contrib.layers.l2_regularizer( 89 | scale=FLAGS.get_flag_value('weight_decay', 0.))) 90 | rep_tensors[k] = tf.squeeze(t, [1, 2]) 91 | 92 | return rep_tensors 93 | 94 | 95 | def add_mlp_heads(rep_tensors, n_out, is_training): 96 | """Adds a mlp head to each of rep_tensors. 97 | 98 | Args: 99 | rep_tensors: A dict of names:tensors to which to attach the head. 100 | n_out: The number of features the head should map to. 101 | is_training: whether in training mode. 102 | 103 | Returns: 104 | tensors: An ordered dict like `taps` but with the head's output as value. 105 | """ 106 | kernel_regularizer = tf.contrib.layers.l2_regularizer( 107 | scale=FLAGS.get_flag_value('weight_decay', 0.)) 108 | channels_hidden = FLAGS.get_flag_value('channels_hidden', 1000) 109 | for k, t in rep_tensors.iteritems(): 110 | t = tf.reshape(t, [-1, 1, 1, np.prod(t.shape[1:])]) 111 | t = tf.layers.conv2d( 112 | t, channels_hidden, kernel_size=1, padding='VALID', 113 | activation=tf.nn.relu, kernel_regularizer=kernel_regularizer) 114 | t = tf.layers.dropout(t, rate=0.5, training=is_training) 115 | t = tf.layers.conv2d( 116 | t, n_out, kernel_size=1, padding='VALID', 117 | kernel_regularizer=kernel_regularizer) 118 | rep_tensors[k] = tf.squeeze(t, [1, 2]) 119 | 120 | return rep_tensors 121 | 122 | 123 | def model_fn(data, mode): 124 | """Produces a loss for the fully-supervised task. 125 | 126 | Args: 127 | data: Dict of inputs containing, among others, "image" and "label." 128 | mode: model's mode: training, eval or prediction 129 | 130 | Returns: 131 | EstimatorSpec 132 | 133 | Raises: 134 | ValueError: Unexpected FLAGS.eval_model 135 | """ 136 | images = data['image'] 137 | tf.logging.info('model_fn(): features=%s, mode=%s)', images, mode) 138 | 139 | input_rank = len(images.get_shape().as_list()) 140 | image_input_rank = 4 # NHWC 141 | patch_input_rank = 5 # NPHWC 142 | assert input_rank in [image_input_rank, patch_input_rank], ( 143 | 'Unsupported input rank: %d' % input_rank) 144 | 145 | module = hub.Module(os.path.expanduser(str(FLAGS.hub_module))) 146 | 147 | if mode == tf.estimator.ModeKeys.PREDICT: 148 | return make_estimator( 149 | mode, 150 | predictions=module( 151 | images, 152 | signature=FLAGS.get_flag_value('signature', 'representation'), 153 | as_dict=True)) 154 | 155 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 156 | pooling_fn = functools.partial( 157 | apply_fractional_pooling, 158 | mode=FLAGS.get_flag_value('pool_mode', 'adaptive_max')) 159 | target_features = 9000 160 | if input_rank == patch_input_rank: 161 | out_tensors = get_patch_representation( 162 | images, 163 | module, 164 | patch_preprocess=None, 165 | is_training=is_training, 166 | target_features=target_features, 167 | combine_patches=FLAGS.get_flag_value('combine_patches', 'avg_pool'), 168 | signature='representation', 169 | pooling_fn=pooling_fn) 170 | else: 171 | out_tensors = module( 172 | images, 173 | signature=FLAGS.get_flag_value('signature', 'representation'), 174 | as_dict=True) 175 | out_tensors = pooling_fn(out_tensors, target_features=target_features) 176 | 177 | eval_model = FLAGS.get_flag_value('eval_model', 'linear') 178 | if eval_model == 'linear': 179 | out_logits = add_linear_heads(out_tensors, datasets.get_num_classes()) 180 | elif eval_model == 'mlp': 181 | out_logits = add_mlp_heads(out_tensors, datasets.get_num_classes(), 182 | is_training=is_training) 183 | else: 184 | raise ValueError('Unsupported eval %s model.' % eval_model) 185 | 186 | # build loss and accuracy 187 | labels = data['label'] 188 | losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, 189 | logits=logits) 190 | for logits in out_logits.values()] 191 | loss = tf.add_n([tf.reduce_mean(loss) for loss in losses]) 192 | 193 | metrics_fn = utils.get_classification_metrics( 194 | tensor_names=out_logits.keys()) 195 | # A tuple of metric_fn and a list of tensors to be evaluated by TPUEstimator. 196 | eval_metrics_tuple = (metrics_fn, [labels] + list(out_logits.values())) 197 | 198 | return make_estimator(mode, loss, eval_metrics_tuple) 199 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Util functions for representation learning. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import collections 25 | import csv 26 | import re 27 | 28 | import numpy as np 29 | import tensorflow as tf 30 | 31 | 32 | INPUT_DATA_STR = "input_data" 33 | IS_TRAINING_STR = "is_training" 34 | REPR_PREFIX_STR = "representation_" 35 | TAGS_IS_TRAINING = ["is_training"] 36 | 37 | 38 | def adaptive_pool(inp, num_target_dimensions=9000, mode="adaptive_max"): 39 | """Adaptive pooling layer. 40 | 41 | This layer performs adaptive pooling, such that the total 42 | dimensionality of output is not bigger than num_target_dimension 43 | 44 | Args: 45 | inp: input tensor 46 | num_target_dimensions: maximum number of output dimensions 47 | mode: one of {"adaptive_max", "adaptive_avg", "max", "avg"} 48 | 49 | Returns: 50 | Result of the pooling operation 51 | 52 | Raises: 53 | ValueError: mode is unexpected. 54 | """ 55 | 56 | size, _, k = inp.get_shape().as_list()[1:] 57 | if mode in ["adaptive_max", "adaptive_avg"]: 58 | if mode == "adaptive_max": 59 | pool_fn = tf.nn.fractional_max_pool 60 | else: 61 | pool_fn = tf.nn.fractional_avg_pool 62 | 63 | # Find the optimal target output tensor size 64 | target_size = (num_target_dimensions / float(k)) ** 0.5 65 | if (abs(num_target_dimensions - k * np.floor(target_size) ** 2) < 66 | abs(num_target_dimensions - k * np.ceil(target_size) ** 2)): 67 | target_size = max(np.floor(target_size), 1.0) 68 | else: 69 | target_size = max(np.ceil(target_size), 1.0) 70 | 71 | # Get optimal stride. Subtract epsilon to ensure correct rounding in 72 | # pool_fn. 73 | stride = size / target_size - 1.0e-5 74 | 75 | # Make sure that the stride is valid 76 | stride = max(stride, 1) 77 | stride = min(stride, size) 78 | 79 | result = pool_fn(inp, [1, stride, stride, 1])[0] 80 | elif mode in ["max", "avg"]: 81 | if mode == "max": 82 | pool_fn = tf.contrib.layers.max_pool2d 83 | else: 84 | pool_fn = tf.contrib.layers.avg_pool2d 85 | total_size = float(np.prod(inp.get_shape()[1:].as_list())) 86 | stride = int(np.ceil(np.sqrt(total_size / num_target_dimensions))) 87 | stride = min(max(1, stride), size) 88 | 89 | result = pool_fn(inp, kernel_size=stride, stride=stride) 90 | else: 91 | raise ValueError("Not supported %s pool." % mode) 92 | 93 | return result 94 | 95 | 96 | def append_multiple_rows_to_csv(dictionaries, csv_path): 97 | """Writes multiples rows to csv file from a list of dictionaries. 98 | 99 | Args: 100 | dictionaries: a list of dictionaries, mapping from csv header to value. 101 | csv_path: path to the result csv file. 102 | """ 103 | 104 | keys = set([]) 105 | for d in dictionaries: 106 | keys.update(d.keys()) 107 | 108 | if not tf.gfile.Exists(csv_path): 109 | with tf.gfile.Open(csv_path, "w") as f: 110 | writer = csv.DictWriter(f, sorted(keys)) 111 | writer.writeheader() 112 | f.flush() 113 | 114 | with tf.gfile.Open(csv_path, "a") as f: 115 | writer = csv.DictWriter(f, sorted(keys)) 116 | writer.writerows(dictionaries) 117 | f.flush() 118 | 119 | 120 | def concat_dicts(dict_list): 121 | """Given a list of dicts merges them into a single dict. 122 | 123 | This function takes a list of dictionaries as an input and then merges all 124 | these dictionaries into a single dictionary by concatenating the values 125 | (along the first axis) that correspond to the same key. 126 | 127 | Args: 128 | dict_list: list of dictionaries 129 | 130 | Returns: 131 | d: merged dictionary 132 | """ 133 | d = collections.defaultdict(list) 134 | for e in dict_list: 135 | for k, v in e.items(): 136 | d[k].append(v) 137 | for k in d: 138 | d[k] = tf.concat(d[k], axis=0) 139 | return d 140 | 141 | 142 | def str2intlist(s, repeats_if_single=None): 143 | """Parse a config's "1,2,3"-style string into a list of ints. 144 | 145 | Args: 146 | s: The string to be parsed, or possibly already an int. 147 | repeats_if_single: If s is already an int or is a single element list, 148 | repeat it this many times to create the list. 149 | 150 | Returns: 151 | A list of integers based on `s`. 152 | """ 153 | if isinstance(s, int): 154 | result = [s] 155 | else: 156 | result = [int(i.strip()) if i != "None" else None 157 | for i in s.split(",")] 158 | if repeats_if_single is not None and len(result) == 1: 159 | result *= repeats_if_single 160 | return result 161 | 162 | 163 | def tf_apply_to_image_or_images(fn, image_or_images): 164 | """Applies a function to a single image or each image in a batch of them. 165 | 166 | Args: 167 | fn: the function to apply, receives an image, returns an image. 168 | image_or_images: Either a single image, or a batch of images. 169 | 170 | Returns: 171 | The result of applying the function to the image or batch of images. 172 | 173 | Raises: 174 | ValueError: if the input is not of rank 3 or 4. 175 | """ 176 | static_rank = len(image_or_images.get_shape().as_list()) 177 | if static_rank == 3: # A single image: HWC 178 | return fn(image_or_images) 179 | elif static_rank == 4: # A batch of images: BHWC 180 | return tf.map_fn(fn, image_or_images) 181 | elif static_rank > 4: # A batch of images: ...HWC 182 | input_shape = tf.shape(image_or_images) 183 | h, w, c = image_or_images.get_shape().as_list()[-3:] 184 | image_or_images = tf.reshape(image_or_images, [-1, h, w, c]) 185 | image_or_images = tf.map_fn(fn, image_or_images) 186 | return tf.reshape(image_or_images, input_shape) 187 | else: 188 | raise ValueError("Unsupported image rank: %d" % static_rank) 189 | 190 | 191 | def tf_apply_with_probability(p, fn, x): 192 | """Apply function `fn` to input `x` randomly `p` percent of the time.""" 193 | return tf.cond( 194 | tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), p), 195 | lambda: fn(x), 196 | lambda: x) 197 | 198 | 199 | def expand_glob(glob_patterns): 200 | checkpoints = [] 201 | for pattern in glob_patterns: 202 | checkpoints.extend(tf.gfile.Glob(pattern)) 203 | assert checkpoints, "There are no checkpoints in " + str(glob_patterns) 204 | return checkpoints 205 | 206 | 207 | def get_latest_hub_per_task(hub_module_paths): 208 | """Get latest hub module for each task. 209 | 210 | The hub module path should match format ".*/hub/[0-9]*/module/.*". 211 | Example usage: 212 | get_latest_hub_per_task(expand_glob(["/cns/el-d/home/dune/representation/" 213 | "xzhai/1899361/*/export/hub/*/module/"])) 214 | returns 4 latest hub module from 4 tasks respectivley. 215 | 216 | Args: 217 | hub_module_paths: a list of hub module paths. 218 | 219 | Returns: 220 | A list of latest hub modules for each task. 221 | 222 | """ 223 | task_to_path = {} 224 | for path in hub_module_paths: 225 | task_name, module_name = path.split("/hub/") 226 | timestamp = int(re.findall(r"([0-9]*)/module", module_name)[0]) 227 | current_path = task_to_path.get(task_name, "0/module") 228 | current_timestamp = int(re.findall(r"([0-9]*)/module", current_path)[0]) 229 | if current_timestamp < timestamp: 230 | task_to_path[task_name] = path 231 | return sorted(task_to_path.values()) 232 | 233 | 234 | def get_classification_metrics(tensor_names): 235 | """Gets classification eval metric on input logits and labels. 236 | 237 | Args: 238 | tensor_names: a list of tensor names for _metrics input tensors. 239 | 240 | Returns: 241 | A function computes the metric result, from input logits and labels. 242 | """ 243 | 244 | def _top_k_accuracy(k, labels, logits): 245 | in_top_k = tf.nn.in_top_k(predictions=logits, targets=labels, k=k) 246 | return tf.metrics.mean(tf.cast(in_top_k, tf.float32)) 247 | 248 | def _metrics(labels, *tensors): 249 | """Computes the metric from logits and labels. 250 | 251 | Args: 252 | labels: ground truth labels. 253 | *tensors: tensors to be evaluated. 254 | 255 | Returns: 256 | Result dict mapping from the metric name to the list of result tensor and 257 | update_op used by tf.metrics. 258 | """ 259 | metrics = {} 260 | assert len(tensor_names) == len(tensors), "Names must match tensors." 261 | for i in range(len(tensors)): 262 | tensor = tensors[i] 263 | name = tensor_names[i] 264 | for k in (1, 5): 265 | metrics["top%d_accuracy_%s" % (k, name)] = _top_k_accuracy( 266 | k, labels, tensor) 267 | 268 | return metrics 269 | 270 | return _metrics 271 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # pylint: disable=missing-docstring 18 | """Preprocessing methods. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import functools 26 | import inception_preprocessing 27 | import tensorflow as tf 28 | 29 | import self_supervision.patch_model_preprocess as pp_lib 30 | import utils 31 | 32 | 33 | FLAGS = tf.flags.FLAGS 34 | 35 | 36 | def get_inception_preprocess(is_training, im_size): 37 | def _inception_preprocess(data): 38 | data["image"] = inception_preprocessing.preprocess_image( 39 | data["image"], im_size[0], im_size[1], is_training, 40 | add_image_summaries=False) 41 | return data 42 | return _inception_preprocess 43 | 44 | 45 | def get_resize_small(smaller_size): 46 | """Resizes the smaller side to `smaller_size` keeping aspect ratio.""" 47 | def _resize_small_pp(data): 48 | image = data["image"] 49 | # A single image: HWC 50 | # A batch of images: BHWC 51 | h, w = tf.shape(image)[-3], tf.shape(image)[-2] 52 | 53 | # Figure out the necessary h/w. 54 | ratio = tf.to_float(smaller_size) / tf.to_float(tf.minimum(h, w)) 55 | h = tf.to_int32(tf.round(tf.to_float(h) * ratio)) 56 | w = tf.to_int32(tf.round(tf.to_float(w) * ratio)) 57 | 58 | # NOTE: use align_corners=False for AREA resize, but True for Bilinear. 59 | # See also https://github.com/tensorflow/tensorflow/issues/6720 60 | static_rank = len(image.get_shape().as_list()) 61 | if static_rank == 3: # A single image: HWC 62 | data["image"] = tf.image.resize_area(image[None], [h, w])[0] 63 | elif static_rank == 4: # A batch of images: BHWC 64 | data["image"] = tf.image.resize_area(image, [h, w]) 65 | return data 66 | return _resize_small_pp 67 | 68 | 69 | def get_crop(is_training, crop_size): 70 | """Returns a random (or central at test-time) crop of `crop_size`.""" 71 | def _crop_pp(data): 72 | crop_fn = functools.partial( 73 | pp_lib.crop, is_training=is_training, crop_size=crop_size) 74 | data["image"] = utils.tf_apply_to_image_or_images(crop_fn, data["image"]) 75 | 76 | return data 77 | return _crop_pp 78 | 79 | 80 | def get_inception_crop(is_training, **kw): 81 | # kw of interest are: aspect_ratio_range, area_range. 82 | # Note that image is not resized yet here. 83 | def _inception_crop_pp(data): 84 | if is_training: 85 | image = data["image"] 86 | begin, size, _ = tf.image.sample_distorted_bounding_box( 87 | tf.shape(image), tf.zeros([0, 0, 4], tf.float32), 88 | use_image_if_no_bounding_boxes=True, **kw) 89 | data["image"] = tf.slice(image, begin, size) 90 | # Unfortunately, the above operation loses the depth-dimension. So we need 91 | # to Restore it the manual way. 92 | data["image"].set_shape([None, None, image.shape[-1]]) 93 | return data 94 | return _inception_crop_pp 95 | 96 | 97 | def get_random_flip_lr(is_training): 98 | def _random_flip_lr_pp(data): 99 | if is_training: 100 | data["image"] = utils.tf_apply_to_image_or_images( 101 | tf.image.random_flip_left_right, data["image"]) 102 | return data 103 | return _random_flip_lr_pp 104 | 105 | 106 | def get_resize_preprocess(im_size, randomize_resize_method=False): 107 | 108 | def _resize(image, method, align_corners): 109 | 110 | def _process(): 111 | # The resized_images are of type float32 and might fall outside of range 112 | # [0, 255]. 113 | resized = tf.cast( 114 | tf.image.resize_images( 115 | image, im_size, method, align_corners=align_corners), 116 | dtype=tf.float32) 117 | return resized 118 | 119 | return _process 120 | 121 | def _resize_pp(data): 122 | im = data["image"] 123 | 124 | if randomize_resize_method: 125 | # pick random resizing method 126 | r = tf.random_uniform([], 0, 3, dtype=tf.int32) 127 | im = tf.case({ 128 | tf.equal(r, tf.cast(0, r.dtype)): 129 | _resize(im, tf.image.ResizeMethod.BILINEAR, True), 130 | tf.equal(r, tf.cast(1, r.dtype)): 131 | _resize(im, tf.image.ResizeMethod.NEAREST_NEIGHBOR, True), 132 | tf.equal(r, tf.cast(2, r.dtype)): 133 | _resize(im, tf.image.ResizeMethod.BICUBIC, True), 134 | # NOTE: use align_corners=False for AREA resize, but True for the 135 | # others. See https://github.com/tensorflow/tensorflow/issues/6720 136 | tf.equal(r, tf.cast(3, r.dtype)): 137 | _resize(im, tf.image.ResizeMethod.AREA, False), 138 | }) 139 | else: 140 | im = tf.image.resize_images(im, im_size) 141 | data["image"] = im 142 | return data 143 | 144 | return _resize_pp 145 | 146 | 147 | def get_rotate_preprocess(): 148 | """Returns a function that does 90deg rotations and sets according labels.""" 149 | 150 | def _rotate_pp(data): 151 | data["label"] = tf.constant([0, 1, 2, 3]) 152 | # We use our own instead of tf.image.rot90 because that one broke 153 | # internally shortly before deadline... 154 | data["image"] = tf.stack([ 155 | data["image"], 156 | tf.transpose(tf.reverse_v2(data["image"], [1]), [1, 0, 2]), 157 | tf.reverse_v2(data["image"], [0, 1]), 158 | tf.reverse_v2(tf.transpose(data["image"], [1, 0, 2]), [1]), 159 | ]) 160 | return data 161 | 162 | return _rotate_pp 163 | 164 | 165 | def get_value_range_preprocess(vmin=-1, vmax=1, dtype=tf.float32): 166 | """Returns a function that sends [0,255] image to [vmin,vmax].""" 167 | 168 | def _value_range_pp(data): 169 | img = tf.cast(data["image"], dtype) 170 | img = vmin + (img / tf.constant(255.0, dtype)) * (vmax - vmin) 171 | data["image"] = img 172 | return data 173 | return _value_range_pp 174 | 175 | 176 | def get_standardization_preprocess(): 177 | def _standardization_pp(data): 178 | # Trick: normalize each patch to avoid low level statistics. 179 | data["image"] = utils.tf_apply_to_image_or_images( 180 | tf.image.per_image_standardization, data["image"]) 181 | return data 182 | return _standardization_pp 183 | 184 | 185 | def get_inception_preprocess_patches(is_training, resize_size, num_of_patches): 186 | 187 | def _inception_preprocess_patches(data): 188 | patches = [] 189 | for _ in range(num_of_patches): 190 | patches.append( 191 | inception_preprocessing.preprocess_image( 192 | data["image"], 193 | resize_size[0], 194 | resize_size[1], 195 | is_training, 196 | add_image_summaries=False)) 197 | patches = tf.stack(patches) 198 | data["image"] = patches 199 | return data 200 | 201 | return _inception_preprocess_patches 202 | 203 | 204 | def get_to_gray_preprocess(grayscale_probability): 205 | 206 | def _to_gray(image): 207 | # Transform to grayscale by taking the mean of RGB. 208 | return tf.tile(tf.reduce_mean(image, axis=2, keepdims=True), [1, 1, 3]) 209 | 210 | def _to_gray_pp(data): 211 | data["image"] = utils.tf_apply_to_image_or_images( 212 | lambda img: utils.tf_apply_with_probability( # pylint:disable=g-long-lambda 213 | grayscale_probability, _to_gray, img), 214 | data["image"]) 215 | return data 216 | 217 | return _to_gray_pp 218 | 219 | 220 | def get_preprocess_fn(fn_names, is_training): 221 | """Returns preprocessing function. 222 | 223 | Args: 224 | fn_names: name of a preprocessing function. 225 | is_training: Whether this should be run in train or eval mode. 226 | Returns: 227 | preprocessing function 228 | 229 | Raises: 230 | ValueError: if preprocessing function name is unknown 231 | """ 232 | 233 | def _fn(data): 234 | def expand(fn_name): 235 | if fn_name == "plain_preprocess": 236 | yield lambda x: x 237 | elif fn_name == "0_to_1": 238 | yield get_value_range_preprocess(0, 1) 239 | elif fn_name == "-1_to_1": 240 | yield get_value_range_preprocess(-1, 1) 241 | elif fn_name == "resize": 242 | yield get_resize_preprocess( 243 | utils.str2intlist(FLAGS.resize_size, 2), 244 | is_training and FLAGS.get_flag_value("randomize_resize_method", 245 | False)) 246 | elif fn_name == "resize_small": 247 | yield get_resize_small(FLAGS.smaller_size) 248 | elif fn_name == "crop": 249 | yield get_crop(is_training, 250 | utils.str2intlist(FLAGS.crop_size, 2)) 251 | elif fn_name == "central_crop": 252 | yield get_crop(False, utils.str2intlist(FLAGS.crop_size, 2)) 253 | elif fn_name == "inception_crop": 254 | yield get_inception_crop(is_training) 255 | elif fn_name == "flip_lr": 256 | yield get_random_flip_lr(is_training) 257 | elif fn_name == "crop_inception_preprocess_patches": 258 | yield get_inception_preprocess_patches( 259 | is_training, utils.str2intlist(FLAGS.resize_size, 2), 260 | FLAGS.num_of_inception_patches) 261 | elif fn_name == "to_gray": 262 | yield get_to_gray_preprocess( 263 | FLAGS.get_flag_value("grayscale_probability", 1.0)) 264 | elif fn_name == "crop_patches": 265 | yield pp_lib.get_crop_patches_fn( 266 | is_training, 267 | split_per_side=FLAGS.splits_per_side, 268 | patch_jitter=FLAGS.get_flag_value("patch_jitter", 0)) 269 | elif fn_name == "standardization": 270 | yield get_standardization_preprocess() 271 | elif fn_name == "rotate": 272 | yield get_rotate_preprocess() 273 | 274 | # Below this line specific combos decomposed. 275 | # It would be nice to move them to the configs at some point. 276 | 277 | elif fn_name == "inception_preprocess": 278 | yield get_inception_preprocess( 279 | is_training, utils.str2intlist(FLAGS.resize_size, 2)) 280 | else: 281 | raise ValueError("Not supported preprocessing %s" % fn_name) 282 | 283 | # Apply all the individual steps in sequence. 284 | tf.logging.info("Data before pre-processing:\n%s", data) 285 | for fn_name in fn_names.split(","): 286 | for p in expand(fn_name.strip()): 287 | data = p(data) 288 | tf.logging.info("Data after `%s`:\n%s", p, data) 289 | return data 290 | 291 | return _fn 292 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Datasets class to provide images and labels in tf batch. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import abc 25 | import os 26 | 27 | import tensorflow as tf 28 | 29 | from preprocess import get_preprocess_fn 30 | 31 | FLAGS = tf.flags.FLAGS 32 | 33 | 34 | class AbstractDataset(object): 35 | """Base class for datasets using the simplied input pipeline.""" 36 | 37 | def __init__(self, 38 | filenames, 39 | reader, 40 | num_epochs, 41 | shuffle, 42 | shuffle_buffer_size=10000, 43 | random_seed=None, 44 | num_reader_threads=64, 45 | drop_remainder=True): 46 | """Creates a new dataset. Sub-classes have to implement _parse_fn(). 47 | 48 | Args: 49 | filenames: A list of filenames. 50 | reader: A dataset reader, e.g. `tf.data.TFRecordDataset`. 51 | `tf.data.TextLineDataset` and `tf.data.FixedLengthRecordDataset`. 52 | num_epochs: An int, defaults to `None`. Number of epochs to cycle 53 | through the dataset before stopping. If set to `None` this will read 54 | samples indefinitely. 55 | shuffle: A boolean, defaults to `False`. Whether output data are 56 | shuffled. 57 | shuffle_buffer_size: `int`, number of examples in the buffer for 58 | shuffling. 59 | random_seed: Optional int. Random seed for shuffle operation. 60 | num_reader_threads: An int, defaults to None. Number of threads reading 61 | from files. When `shuffle` is False, number of threads is set to 1. When 62 | using default value, there is one thread per filenames. 63 | drop_remainder: If true, then the last incomplete batch is dropped. 64 | """ 65 | self.filenames = filenames 66 | self.reader = reader 67 | self.num_reader_threads = num_reader_threads 68 | self.num_epochs = num_epochs 69 | self.shuffle = shuffle 70 | self.shuffle_buffer_size = shuffle_buffer_size 71 | self.random_seed = random_seed 72 | self.drop_remainder = drop_remainder 73 | 74 | # Additional options for optimizing TPU input pipelines. 75 | self.num_parallel_batches = 8 76 | 77 | def _make_source_dataset(self): 78 | """Reads the files in self.filenames and returns a `tf.data.Dataset`. 79 | 80 | This does not parse the examples! 81 | 82 | Returns: 83 | `tf.data.Dataset` repeated for self.num_epochs and shuffled if 84 | self.shuffle is `True`. Files are always read in parallel and sloppy. 85 | """ 86 | # Shuffle the filenames to ensure better randomization. 87 | dataset = tf.data.Dataset.list_files(self.filenames, shuffle=self.shuffle, 88 | seed=self.random_seed) 89 | 90 | dataset = dataset.repeat(self.num_epochs) 91 | 92 | def fetch_dataset(filename): 93 | buffer_size = 8 * 1024 * 1024 # 8 MiB per file 94 | dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) 95 | return dataset 96 | 97 | # Read the data from disk in parallel 98 | dataset = dataset.apply( 99 | tf.data.experimental.parallel_interleave( 100 | fetch_dataset, 101 | cycle_length=self.num_reader_threads, 102 | sloppy=self.shuffle and self.random_seed is None)) 103 | 104 | if self.shuffle: 105 | dataset = dataset.shuffle(self.shuffle_buffer_size, seed=self.random_seed) 106 | return dataset 107 | 108 | @abc.abstractmethod 109 | def _parse_fn(self, value): 110 | """Parses an image and its label from a serialized TFExample. 111 | 112 | Args: 113 | value: serialized string containing an TFExample. 114 | 115 | Returns: 116 | Returns a tuple of (image, label) from the TFExample. 117 | """ 118 | raise NotImplementedError 119 | 120 | def input_fn(self, params): 121 | """Input function which provides a single batch for train or eval. 122 | 123 | Args: 124 | params: `dict` of parameters passed from the `TPUEstimator`. 125 | `params['batch_size']` is provided and should be used as the effective 126 | batch size. 127 | 128 | Returns: 129 | A `tf.data.Dataset` object. 130 | """ 131 | # Retrieves the batch size for the current shard. The # of shards is 132 | # computed according to the input pipeline deployment. See 133 | # tf.contrib.tpu.RunConfig for details. 134 | batch_size = params['batch_size'] 135 | 136 | dataset = self._make_source_dataset() 137 | 138 | # Use the fused map-and-batch operation. 139 | # 140 | # For XLA, we must used fixed shapes. Because we repeat the source training 141 | # dataset indefinitely, we can use `drop_remainder=True` to get fixed-size 142 | # batches without dropping any training examples. 143 | # 144 | # When evaluating, `drop_remainder=True` prevents accidentally evaluating 145 | # the same image twice by dropping the final batch if it is less than a full 146 | # batch size. As long as this validation is done with consistent batch size, 147 | # exactly the same images will be used. 148 | dataset = dataset.apply( 149 | tf.data.experimental.map_and_batch( 150 | self._parse_fn, 151 | batch_size=batch_size, 152 | num_parallel_batches=self.num_parallel_batches, 153 | drop_remainder=self.drop_remainder)) 154 | 155 | # Prefetch overlaps in-feed with training 156 | dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) 157 | return dataset 158 | 159 | 160 | def generate_sharded_filenames(filename): 161 | base, count = filename.split('@') 162 | count = int(count) 163 | return ['{}-{:05d}-of-{:05d}'.format(base, i, count) 164 | for i in range(count)] 165 | 166 | 167 | class DatasetImagenet(AbstractDataset): 168 | """Provides train/val/trainval/test splits for Imagenet data. 169 | 170 | -> trainval split represents official Imagenet train split. 171 | -> train split is derived by taking the first 984 of 1024 shards of 172 | the offcial training data. 173 | -> val split is derived by taking the last 40 shard of the official 174 | training data. 175 | -> test split represents official Imagenet test split. 176 | """ 177 | 178 | COUNTS = {'train': 1231121, 179 | 'val': 50046, 180 | 'trainval': 1281167, 181 | 'test': 50000} 182 | 183 | NUM_CLASSES = 1000 184 | IMAGE_KEY = 'image/encoded' 185 | LABEL_KEY = 'image/class/label' 186 | 187 | FEATURE_MAP = { 188 | IMAGE_KEY: tf.FixedLenFeature(shape=[], dtype=tf.string), 189 | LABEL_KEY: tf.FixedLenFeature(shape=[], dtype=tf.int64) 190 | } 191 | 192 | LABEL_OFFSET = 1 193 | 194 | def __init__(self, 195 | split_name, 196 | preprocess_fn, 197 | num_epochs, 198 | shuffle, 199 | random_seed=None, 200 | drop_remainder=True): 201 | """Initialize the dataset object. 202 | 203 | Args: 204 | split_name: A string split name, to load from the dataset. 205 | preprocess_fn: Preprocess a single example. The example is already 206 | parsed into a dictionary. 207 | num_epochs: An int, defaults to `None`. Number of epochs to cycle 208 | through the dataset before stopping. If set to `None` this will read 209 | samples indefinitely. 210 | shuffle: A boolean, defaults to `False`. Whether output data are 211 | shuffled. 212 | random_seed: Optional int. Random seed for shuffle operation. 213 | drop_remainder: If true, then the last incomplete batch is dropped. 214 | """ 215 | # This is an instance-variable instead of a class-variable because it 216 | # depends on FLAGS, which is not parsed yet at class-parse-time. 217 | files = os.path.join(os.path.expanduser(FLAGS.dataset_dir), 218 | '%s@%i') 219 | filenames = { 220 | 'train': generate_sharded_filenames(files % ('train', 1024))[:-40], 221 | 'val': generate_sharded_filenames(files % ('train', 1024))[-40:], 222 | 'trainval': generate_sharded_filenames(files % ('train', 1024)), 223 | 'test': generate_sharded_filenames(files % ('validation', 128)) 224 | } 225 | 226 | super(DatasetImagenet, self).__init__( 227 | filenames=filenames[split_name], 228 | reader=tf.data.TFRecordDataset, 229 | num_epochs=num_epochs, 230 | shuffle=shuffle, 231 | random_seed=random_seed, 232 | drop_remainder=drop_remainder) 233 | self.split_name = split_name 234 | self.preprocess_fn = preprocess_fn 235 | 236 | def _parse_fn(self, value): 237 | """Parses an image and its label from a serialized TFExample. 238 | 239 | Args: 240 | value: serialized string containing an TFExample. 241 | 242 | Returns: 243 | Returns a tuple of (image, label) from the TFExample. 244 | """ 245 | example = tf.parse_single_example(value, self.FEATURE_MAP) 246 | image = tf.image.decode_jpeg(example[self.IMAGE_KEY], channels=3) 247 | # Subtract LABEL_OFFSET so that labels are in [0, 1000). 248 | label = tf.cast(example[self.LABEL_KEY], tf.int32) - self.LABEL_OFFSET 249 | 250 | return self.preprocess_fn({'image': image, 'label': label}) 251 | 252 | 253 | DATASET_MAP = { 254 | 'imagenet': DatasetImagenet, 255 | } 256 | 257 | 258 | def get_data(params, 259 | split_name, 260 | is_training, 261 | shuffle=True, 262 | num_epochs=None, 263 | drop_remainder=False): 264 | """Produces image/label tensors for a given dataset. 265 | 266 | Args: 267 | params: dictionary with `batch_size` entry (thanks TPU...). 268 | split_name: data split, e.g. train, val, test 269 | is_training: whether to run pre-processing in train or test mode. 270 | shuffle: if True, shuffles the data 271 | num_epochs: number of epochs. If None, proceeds indefenitely 272 | drop_remainder: Drop remaining examples in the last dataset batch. It is 273 | useful for third party checkpoints with fixed batch size. 274 | 275 | Returns: 276 | image, label, example counts 277 | """ 278 | dataset = DATASET_MAP[FLAGS.dataset] 279 | preprocess_fn = get_preprocess_fn(FLAGS.preprocessing, is_training) 280 | 281 | return dataset( 282 | split_name=split_name, 283 | preprocess_fn=preprocess_fn, 284 | num_epochs=num_epochs, 285 | shuffle=shuffle, 286 | random_seed=FLAGS.get_flag_value('random_seed', None), 287 | drop_remainder=drop_remainder).input_fn(params) 288 | 289 | 290 | def get_count(split_name): 291 | return DATASET_MAP[FLAGS.dataset].COUNTS[split_name] 292 | 293 | 294 | def get_num_classes(): 295 | return DATASET_MAP[FLAGS.dataset].NUM_CLASSES 296 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2019 Google LLC 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /self_supervision/patch_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Utils for patch based image processing.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import functools 24 | import struct 25 | import numpy as np 26 | import tensorflow as tf 27 | import tensorflow_hub as hub 28 | 29 | import preprocess 30 | import utils 31 | from models.utils import get_net 32 | from trainer import make_estimator 33 | 34 | FLAGS = tf.flags.FLAGS 35 | 36 | PATCH_H_COUNT = 3 37 | PATCH_W_COUNT = 3 38 | PATCH_COUNT = PATCH_H_COUNT * PATCH_W_COUNT 39 | 40 | 41 | # It's supposed to be in the root folder, which is also pwd when running, if the 42 | # instructions in the README are followed. Hence not a flag. 43 | PERMUTATION_PATH = 'permutations_100_max.bin' 44 | 45 | 46 | def apply_model(image_fn, 47 | is_training, 48 | num_outputs, 49 | perms, 50 | make_signature=False): 51 | """Creates the patch based model output from patches representations. 52 | 53 | Args: 54 | image_fn: function returns image tensor. 55 | is_training: is training flag used for batch norm and drop out. 56 | num_outputs: number of output classes. 57 | perms: numpy array with shape [m, k], element range [0, PATCH_COUNT). k 58 | stands for the patch numbers used in a permutation. m stands forthe number 59 | of permutations. Each permutation is used to concat the patch inputs 60 | [n*PATCH_COUNT, h, w, c] into tensor with shape [n*m, h, w, c*k]. 61 | make_signature: whether to create signature for hub module. 62 | 63 | Returns: 64 | out: output tensor with shape [n*m, 1, 1, num_outputs]. 65 | 66 | Raises: 67 | ValueError: An error occurred when the architecture is unknown. 68 | """ 69 | images = image_fn() 70 | 71 | net = get_net(num_classes=FLAGS.get_flag_value('embed_dim', 1000)) 72 | out, end_points = net(images, is_training, 73 | weight_decay=FLAGS.get_flag_value('weight_decay', 1e-4)) 74 | 75 | print(end_points) 76 | if not make_signature: 77 | out = permutate_and_concat_batch_patches(out, perms) 78 | out = fully_connected(out, num_outputs, is_training=is_training) 79 | 80 | out = tf.squeeze(out, [1, 2]) 81 | 82 | if make_signature: 83 | hub.add_signature(inputs={'image': images}, outputs=out) 84 | hub.add_signature( 85 | name='representation', 86 | inputs={'image': images}, 87 | outputs=end_points) 88 | return out 89 | 90 | 91 | def image_grid(images, ny, nx, padding=0): 92 | """Create a batch of image grids from a batch of images. 93 | 94 | Args: 95 | images: A batch of patches (B,N,H,W,C) 96 | ny: vertical number of images 97 | nx: horizontal number of images 98 | padding: number of zeros between images, if any. 99 | 100 | Returns: 101 | A tensor batch of image grids shaped (B,H*ny,W*nx,C), although that is a 102 | simplifying lie: if padding is used h/w will be different. 103 | """ 104 | with tf.name_scope('grid_image'): 105 | if padding: 106 | padding = [padding, padding] 107 | images = tf.pad(images, [[0, 0], [0, 0], padding, padding, [0, 0]]) 108 | 109 | return tf.concat([ 110 | tf.concat([images[:, y * nx + x] for x in range(nx)], axis=-2) 111 | for y in range(ny)], axis=-3) 112 | 113 | 114 | def creates_estimator_model(images, labels, perms, num_classes, mode): 115 | """Creates EstimatorSpec for the patch based self supervised models. 116 | 117 | Args: 118 | images: images 119 | labels: self supervised labels (class indices) 120 | perms: patch permutations 121 | num_classes: number of different permutations 122 | mode: model's mode: training, eval or prediction 123 | 124 | Returns: 125 | EstimatorSpec 126 | """ 127 | print(' +++ Mode: %s, images: %s, labels: %s' % (mode, images, labels)) 128 | 129 | images = tf.reshape(images, shape=[-1] + images.get_shape().as_list()[-3:]) 130 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: 131 | with tf.variable_scope('module'): 132 | image_fn = lambda: images 133 | logits = apply_model( 134 | image_fn=image_fn, 135 | is_training=(mode == tf.estimator.ModeKeys.TRAIN), 136 | num_outputs=num_classes, 137 | perms=perms, 138 | make_signature=False) 139 | else: 140 | input_shape = utils.str2intlist( 141 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3')) 142 | image_fn = lambda: tf.placeholder( # pylint: disable=g-long-lambda 143 | shape=input_shape, 144 | dtype=tf.float32) 145 | 146 | apply_model_function = functools.partial( 147 | apply_model, 148 | image_fn=image_fn, 149 | num_outputs=num_classes, 150 | perms=perms, 151 | make_signature=True) 152 | 153 | tf_hub_module_spec = hub.create_module_spec( 154 | apply_model_function, [(utils.TAGS_IS_TRAINING, { 155 | 'is_training': True 156 | }), (set(), { 157 | 'is_training': False 158 | })], 159 | drop_collections=['summaries']) 160 | tf_hub_module = hub.Module(tf_hub_module_spec, trainable=False, tags=set()) 161 | hub.register_module_for_export(tf_hub_module, export_name='module') 162 | logits = tf_hub_module(images) 163 | return make_estimator(mode, predictions=logits) 164 | 165 | # build loss and accuracy 166 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 167 | labels=labels, logits=logits) 168 | loss = tf.reduce_mean(loss) 169 | 170 | eval_metrics = ( 171 | lambda labels, logits: { # pylint: disable=g-long-lambda 172 | 'accuracy': tf.metrics.accuracy( 173 | labels=labels, predictions=tf.argmax(logits, axis=-1))}, 174 | [labels, logits]) 175 | return make_estimator(mode, loss, eval_metrics, logits) 176 | 177 | 178 | def fully_connected(inputs, 179 | num_classes=100, 180 | weight_decay=5e-4, 181 | keep_prob=0.5, 182 | is_training=True): 183 | """Two layers fully connected network copied from Alexnet fc7-fc8.""" 184 | net = inputs 185 | _, _, w, _ = net.get_shape().as_list() 186 | kernel_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay) 187 | net = tf.layers.conv2d( 188 | net, 189 | filters=4096, 190 | kernel_size=w, 191 | padding='same', 192 | kernel_initializer=tf.truncated_normal_initializer(0.0, 0.005), 193 | bias_initializer=tf.constant_initializer(0.1), 194 | kernel_regularizer=kernel_regularizer) 195 | net = tf.layers.batch_normalization( 196 | net, momentum=0.997, epsilon=1e-5, fused=None, training=is_training) 197 | net = tf.nn.relu(net) 198 | if is_training: 199 | net = tf.nn.dropout(net, keep_prob=keep_prob) 200 | net = tf.layers.conv2d( 201 | net, 202 | filters=num_classes, 203 | kernel_size=1, 204 | padding='same', 205 | kernel_initializer=tf.truncated_normal_initializer(0.0, 0.005), 206 | bias_initializer=tf.zeros_initializer(), 207 | kernel_regularizer=kernel_regularizer) 208 | 209 | return net 210 | 211 | 212 | def generate_patch_locations(): 213 | """Generates relative patch locations.""" 214 | perms = np.array([(i, 4) for i in range(9) if i != 4]) 215 | return perms, len(perms) 216 | 217 | 218 | def load_permutations(): 219 | """Loads a set of pre-defined permutations.""" 220 | with tf.gfile.Open(PERMUTATION_PATH, 'rb') as f: 221 | int32_size = 4 222 | s = f.read(int32_size * 2) 223 | [num_perms, c] = struct.unpack(' NPHWC 358 | t = tf.reshape(t, [-1, num_of_patches] + t.get_shape().as_list()[-3:]) 359 | if combine_patches == 'concat': 360 | # [N, P, H, W, C] -> [N, H, W, P*C] 361 | _, p, h, w, c = t.get_shape().as_list() 362 | out_tensors[k] = tf.reshape( 363 | tf.transpose(t, perm=[0, 2, 3, 4, 1]), tf.stack([-1, h, w, p * c])) 364 | elif combine_patches == 'max_pool': 365 | # Reduce max on P channel of NPHWC. 366 | out_tensors[k] = tf.reduce_max(t, axis=1) 367 | elif combine_patches == 'avg_pool': 368 | # Reduce mean on P channel of NPHWC. 369 | out_tensors[k] = tf.reduce_mean(t, axis=1) 370 | else: 371 | raise ValueError( 372 | 'Unsupported combine patches method %s.' % combine_patches) 373 | 374 | return out_tensors 375 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Resnet model. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import functools 25 | 26 | import tensorflow as tf 27 | 28 | 29 | def get_shape_as_list(x): 30 | return x.get_shape().as_list() 31 | 32 | 33 | def fixed_padding(x, kernel_size): 34 | pad_total = kernel_size - 1 35 | pad_beg = pad_total // 2 36 | pad_end = pad_total - pad_beg 37 | 38 | x = tf.pad(x, [[0, 0], 39 | [pad_beg, pad_end], [pad_beg, pad_end], 40 | [0, 0]]) 41 | return x 42 | 43 | 44 | def batch_norm(x, training): 45 | return tf.layers.batch_normalization(x, fused=True, training=training) 46 | 47 | 48 | def identity_norm(x, training): 49 | del training 50 | return x 51 | 52 | 53 | def bottleneck_v1(x, filters, training, # pylint: disable=missing-docstring 54 | strides=1, 55 | activation_fn=tf.nn.relu, 56 | normalization_fn=batch_norm, 57 | kernel_regularizer=None): 58 | 59 | # Record input tensor, such that it can be used later in as skip-connection 60 | x_shortcut = x 61 | 62 | # Project input if necessary 63 | if (strides > 1) or (filters != x.shape[-1]): 64 | x_shortcut = tf.layers.conv2d(x_shortcut, filters=filters, kernel_size=1, 65 | strides=strides, 66 | kernel_regularizer=kernel_regularizer, 67 | use_bias=False, 68 | padding='SAME') 69 | x_shortcut = normalization_fn(x_shortcut, training=training) 70 | 71 | # First convolution 72 | # Note, that unlike original Resnet paper we never use stride in the first 73 | # convolution. Instead, we apply stride in the second convolution. The reason 74 | # is that the first convolution has kernel of size 1x1, which results in 75 | # information loss when combined with stride bigger than one. 76 | x = tf.layers.conv2d(x, filters=filters // 4, 77 | kernel_size=1, 78 | kernel_regularizer=kernel_regularizer, 79 | use_bias=False, 80 | padding='SAME') 81 | x = normalization_fn(x, training=training) 82 | x = activation_fn(x) 83 | 84 | # Second convolution 85 | x = fixed_padding(x, kernel_size=3) 86 | x = tf.layers.conv2d(x, filters=filters // 4, 87 | strides=strides, 88 | kernel_size=3, 89 | kernel_regularizer=kernel_regularizer, 90 | use_bias=False, 91 | padding='VALID') 92 | x = normalization_fn(x, training=training) 93 | x = activation_fn(x) 94 | 95 | # Third convolution 96 | x = tf.layers.conv2d(x, filters=filters, 97 | kernel_size=1, 98 | kernel_regularizer=kernel_regularizer, 99 | use_bias=False, 100 | padding='SAME') 101 | x = normalization_fn(x, training=training) 102 | 103 | # Skip connection 104 | x = x_shortcut + x 105 | x = activation_fn(x) 106 | 107 | return x 108 | 109 | 110 | def bottleneck_v2(x, filters, training, # pylint: disable=missing-docstring 111 | strides=1, 112 | activation_fn=tf.nn.relu, 113 | normalization_fn=batch_norm, 114 | kernel_regularizer=None, 115 | no_shortcut=False): 116 | 117 | # Record input tensor, such that it can be used later in as skip-connection 118 | x_shortcut = x 119 | 120 | x = normalization_fn(x, training=training) 121 | x = activation_fn(x) 122 | 123 | # Project input if necessary 124 | if (strides > 1) or (filters != x.shape[-1]): 125 | x_shortcut = tf.layers.conv2d(x, filters=filters, kernel_size=1, 126 | strides=strides, 127 | kernel_regularizer=kernel_regularizer, 128 | use_bias=False, 129 | padding='VALID') 130 | 131 | # First convolution 132 | # Note, that unlike original Resnet paper we never use stride in the first 133 | # convolution. Instead, we apply stride in the second convolution. The reason 134 | # is that the first convolution has kernel of size 1x1, which results in 135 | # information loss when combined with stride bigger than one. 136 | x = tf.layers.conv2d(x, filters=filters // 4, 137 | kernel_size=1, 138 | kernel_regularizer=kernel_regularizer, 139 | use_bias=False, 140 | padding='SAME') 141 | 142 | # Second convolution 143 | x = normalization_fn(x, training=training) 144 | x = activation_fn(x) 145 | # Note, that padding depends on the dilation rate. 146 | x = fixed_padding(x, kernel_size=3) 147 | x = tf.layers.conv2d(x, filters=filters // 4, 148 | strides=strides, 149 | kernel_size=3, 150 | kernel_regularizer=kernel_regularizer, 151 | use_bias=False, 152 | padding='VALID') 153 | 154 | # Third convolution 155 | x = normalization_fn(x, training=training) 156 | x = activation_fn(x) 157 | x = tf.layers.conv2d(x, filters=filters, 158 | kernel_size=1, 159 | kernel_regularizer=kernel_regularizer, 160 | use_bias=False, 161 | padding='SAME') 162 | 163 | if no_shortcut: 164 | return x 165 | else: 166 | return x + x_shortcut 167 | 168 | 169 | def resnet(x, # pylint: disable=missing-docstring 170 | is_training, 171 | num_layers, 172 | strides=(2, 2, 2), 173 | num_classes=1000, 174 | filters_factor=4, 175 | weight_decay=1e-4, 176 | include_root_block=True, 177 | root_conv_size=7, root_conv_stride=2, 178 | root_pool_size=3, root_pool_stride=2, 179 | activation_fn=tf.nn.relu, 180 | last_relu=True, 181 | normalization_fn=batch_norm, 182 | global_pool=True, 183 | mode='v2'): 184 | 185 | assert mode in ['v1', 'v2'], 'Unknown Resnet mode: {}'.format(mode) 186 | unit = bottleneck_v2 if mode == 'v2' else bottleneck_v1 187 | 188 | end_points = {} 189 | 190 | filters = 16 * filters_factor 191 | 192 | kernel_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay) 193 | 194 | if include_root_block: 195 | x = fixed_padding(x, kernel_size=root_conv_size) 196 | x = tf.layers.conv2d(x, filters=filters, 197 | kernel_size=root_conv_size, 198 | strides=root_conv_stride, 199 | padding='VALID', use_bias=False, 200 | kernel_regularizer=kernel_regularizer) 201 | 202 | if mode == 'v1': 203 | x = normalization_fn(x, training=is_training) 204 | x = activation_fn(x) 205 | 206 | x = fixed_padding(x, kernel_size=root_pool_size) 207 | x = tf.layers.max_pooling2d(x, pool_size=root_pool_size, 208 | strides=root_pool_stride, padding='VALID') 209 | end_points['after_root'] = x 210 | 211 | params = {'activation_fn': activation_fn, 212 | 'normalization_fn': normalization_fn, 213 | 'training': is_training, 214 | 'kernel_regularizer': kernel_regularizer, 215 | } 216 | 217 | strides = list(strides)[::-1] 218 | num_layers = list(num_layers)[::-1] 219 | 220 | filters *= 4 221 | for _ in range(num_layers.pop()): 222 | x = unit(x, filters, strides=1, **params) 223 | end_points['block1'] = x 224 | 225 | filters *= 2 226 | x = unit(x, filters, strides=strides.pop(), **params) 227 | for _ in range(num_layers.pop() - 1): 228 | x = unit(x, filters, strides=1, **params) 229 | end_points['block2'] = x 230 | 231 | filters *= 2 232 | x = unit(x, filters, strides=strides.pop(), **params) 233 | for _ in range(num_layers.pop() - 1): 234 | x = unit(x, filters, strides=1, **params) 235 | end_points['block3'] = x 236 | 237 | filters *= 2 238 | x = unit(x, filters, strides=strides.pop(), **params) 239 | for _ in range(num_layers.pop() - 1): 240 | x = unit(x, filters, strides=1, **params) 241 | end_points['block4'] = x 242 | 243 | if (mode == 'v1') and (not last_relu): 244 | raise ValueError('last_relu is always True (implicitly) in the v1 mode.') 245 | 246 | if mode == 'v2': 247 | x = normalization_fn(x, training=is_training) 248 | if last_relu: 249 | x = activation_fn(x) 250 | 251 | if global_pool: 252 | x = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 253 | end_points['pre_logits'] = tf.squeeze(x, [1, 2]) 254 | else: 255 | end_points['pre_logits'] = x 256 | 257 | if num_classes: 258 | logits = tf.layers.conv2d(x, filters=num_classes, 259 | kernel_size=1, 260 | kernel_regularizer=kernel_regularizer) 261 | if global_pool: 262 | logits = tf.squeeze(logits, [1, 2]) 263 | end_points['logits'] = logits 264 | return logits, end_points 265 | else: 266 | return end_points['pre_logits'], end_points 267 | 268 | resnet50 = functools.partial(resnet, num_layers=(3, 4, 6, 3)) 269 | 270 | # Experimental code ######################################## 271 | # "Reversible" resnet ###################################### 272 | 273 | 274 | # Invertible residual block as outlined in https://arxiv.org/abs/1707.04585 275 | def bottleneck_rev(x, training, # pylint: disable=missing-docstring 276 | activation_fn=tf.nn.relu, 277 | normalization_fn=batch_norm, 278 | kernel_regularizer=None): 279 | 280 | unit = bottleneck_v2 281 | 282 | x1, x2 = tf.split(x, 2, 3) 283 | 284 | y1 = x1 + unit(x2, x2.shape[-1], training, 285 | strides=1, 286 | activation_fn=activation_fn, 287 | normalization_fn=normalization_fn, 288 | kernel_regularizer=kernel_regularizer, 289 | no_shortcut=True) 290 | y2 = x2 291 | 292 | return tf.concat([y2, y1], axis=3) 293 | 294 | 295 | def pool_and_double_channels(x, pool_stride): 296 | if pool_stride > 1: 297 | x = tf.layers.average_pooling2d(x, pool_size=pool_stride, 298 | strides=pool_stride, 299 | padding='SAME') 300 | return tf.pad(x, [[0, 0], [0, 0], [0, 0], 301 | [x.shape[3] // 2, x.shape[3] // 2]]) 302 | 303 | 304 | def revnet(x, # pylint: disable=missing-docstring 305 | is_training, 306 | num_layers, 307 | strides=(2, 2, 2), 308 | num_classes=1000, 309 | filters_factor=4, 310 | weight_decay=1e-4, 311 | include_root_block=True, 312 | root_conv_size=7, root_conv_stride=2, 313 | root_pool_size=3, root_pool_stride=2, 314 | global_pool=True, 315 | activation_fn=tf.nn.relu, 316 | normalization_fn=batch_norm, 317 | last_relu=False, 318 | mode='v2'): 319 | 320 | del mode # unused parameter, exists for compatibility with resnet function 321 | 322 | unit = bottleneck_rev 323 | 324 | end_points = {} 325 | 326 | filters = 16 * filters_factor 327 | 328 | kernel_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay) 329 | 330 | # First convolution serves as random projection in order to increase number 331 | # of channels. It is not possible to skip it. 332 | x = fixed_padding(x, kernel_size=root_conv_size) 333 | x = tf.layers.conv2d(x, filters=4 * filters, 334 | kernel_size=root_conv_size, 335 | strides=root_conv_stride, 336 | padding='VALID', use_bias=False, 337 | kernel_regularizer=None) 338 | 339 | if include_root_block: 340 | x = fixed_padding(x, kernel_size=root_pool_size) 341 | x = tf.layers.max_pooling2d( 342 | x, pool_size=root_pool_size, strides=root_pool_stride, padding='VALID') 343 | 344 | end_points['after_root'] = x 345 | 346 | params = {'activation_fn': activation_fn, 347 | 'normalization_fn': normalization_fn, 348 | 'training': is_training, 349 | 'kernel_regularizer': kernel_regularizer, 350 | } 351 | 352 | num_layers = list(num_layers)[::-1] 353 | strides = list(strides)[::-1] 354 | 355 | for _ in range(num_layers.pop()): 356 | x = unit(x, **params) 357 | end_points['block1'] = x 358 | x = pool_and_double_channels(x, strides.pop()) 359 | 360 | for _ in range(num_layers.pop()): 361 | x = unit(x, **params) 362 | end_points['block2'] = x 363 | x = pool_and_double_channels(x, strides.pop()) 364 | 365 | for _ in range(num_layers.pop()): 366 | x = unit(x, **params) 367 | end_points['block3'] = x 368 | x = pool_and_double_channels(x, strides.pop()) 369 | 370 | for _ in range(num_layers.pop()): 371 | x = unit(x, **params) 372 | end_points['block4'] = x 373 | 374 | x = normalization_fn(x, training=is_training) 375 | 376 | if last_relu: 377 | x = activation_fn(x) 378 | 379 | if global_pool: 380 | x = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 381 | end_points['pre_logits'] = tf.squeeze(x, [1, 2]) 382 | else: 383 | end_points['pre_logits'] = x 384 | 385 | if num_classes: 386 | logits = tf.layers.conv2d(x, filters=num_classes, 387 | kernel_size=1, 388 | kernel_regularizer=kernel_regularizer) 389 | if global_pool: 390 | logits = tf.squeeze(logits, [1, 2]) 391 | end_points['logits'] = logits 392 | return logits, end_points 393 | else: 394 | return end_points['pre_logits'], end_points 395 | 396 | 397 | revnet50 = functools.partial(revnet, num_layers=(3, 4, 6, 3)) 398 | -------------------------------------------------------------------------------- /train_and_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # pylint: disable=line-too-long 18 | r"""The main script for starting training and evaluation. 19 | 20 | How to run: 21 | blaze run -c opt --config=dmtf_cuda \ 22 | learning/brain/research/dune/experimental/representation/release/train_and_eval -- \ 23 | --workdir /tmp/test \ 24 | --config /google/src/cloud/akolesnikov/release/release/config/supervised/imagenet.py \ 25 | --nouse_tpu 26 | """ 27 | # pylint: enable=line-too-long 28 | 29 | from __future__ import absolute_import 30 | from __future__ import division 31 | 32 | import functools 33 | import math 34 | import os 35 | 36 | import absl.app as app 37 | import absl.flags as flags 38 | import absl.logging as logging 39 | 40 | import tensorflow as tf 41 | import tensorflow_hub as hub 42 | 43 | import datasets 44 | from self_supervision.self_supervision_lib import get_self_supervision_model 45 | import utils 46 | 47 | from tensorflow.contrib.cluster_resolver import TPUClusterResolver 48 | 49 | 50 | FLAGS = flags.FLAGS 51 | 52 | # General run setup flags. 53 | flags.DEFINE_string('workdir', None, 'Where to store files.') 54 | flags.mark_flag_as_required('workdir') 55 | 56 | flags.DEFINE_integer('num_gpus', 1, 'Number of GPUs to use.') 57 | flags.DEFINE_bool('use_tpu', True, 'Whether running on TPU or not.') 58 | flags.DEFINE_bool('run_eval', False, 'Run eval mode') 59 | 60 | flags.DEFINE_string('tpu_worker_name', 'tpu_worker', 61 | 'Name of a TPU worker.') 62 | 63 | # More detailed experiment flags 64 | flags.DEFINE_string('dataset', None, 'Which dataset to use, typically ' 65 | '`imagenet`.') 66 | flags.mark_flag_as_required('dataset') 67 | 68 | flags.DEFINE_string('dataset_dir', None, 'Location of the dataset files.') 69 | flags.mark_flag_as_required('dataset_dir') 70 | 71 | flags.DEFINE_integer('eval_batch_size', None, 'Optional different batch-size' 72 | ' evaluation, defaults to the same as `batch_size`.') 73 | 74 | flags.DEFINE_integer('keep_checkpoint_every_n_hours', None, 'Keep one ' 75 | 'checkpoint every this many hours. Otherwise, only the ' 76 | 'last few ones are kept. Defaults to 4h.') 77 | 78 | flags.DEFINE_integer('random_seed', None, 'Seed to use. None is random.') 79 | 80 | flags.DEFINE_integer('save_checkpoints_secs', None, 'Every how many seconds ' 81 | 'to save a checkpoint. Defaults to 600 ie every 10mins.') 82 | 83 | flags.DEFINE_string('serving_input_key', None, 'The name of the input tensor ' 84 | 'in the generated hub module. Just leave it at default.') 85 | 86 | flags.DEFINE_string('serving_input_shape', None, 'The shape of the input tensor' 87 | ' in the stored hub module. Can contain `None`.') 88 | 89 | flags.DEFINE_string('signature', None, 'The name of the tensor to use as ' 90 | 'representation for evaluation. Just leave to default.') 91 | 92 | flags.DEFINE_string('task', None, 'Which pretext-task to learn from. Can be ' 93 | 'one of `rotation`, `exemplar`, `jigsaw`, ' 94 | '`relative_patch_location`, `linear_eval`, `supervised`.') 95 | flags.mark_flag_as_required('task') 96 | 97 | flags.DEFINE_string('train_split', None, 'Which dataset split to train on. ' 98 | 'Should only be `train` (default) or `trainval`.') 99 | flags.DEFINE_string('val_split', None, 'Which dataset split to eval on. ' 100 | 'Should only be `val` (default) or `test`.') 101 | 102 | # Flags about the pretext tasks 103 | 104 | flags.DEFINE_integer('embed_dim', None, 'For most pretext tasks, which ' 105 | 'dimension the embedding/hidden vector should be. ' 106 | 'Defaults to 1000.') 107 | 108 | flags.DEFINE_float('margin', None, 'For the `exemplar` pretext task, ' 109 | 'how large the triplet loss margin should be.') 110 | 111 | flags.DEFINE_integer('num_of_inception_patches', None, 'For the Exemplar ' 112 | 'pretext task, how many instances of an image to create.') 113 | 114 | flags.DEFINE_integer('patch_jitter', None, 'For patch-based methods, by how ' 115 | 'many pixels to jitter the patches. Defaults to 0.') 116 | 117 | flags.DEFINE_integer('perm_subset_size', None, 'Subset of permutations to ' 118 | 'sample per example in the `jigsaw` pretext task. ' 119 | 'Defaults to 8.') 120 | 121 | flags.DEFINE_integer('splits_per_side', None, 'For the `crop_patches` ' 122 | 'preprocessor, how many times to split a side. ' 123 | 'For example, 3 will result in 3x3=9 patches.') 124 | 125 | # Flags for evaluation. 126 | flags.DEFINE_string('eval_model', None, 'Whether to perform evaluation with a ' 127 | '`linear` (default) model, or with an `mlp` model.') 128 | 129 | flags.DEFINE_string('hub_module', None, 'Folder where the hub module that ' 130 | 'should be evaluated is stored.') 131 | 132 | flags.DEFINE_string('pool_mode', None, 'When running evaluation on ' 133 | 'intermediate layers (not logits) of the network, it is ' 134 | 'commonplace to pool the features down to 9000. This ' 135 | 'decides the pooling method to be used: `adaptive_max` ' 136 | '(default), `adaptive_avg`, `max`, or `avg`.') 137 | 138 | flags.DEFINE_string('combine_patches', None, 'When running evaluation on ' 139 | 'patch models, it is used to merge patch representations' 140 | 'to the full image representation. The value should be set' 141 | 'to `avg_pool`(default), or `concat`.') 142 | 143 | # Flags about the model. 144 | flags.DEFINE_string('architecture', None, 145 | help='Which basic network architecture to use. ' 146 | 'One of vgg19, resnet50, revnet50.') 147 | # flags.mark_flag_as_required('architecture') # Not required in eval mode. 148 | 149 | flags.DEFINE_integer('filters_factor', None, 'Widening factor for network ' 150 | 'filters. For ResNet, default = 4 = vanilla ResNet.') 151 | 152 | flags.DEFINE_bool('last_relu', None, 'Whether to include (default) the final ' 153 | 'ReLU layer in ResNet/RevNet models or not.') 154 | 155 | flags.DEFINE_string('mode', None, 'Which ResNet to use, `v1` or `v2`.') 156 | 157 | # Flags about the optimization process. 158 | flags.DEFINE_integer('batch_size', None, 'The global batch-size to use.') 159 | flags.mark_flag_as_required('batch_size') 160 | 161 | flags.DEFINE_string('decay_epochs', None, 'Optional list of epochs at which ' 162 | 'learning-rate decay should happen, such as `15,25`.') 163 | 164 | flags.DEFINE_integer('epochs', None, 'Number of epochs to run training.') 165 | flags.mark_flag_as_required('epochs') 166 | 167 | flags.DEFINE_float('lr_decay_factor', None, 'Factor by which to decay the ' 168 | 'learning-rate at each decay step. Default 0.1.') 169 | 170 | flags.DEFINE_float('lr', None, 'The base learning-rate to use for training.') 171 | flags.mark_flag_as_required('lr') 172 | 173 | flags.DEFINE_float('lr_scale_batch_size', None, 'The batch-size for which the ' 174 | 'base learning-rate `lr` is defined. For batch-sizes ' 175 | 'different from that, it is scaled linearly accordingly.' 176 | 'For example lr=0.1, batch_size=128, lr_scale_batch_size=32' 177 | ', then actual lr=0.025.') 178 | flags.mark_flag_as_required('lr_scale_batch_size') 179 | 180 | flags.DEFINE_string('optimizer', None, 'Which optimizer to use. ' 181 | 'Only `sgd` (default) or `adam` are supported.') 182 | 183 | flags.DEFINE_integer('warmup_epochs', None, 'Duration of the linear learning-' 184 | 'rate warm-up (from 0 to actual). Defaults to 0.') 185 | 186 | flags.DEFINE_float('weight_decay', None, 'Strength of weight-decay. ' 187 | 'Defaults to 1e-4, and may be set to 0.') 188 | 189 | # Flags about pre-processing/data augmentation. 190 | flags.DEFINE_string('crop_size', None, 'Size of the crop when using `crop` ' 191 | 'or `central_crop` preprocessing. Either a single ' 192 | 'integer like `32` or a pair like `32,24`.') 193 | 194 | flags.DEFINE_float('grayscale_probability', None, 'When using `to_gray` ' 195 | 'preprocessing, probability of actually doing it. Defaults ' 196 | 'to 1.0, i.e. deterministically grayscaling the input.') 197 | 198 | flags.DEFINE_string('preprocessing', None, 'A comma-separated list of ' 199 | 'pre-processing steps to perform, see preprocess.py.') 200 | flags.mark_flag_as_required('preprocessing') 201 | 202 | flags.DEFINE_bool('randomize_resize_method', None, 'Whether or not (default) ' 203 | 'to use a random interpolation method in the `resize` ' 204 | 'preprocessor.') 205 | 206 | flags.DEFINE_string('resize_size', None, 'For the `resize`, ' 207 | '`inception_preprocess`, and ' 208 | '`crop_inception_preprocess_patches` preprocessors, the ' 209 | 'size in pixels to which to resize the input. Can be a ' 210 | 'single number for square, or a pair as `128,64`.') 211 | 212 | flags.DEFINE_integer('smaller_size', None, 'For the `resize_small` preprocessor' 213 | ', the desired size that the smaller side should have ' 214 | 'after resizing the image (keeping aspect ratio).') 215 | 216 | 217 | # Number of iterations (=training steps) per TPU training loop. Use >100 for 218 | # good speed. This is the minimum number of steps between checkpoints. 219 | TPU_ITERATIONS_PER_LOOP = 500 220 | 221 | 222 | def train_and_eval(): 223 | """Trains a network on (self) supervised data.""" 224 | checkpoint_dir = os.path.join(FLAGS.workdir) 225 | 226 | if FLAGS.use_tpu: 227 | master = TPUClusterResolver( 228 | tpu=[os.environ['TPU_NAME']]).get_master() 229 | else: 230 | master = '' 231 | 232 | config = tf.contrib.tpu.RunConfig( 233 | model_dir=checkpoint_dir, 234 | tf_random_seed=FLAGS.get_flag_value('random_seed', None), 235 | master=master, 236 | evaluation_master=master, 237 | keep_checkpoint_every_n_hours=FLAGS.get_flag_value( 238 | 'keep_checkpoint_every_n_hours', 4), 239 | save_checkpoints_secs=FLAGS.get_flag_value('save_checkpoints_secs', 600), 240 | tpu_config=tf.contrib.tpu.TPUConfig( 241 | iterations_per_loop=TPU_ITERATIONS_PER_LOOP, 242 | tpu_job_name=FLAGS.tpu_worker_name)) 243 | 244 | # The global batch-sizes are passed to the TPU estimator, and it will pass 245 | # along the local batch size in the model_fn's `params` argument dict. 246 | estimator = tf.contrib.tpu.TPUEstimator( 247 | model_fn=get_self_supervision_model(FLAGS.task), 248 | model_dir=checkpoint_dir, 249 | config=config, 250 | use_tpu=FLAGS.use_tpu, 251 | train_batch_size=FLAGS.batch_size, 252 | eval_batch_size=FLAGS.get_flag_value('eval_batch_size', FLAGS.batch_size)) 253 | 254 | if FLAGS.run_eval: 255 | data_fn = functools.partial( 256 | datasets.get_data, 257 | split_name=FLAGS.get_flag_value('val_split', 'val'), 258 | is_training=False, 259 | shuffle=False, 260 | num_epochs=1, 261 | drop_remainder=FLAGS.use_tpu) 262 | 263 | # Contrary to what the documentation claims, the `train` and the 264 | # `evaluate` functions NEED to have `max_steps` and/or `steps` set and 265 | # cannot make use of the iterator's end-of-input exception, so we need 266 | # to do some math for that here. 267 | num_samples = datasets.get_count(FLAGS.get_flag_value('val_split', 'val')) 268 | num_steps = num_samples // FLAGS.get_flag_value('eval_batch_size', 269 | FLAGS.batch_size) 270 | tf.logging.info('val_steps: %d', num_steps) 271 | 272 | for checkpoint in tf.contrib.training.checkpoints_iterator( 273 | estimator.model_dir, timeout=10 * 60): 274 | 275 | estimator.evaluate( 276 | checkpoint_path=checkpoint, input_fn=data_fn, steps=num_steps) 277 | 278 | hub_exporter = hub.LatestModuleExporter('hub', serving_input_fn) 279 | hub_exporter.export( 280 | estimator, 281 | os.path.join(checkpoint_dir, 'export/hub'), 282 | checkpoint) 283 | 284 | if tf.gfile.Exists(os.path.join(FLAGS.workdir, 'TRAINING_IS_DONE')): 285 | break 286 | 287 | # Evaluates the latest checkpoint on validation set. 288 | result = estimator.evaluate(input_fn=data_fn, steps=num_steps) 289 | return result 290 | 291 | else: 292 | train_data_fn = functools.partial( 293 | datasets.get_data, 294 | split_name=FLAGS.get_flag_value('train_split', 'train'), 295 | is_training=True, 296 | num_epochs=int(math.ceil(FLAGS.epochs)), 297 | drop_remainder=True) 298 | 299 | # We compute the number of steps and make use of Estimator's max_steps 300 | # arguments instead of relying on the Dataset's iterator to run out after 301 | # a number of epochs so that we can use 'fractional' epochs, which are 302 | # used by regression tests. (And because TPUEstimator needs it anyways.) 303 | num_samples = datasets.get_count(FLAGS.get_flag_value('train_split', 304 | 'train')) 305 | # Depending on whether we drop the last batch each epoch or only at the 306 | # ver end, this should be ordered differently for rounding. 307 | updates_per_epoch = num_samples // FLAGS.batch_size 308 | num_steps = int(math.ceil(FLAGS.epochs * updates_per_epoch)) 309 | tf.logging.info('train_steps: %d', num_steps) 310 | 311 | estimator.train(train_data_fn, max_steps=num_steps) 312 | 313 | 314 | def serving_input_fn(): 315 | """A serving input fn.""" 316 | input_shape = utils.str2intlist( 317 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3')) 318 | image_features = { 319 | FLAGS.get_flag_value('serving_input_key', 'image'): 320 | tf.placeholder(dtype=tf.float32, shape=input_shape)} 321 | return tf.estimator.export.ServingInputReceiver( 322 | features=image_features, receiver_tensors=image_features) 323 | 324 | 325 | def main(unused_argv): 326 | # logging.info('config: %s', FLAGS) 327 | logging.info('workdir: %s', FLAGS.workdir) 328 | 329 | train_and_eval() 330 | 331 | logging.info('I\'m done with my work, ciao!') 332 | 333 | 334 | if __name__ == '__main__': 335 | app.run(main) 336 | --------------------------------------------------------------------------------