├── __init__.py
├── models
├── __init__.py
├── utils.py
├── vggnet.py
└── resnet.py
├── self_supervision
├── __init__.py
├── relative_patch_location.py
├── jigsaw.py
├── self_supervision_lib.py
├── patch_model_preprocess.py
├── rotation.py
├── supervised.py
├── exemplar.py
├── linear_eval.py
└── patch_utils.py
├── config
├── supervised
│ └── imagenet.sh
├── rotation
│ └── imagenet.sh
├── evaluation
│ ├── rotation_or_exemplar.sh
│ └── jigsaw_or_relative_patch_location.sh
├── exemplar
│ └── imagenet.sh
├── jigsaw
│ └── imagenet.sh
└── relative_patch_location
│ └── imagenet.sh
├── CONTRIBUTING.md
├── setup.py
├── trainer.py
├── README.md
├── utils.py
├── preprocess.py
├── datasets.py
├── LICENSE
└── train_and_eval.py
/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
--------------------------------------------------------------------------------
/self_supervision/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
--------------------------------------------------------------------------------
/config/supervised/imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | #!/bin/sh
18 |
19 | python train_and_eval.py \
20 | --task supervised \
21 | --architecture resnet50 \
22 | --filters_factor 4 \
23 | --weight_decay 1e-4 \
24 | --dataset imagenet \
25 | --train_split trainval \
26 | --val_split test \
27 | --batch_size 256 \
28 | --eval_batch_size 10 \
29 | \
30 | --preprocessing inception_preprocess \
31 | --resize_size 224 \
32 | \
33 | --lr 0.1 \
34 | --lr_scale_batch_size 256 \
35 | --decay_epochs 30,60,80 \
36 | --epochs 90 \
37 | --warmup_epochs 5 \
38 | "$@"
39 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/config/rotation/imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | #!/bin/sh
18 |
19 | python train_and_eval.py \
20 | --task rotation \
21 | --dataset imagenet \
22 | --train_split train \
23 | --val_split val \
24 | --batch_size 64 \
25 | --eval_batch_size 16 \
26 | \
27 | --architecture revnet50 \
28 | --filters_factor 16 \
29 | --last_relu True \
30 | \
31 | --preprocessing inception_preprocess,rotate \
32 | --resize_size 224,224 \
33 | \
34 | --lr 0.1 \
35 | --lr_scale_batch_size 256 \
36 | --decay_epochs 15,25 \
37 | --epochs 35 \
38 | --warmup_epochs 5 \
39 | \
40 | --serving_input_shape None,224,224,3 \
41 | "$@"
42 |
--------------------------------------------------------------------------------
/config/evaluation/rotation_or_exemplar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | #!/bin/sh
18 |
19 | # For the reported results, try following parameters on 4x4 TPUs:
20 | # batch_size: 2048
21 | # decay_epochs: 480,500
22 | # epochs: 520
23 | python train_and_eval.py \
24 | --task linear_eval \
25 | --dataset imagenet \
26 | --train_split trainval \
27 | --val_split test \
28 | --batch_size 512 \
29 | --eval_batch_size 32 \
30 | \
31 | --preprocessing resize_small,crop,-1_to_1 \
32 | --crop_size 224,224 \
33 | --smaller_size 256 \
34 | --pool_mode max \
35 | \
36 | --lr 0.1 \
37 | --decay_epochs 30,50 \
38 | --epochs 70 \
39 | --lr_scale_batch_size 256 \
40 | --hub_module ~/Downloads/module \
41 | "$@"
42 |
--------------------------------------------------------------------------------
/config/exemplar/imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | #!/bin/sh
18 |
19 | python train_and_eval.py \
20 | --task exemplar \
21 | --dataset imagenet \
22 | --train_split train \
23 | --val_split val \
24 | --batch_size 64 \
25 | --eval_batch_size 16 \
26 | \
27 | --architecture resnet50 \
28 | --filters_factor 12 \
29 | --last_relu True \
30 | --mode v1 \
31 | \
32 | --preprocessing to_gray,crop_inception_preprocess_patches,standardization \
33 | --resize_size 224,224 \
34 | --grayscale_probability 0.66 \
35 | --embed_dim 1000 \
36 | --margin 0.5 \
37 | --num_of_inception_patches 8 \
38 | \
39 | --lr 0.1 \
40 | --lr_scale_batch_size 256 \
41 | --decay_epochs 15,25 \
42 | --epochs 35 \
43 | --warmup_epochs 5 \
44 | \
45 | --serving_input_shape None,224,224,3 \
46 | "$@"
47 |
--------------------------------------------------------------------------------
/config/jigsaw/imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | #!/bin/sh
18 |
19 | python train_and_eval.py \
20 | --task jigsaw \
21 | --dataset imagenet \
22 | --train_split train \
23 | --val_split val \
24 | --batch_size 128 \
25 | --eval_batch_size 8 \
26 | \
27 | --architecture resnet50 \
28 | --filters_factor 8 \
29 | --last_relu True \
30 | --mode v1 \
31 | \
32 | --preprocessing resize,to_gray,crop,crop_patches,standardization \
33 | --resize_size 292,292 \
34 | --crop_size 255 \
35 | --grayscale_probability 0.66 \
36 | --splits_per_side 3 \
37 | --patch_jitter 21 \
38 | --embed_dim 1000 \
39 | \
40 | --lr 0.1 \
41 | --lr_scale_batch_size 256 \
42 | --decay_epochs 15,25 \
43 | --epochs 35 \
44 | --warmup_epochs 5 \
45 | \
46 | --serving_input_shape None,64,64,3 \
47 | "$@"
48 |
--------------------------------------------------------------------------------
/config/relative_patch_location/imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | #!/bin/sh
18 |
19 | python train_and_eval.py \
20 | --task relative_patch_location \
21 | --dataset imagenet \
22 | --train_split train \
23 | --val_split val \
24 | --batch_size 128 \
25 | --eval_batch_size 8 \
26 | \
27 | --architecture resnet50 \
28 | --filters_factor 8 \
29 | --last_relu True \
30 | --mode v1 \
31 | \
32 | --preprocessing resize,to_gray,crop,crop_patches,standardization \
33 | --resize_size 292,292 \
34 | --crop_size 255 \
35 | --grayscale_probability 0.66 \
36 | --splits_per_side 3 \
37 | --patch_jitter 21 \
38 | --embed_dim 1000 \
39 | \
40 | --lr 0.1 \
41 | --lr_scale_batch_size 256 \
42 | --decay_epochs 15,25 \
43 | --epochs 35 \
44 | --warmup_epochs 5 \
45 | \
46 | --serving_input_shape None,64,64,3 \
47 | "$@"
48 |
--------------------------------------------------------------------------------
/config/evaluation/jigsaw_or_relative_patch_location.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | #!/bin/sh
18 |
19 | # For the reported results, try following parameters on 4x4 TPUs:
20 | # batch_size: 2048
21 | # decay_epochs: 480,500
22 | # epochs: 520
23 | python train_and_eval.py \
24 | --task linear_eval \
25 | --dataset imagenet \
26 | --train_split trainval \
27 | --val_split test \
28 | --batch_size 512 \
29 | --eval_batch_size 32 \
30 | \
31 | --preprocessing resize_small,crop,crop_patches,standardization \
32 | --resize_size 256,256 \
33 | --crop_size 192,192 \
34 | --smaller_size 224 \
35 | --patch_jitter 0 \
36 | --splits_per_side 3 \
37 | --pool_mode max \
38 | --combine_patches avg_pool \
39 | \
40 | --lr 0.1 \
41 | --decay_epochs 30,50 \
42 | --epochs 70 \
43 | --lr_scale_batch_size 256 \
44 | --hub_module ~/Downloads/module \
45 | "$@"
46 |
--------------------------------------------------------------------------------
/self_supervision/relative_patch_location.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Produces ratations for input images.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import tensorflow as tf
25 |
26 | from self_supervision import patch_utils
27 |
28 |
29 | def model_fn(data, mode):
30 | """Produces a loss for the relative patch location task.
31 |
32 | Args:
33 | data: Dict of inputs ("image" being the image)
34 | mode: model's mode: training, eval or prediction
35 |
36 | Returns:
37 | EstimatorSpec
38 | """
39 | images = data['image']
40 |
41 | # Patch locations
42 | perms, num_classes = patch_utils.generate_patch_locations()
43 | labels = tf.tile(list(range(num_classes)), tf.shape(images)[:1])
44 |
45 | return patch_utils.creates_estimator_model(
46 | images, labels, perms, num_classes, mode)
47 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Install self supervised learning package."""
18 |
19 | from setuptools import find_packages
20 | from setuptools import setup
21 |
22 | setup(
23 | name='self_supervised_learning',
24 | version='1.0',
25 | description=('Self Supervised Learning - code from "Revisiting '
26 | 'Self-Supervised Visual Representation Learning" paper'),
27 | author='Google LLC',
28 | author_email='no-reply@google.com',
29 | url='http://github.com/TODO',
30 | license='Apache 2.0',
31 | packages=find_packages(),
32 | package_data={
33 | },
34 | scripts=[
35 | ],
36 | install_requires=[
37 | 'future',
38 | 'numpy',
39 | 'absl-py',
40 | 'tensorflow',
41 | 'tensorflow-hub',
42 | ],
43 | classifiers=[
44 | 'Development Status :: 4 - Beta',
45 | 'Intended Audience :: Developers',
46 | 'Intended Audience :: Science/Research',
47 | 'License :: OSI Approved :: Apache Software License',
48 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
49 | ],
50 | keywords='tensorflow self supervised learning',
51 | )
52 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Helper functions for NN models.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import functools
25 | import absl.flags as flags
26 |
27 | import models.resnet
28 | import models.vggnet
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | def get_net(num_classes=None): # pylint: disable=missing-docstring
34 | architecture = FLAGS.architecture
35 |
36 | if 'vgg19' in architecture:
37 | net = functools.partial(
38 | models.vggnet.vgg19,
39 | filters_factor=FLAGS.get_flag_value('filters_factor', 8))
40 | else:
41 | if 'resnet50' in architecture:
42 | net = models.resnet.resnet50
43 | elif 'revnet50' in architecture:
44 | net = models.resnet.revnet50
45 | else:
46 | raise ValueError('Unsupported architecture: %s' % architecture)
47 |
48 | net = functools.partial(
49 | net,
50 | filters_factor=FLAGS.get_flag_value('filters_factor', 4),
51 | last_relu=FLAGS.get_flag_value('last_relu', True),
52 | mode=FLAGS.get_flag_value('mode', 'v2'))
53 |
54 | if FLAGS.task in ('jigsaw', 'relative_patch_location'):
55 | net = functools.partial(net, root_conv_stride=1, strides=(2, 2, 1))
56 |
57 | # Few things that are common across all models.
58 | net = functools.partial(
59 | net, num_classes=num_classes,
60 | weight_decay=FLAGS.get_flag_value('weight_decay', 1e-4))
61 |
62 | return net
63 |
--------------------------------------------------------------------------------
/self_supervision/jigsaw.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Produces ratations for input images.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import tensorflow as tf
25 |
26 | from self_supervision import patch_utils
27 |
28 | FLAGS = tf.flags.FLAGS
29 |
30 |
31 | def model_fn(data, mode):
32 | """Produces a loss for the jigsaw task.
33 |
34 | Args:
35 | data: Dict of inputs ("image" being the image)
36 | mode: model's mode: training, eval or prediction
37 |
38 | Returns:
39 | EstimatorSpec
40 | """
41 | images = data['image']
42 |
43 | # Patch locations
44 | perms, num_classes = patch_utils.load_permutations()
45 | labels = list(range(num_classes))
46 |
47 | # Selects a subset of permutation for training. There're two methods:
48 | # 1. For each image, selects 16 permutations independently.
49 | # 2. For each batch of images, selects the same 16 permutations.
50 | # Here we used method 2, for simplicity.
51 | if mode in [tf.estimator.ModeKeys.TRAIN]:
52 | perm_subset_size = FLAGS.get_flag_value('perm_subset_size', 8)
53 | indexs = list(range(num_classes))
54 | indexs = tf.random_shuffle(indexs)
55 | labels = indexs[:perm_subset_size]
56 | perms = tf.gather(perms, labels, axis=0)
57 | tf.logging.info('subsample %s' % perms)
58 |
59 | labels = tf.tile(labels, tf.shape(images)[:1])
60 |
61 | return patch_utils.creates_estimator_model(
62 | images, labels, perms, num_classes, mode)
63 |
--------------------------------------------------------------------------------
/self_supervision/self_supervision_lib.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Generates training data with self supervision.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import tensorflow as tf
25 |
26 | from self_supervision import exemplar
27 | from self_supervision import jigsaw
28 | from self_supervision import linear_eval
29 | from self_supervision import relative_patch_location
30 | from self_supervision import rotation
31 | from self_supervision import supervised
32 |
33 |
34 | def get_self_supervision_model(self_supervision):
35 | """Gets self supervised training data and labels."""
36 |
37 | mapping = {
38 | "linear_eval": linear_eval.model_fn,
39 | "supervised": supervised.model_fn,
40 |
41 | "rotation": rotation.model_fn,
42 | "jigsaw": jigsaw.model_fn,
43 | "relative_patch_location": relative_patch_location.model_fn,
44 | "exemplar": exemplar.model_fn,
45 | }
46 |
47 | model_fn = mapping.get(self_supervision)
48 | if model_fn is None:
49 | raise ValueError("Unknown self-supervision: %s" % self_supervision)
50 |
51 | def _model_fn(features, labels, mode, params):
52 | """Returns the EstimatorSpec to run the model.
53 |
54 | Args:
55 | features: Dict of inputs ("image" being the image).
56 | labels: unused but required by Estimator API.
57 | mode: model's mode: training, eval or prediction
58 | params: required by Estimator API, contains TPU local `batch_size`.
59 |
60 | Returns:
61 | EstimatorSpec
62 |
63 | Raises:
64 | ValueError when the self_supervision is unknown.
65 | """
66 | del labels, params # unused
67 | tf.logging.info("Calling model_fn in mode %s with data:", mode)
68 | tf.logging.info(features)
69 | return model_fn(features, mode)
70 |
71 | return _model_fn
72 |
--------------------------------------------------------------------------------
/models/vggnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements VGG model.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import functools
25 |
26 | import tensorflow as tf
27 |
28 |
29 | def convbnrelu(x, filters, size, training, **convkw):
30 | x = tf.layers.conv2d(x, kernel_size=size, filters=filters,
31 | use_bias=False, **convkw)
32 | x = tf.layers.batch_normalization(x, fused=True, training=training)
33 | x = tf.nn.relu(x)
34 | return x
35 |
36 |
37 | def vgg19(x, is_training, num_classes=1000, # pylint: disable=missing-docstring
38 | filters_factor=8,
39 | weight_decay=5e-4):
40 | # NOTE: default weight_decay here is as in the VGGNet paper, which is
41 | # different from the ResNet and RevNet models.
42 | # NOTE: Another difference is that we are using BatchNorm, and because of
43 | # that, we are not using Dropout in the final FC layers.
44 |
45 | regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
46 | conv3 = functools.partial(convbnrelu, size=3, training=is_training,
47 | kernel_regularizer=regularizer, padding='same')
48 | fc = functools.partial(convbnrelu, training=is_training,
49 | kernel_regularizer=regularizer, padding='valid')
50 |
51 | end_points = {}
52 |
53 | # After long discussion, we settled on filters_factor=8 being the default and
54 | # thus needing to match the vanilla VGGNet, which starts with 64.
55 | w = 8 * filters_factor # w stands for width (a la wide-resnet)
56 | x = conv3(x, w)
57 | x = conv3(x, w)
58 | end_points['block1'] = x
59 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 112x112
60 | x = conv3(x, 2*w)
61 | x = conv3(x, 2*w)
62 | end_points['block2'] = x
63 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 56x56
64 | x = conv3(x, 4*w)
65 | x = conv3(x, 4*w)
66 | x = conv3(x, 4*w)
67 | x = conv3(x, 4*w)
68 | end_points['block3'] = x
69 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 28x28
70 | x = conv3(x, 8*w)
71 | x = conv3(x, 8*w)
72 | x = conv3(x, 8*w)
73 | x = conv3(x, 8*w)
74 | end_points['block4'] = x
75 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 14x14
76 | x = conv3(x, 8*w)
77 | x = conv3(x, 8*w)
78 | x = conv3(x, 8*w)
79 | x = conv3(x, 8*w)
80 | end_points['block5'] = x
81 |
82 | x = tf.layers.max_pooling2d(x, pool_size=2, strides=2) # 7x7
83 | x = fc(x, 512*filters_factor, x.get_shape().as_list()[-3:-1])
84 | end_points['fc6'] = x
85 | x = fc(x, 512*filters_factor, (1, 1))
86 |
87 | end_points['pre_logits'] = x
88 | x = tf.layers.conv2d(x, kernel_size=1, filters=num_classes,
89 | use_bias=True, kernel_regularizer=regularizer)
90 | x = tf.squeeze(x, [1, 2])
91 | end_points['logits'] = x
92 |
93 | return x, end_points
94 |
--------------------------------------------------------------------------------
/self_supervision/patch_model_preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # pylint: disable=missing-docstring
18 | """Preprocessing methods for self supervised representation learning.
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | import functools
26 | import tensorflow as tf
27 |
28 | import utils as utils
29 |
30 |
31 | def crop(image, is_training, crop_size):
32 | h, w, c = crop_size[0], crop_size[1], image.shape[-1]
33 |
34 | if is_training:
35 | return tf.random_crop(image, [h, w, c])
36 | else:
37 | # Central crop for now. (See Table 5 in Appendix of
38 | # https://arxiv.org/pdf/1703.07737.pdf for why)
39 | dy = (tf.shape(image)[0] - h)//2
40 | dx = (tf.shape(image)[1] - w)//2
41 | return tf.image.crop_to_bounding_box(image, dy, dx, h, w)
42 |
43 |
44 | def image_to_patches(image, is_training, split_per_side, patch_jitter=0):
45 | """Crops split_per_side x split_per_side patches from input image.
46 |
47 | Args:
48 | image: input image tensor with shape [h, w, c].
49 | is_training: is training flag.
50 | split_per_side: split of patches per image side.
51 | patch_jitter: jitter of each patch from each grid.
52 |
53 | Returns:
54 | Patches tensor with shape [patch_count, hc, wc, c].
55 | """
56 | h, w, _ = image.get_shape().as_list()
57 |
58 | h_grid = h // split_per_side
59 | w_grid = w // split_per_side
60 | h_patch = h_grid - patch_jitter
61 | w_patch = w_grid - patch_jitter
62 |
63 | tf.logging.info(
64 | "Crop patches - image size: (%d, %d), split_per_side: %d, "
65 | "grid_size: (%d, %d), patch_size: (%d, %d), split_jitter: %d",
66 | h, w, split_per_side, h_grid, w_grid, h_patch, w_patch, patch_jitter)
67 |
68 | patches = []
69 | for i in range(split_per_side):
70 | for j in range(split_per_side):
71 |
72 | p = tf.image.crop_to_bounding_box(image, i * h_grid, j * w_grid, h_grid,
73 | w_grid)
74 | # Trick: crop a small tile from pixel cell, to avoid edge continuity.
75 | if h_patch < h_grid or w_patch < w_grid:
76 | p = crop(p, is_training, [h_patch, w_patch])
77 |
78 | patches.append(p)
79 |
80 | return tf.stack(patches)
81 |
82 |
83 | def get_crop_patches_fn(is_training, split_per_side, patch_jitter=0):
84 | """Gets a function which crops split_per_side x split_per_side patches.
85 |
86 | Args:
87 | is_training: is training flag.
88 | split_per_side: split of patches per image side.
89 | patch_jitter: jitter of each patch from each grid. E.g. 255x255 input
90 | image with split_per_side=3 will be split into 3 85x85 grids, and
91 | patches are cropped from each grid with size (grid_size-patch_jitter,
92 | grid_size-patch_jitter).
93 |
94 | Returns:
95 | A function returns name to tensor dict. This function crops split_per_side x
96 | split_per_side patches from "image" tensor in input data dict.
97 | """
98 |
99 | def _crop_patches_pp(data):
100 | image = data["image"]
101 |
102 | image_to_patches_fn = functools.partial(
103 | image_to_patches,
104 | is_training=is_training,
105 | split_per_side=split_per_side,
106 | patch_jitter=patch_jitter)
107 | image = utils.tf_apply_to_image_or_images(image_to_patches_fn, image)
108 |
109 | data["image"] = image
110 | return data
111 | return _crop_patches_pp
112 |
113 |
--------------------------------------------------------------------------------
/self_supervision/rotation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Produces ratations for input images.
18 |
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | import functools
26 | import tensorflow as tf
27 | import tensorflow_hub as hub
28 |
29 | from models.utils import get_net
30 | import trainer
31 | import utils
32 |
33 |
34 | FLAGS = tf.flags.FLAGS
35 |
36 |
37 | def apply_model(image_fn, # pylint: disable=missing-docstring
38 | is_training,
39 | num_outputs,
40 | make_signature=False):
41 |
42 | # Image tensor needs to be created lazily in order to satisfy tf-hub
43 | # restriction: all tensors should be created inside tf-hub helper function.
44 | images = image_fn()
45 |
46 | net = get_net(num_classes=num_outputs)
47 |
48 | output, end_points = net(images, is_training)
49 |
50 | if make_signature:
51 | hub.add_signature(inputs={'image': images}, outputs=output)
52 | hub.add_signature(
53 | name='representation',
54 | inputs={'image': images},
55 | outputs=end_points)
56 | return output
57 |
58 |
59 | def model_fn(data, mode):
60 | """Produces a loss for the rotation task.
61 |
62 | Args:
63 | data: Dict of inputs containing, among others, "image" and "label."
64 | mode: model's mode: training, eval or prediction
65 |
66 | Returns:
67 | EstimatorSpec
68 | """
69 | num_angles = 4
70 | images = data['image']
71 |
72 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
73 | images = tf.reshape(images, [-1] + images.get_shape().as_list()[-3:])
74 | with tf.variable_scope('module'):
75 | image_fn = lambda: images
76 | logits = apply_model(
77 | image_fn=image_fn,
78 | is_training=(mode == tf.estimator.ModeKeys.TRAIN),
79 | num_outputs=num_angles,
80 | make_signature=False)
81 | else:
82 | input_shape = utils.str2intlist(
83 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3'))
84 | image_fn = lambda: tf.placeholder(shape=input_shape, # pylint: disable=g-long-lambda
85 | dtype=tf.float32)
86 | apply_model_function = functools.partial(
87 | apply_model,
88 | image_fn=image_fn,
89 | num_outputs=num_angles,
90 | make_signature=True)
91 | tf_hub_module_spec = hub.create_module_spec(apply_model_function,
92 | [(utils.TAGS_IS_TRAINING, {
93 | 'is_training': True
94 | }),
95 | (set(), {
96 | 'is_training': False
97 | })])
98 | tf_hub_module = hub.Module(tf_hub_module_spec, trainable=False, tags=set())
99 | hub.register_module_for_export(tf_hub_module, export_name='module')
100 | logits = tf_hub_module(images)
101 |
102 | return trainer.make_estimator(mode, predictions=logits)
103 |
104 | labels = tf.reshape(data['label'], [-1])
105 |
106 | # build loss and accuracy
107 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
108 | labels=labels, logits=logits)
109 | loss = tf.reduce_mean(loss)
110 |
111 | eval_metrics = (
112 | lambda labels, logits: { # pylint: disable=g-long-lambda
113 | 'accuracy': tf.metrics.accuracy(
114 | labels=labels,
115 | predictions=tf.argmax(logits, axis=-1))},
116 | [labels, logits])
117 | return trainer.make_estimator(mode, loss, eval_metrics, logits)
118 |
--------------------------------------------------------------------------------
/self_supervision/supervised.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements fully-supervised model.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import functools
25 | import tensorflow as tf
26 | import tensorflow_hub as hub
27 |
28 | import datasets
29 | from models.utils import get_net
30 | import trainer
31 | import utils
32 |
33 | FLAGS = tf.flags.FLAGS
34 |
35 |
36 | def apply_model(image_fn, # pylint: disable=missing-docstring
37 | is_training,
38 | num_outputs,
39 | make_signature=False):
40 |
41 | # Image tensor needs to be created lazily in order to satisfy tf-hub
42 | # restriction: all tensors should be created inside tf-hub helper function.
43 | images = image_fn()
44 |
45 | net = get_net(num_classes=num_outputs)
46 |
47 | output, end_points = net(images, is_training)
48 |
49 | if make_signature:
50 | hub.add_signature(inputs={'image': images}, outputs=output)
51 | hub.add_signature(inputs={'image': images}, outputs=end_points,
52 | name='representation')
53 | return output
54 |
55 |
56 | def model_fn(data, mode):
57 | """Produces a loss for the fully-supervised task.
58 |
59 | Args:
60 | data: Dict of inputs containing, among others, "image" and "label."
61 | mode: model's mode: training, eval or prediction
62 |
63 | Returns:
64 | EstimatorSpec
65 | """
66 | images = data['image']
67 |
68 | # In predict mode (called once at the end of training), we only instantiate
69 | # the model in order to export a tf.hub module for it.
70 | # This will then make saving and loading much easier down the line.
71 | if mode == tf.estimator.ModeKeys.PREDICT:
72 | input_shape = utils.str2intlist(
73 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3'))
74 | apply_model_function = functools.partial(
75 | apply_model,
76 | image_fn=lambda: tf.placeholder(shape=input_shape, dtype=tf.float32), # pylint: disable=g-long-lambda
77 | num_outputs=datasets.get_num_classes(),
78 | make_signature=True)
79 | tf_hub_module_spec = hub.create_module_spec(
80 | apply_model_function,
81 | [(utils.TAGS_IS_TRAINING, {'is_training': True}),
82 | (set(), {'is_training': False})])
83 | tf_hub_module = hub.Module(tf_hub_module_spec, trainable=False, tags=set())
84 | hub.register_module_for_export(tf_hub_module, export_name='module')
85 | logits = tf_hub_module(images)
86 |
87 | # There is no training happening anymore, only prediciton and model export.
88 | return trainer.make_estimator(mode, predictions=logits)
89 |
90 | # From here on, we are either in train or eval modes.
91 | # Create the model in the 'module' name scope so it matches nicely with
92 | # tf.hub's requirements for import/export later.
93 | with tf.variable_scope('module'):
94 | logits = apply_model(
95 | image_fn=lambda: images,
96 | is_training=(mode == tf.estimator.ModeKeys.TRAIN),
97 | num_outputs=datasets.get_num_classes(),
98 | make_signature=False)
99 |
100 | labels = data['label']
101 |
102 | # build loss and accuracy
103 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
104 | labels=labels, logits=logits)
105 | loss = tf.reduce_mean(loss)
106 |
107 | # Gets a metric_fn which evaluates the "top1_accuracy" and "top5_accuracy".
108 | # The resulting metrics are named "top1_accuracy_{tensor_name}",
109 | # "top5_accuracy_{tensor_name}".
110 | metrics_fn = utils.get_classification_metrics(['logits'])
111 | # A tuple of metric_fn and a list of tensors to be evaluated by TPUEstimator.
112 | eval_metrics_tuple = (metrics_fn, [labels, logits])
113 |
114 | return trainer.make_estimator(mode, loss, eval_metrics_tuple)
115 |
--------------------------------------------------------------------------------
/self_supervision/exemplar.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Exemplar implementation with triplet semihard loss."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import functools
24 | import tensorflow as tf
25 | import tensorflow_hub as hub
26 |
27 | import utils
28 | from models.utils import get_net
29 | from trainer import make_estimator
30 |
31 | FLAGS = tf.flags.FLAGS
32 |
33 |
34 | def apply_model(image_fn,
35 | is_training,
36 | num_outputs,
37 | make_signature=False):
38 | """Creates the patch based model output from patches representations.
39 |
40 | Args:
41 | image_fn: function returns image tensor.
42 | is_training: is training flag used for batch norm and drop out.
43 | num_outputs: number of output classes.
44 | make_signature: whether to create signature for hub module.
45 |
46 |
47 | Returns:
48 | out: output tensor with shape [n*m, 1, 1, num_outputs].
49 |
50 | Raises:
51 | ValueError: An error occurred when the architecture is unknown.
52 | """
53 | # Image tensor needs to be created lazily in order to satisfy tf-hub
54 | # restriction: all tensors should be created inside tf-hub helper function.
55 | images = image_fn()
56 |
57 | net = get_net(num_classes=num_outputs)
58 | out, end_points = net(images, is_training,
59 | weight_decay=FLAGS.get_flag_value('weight_decay', 1e-4))
60 |
61 | print(end_points)
62 |
63 | if len(out.get_shape().as_list()) == 4:
64 | out = tf.squeeze(out, [1, 2])
65 |
66 | if make_signature:
67 | hub.add_signature(inputs={'image': images}, outputs=out)
68 | hub.add_signature(
69 | name='representation',
70 | inputs={'image': images},
71 | outputs=end_points)
72 | return out
73 |
74 |
75 | def repeat(x, times):
76 | """Exactly like np.repeat."""
77 | return tf.reshape(tf.tile(tf.expand_dims(x, -1), [1, times]), [-1])
78 |
79 |
80 | def model_fn(data, mode):
81 | """Produces a loss for the exemplar task supervision.
82 |
83 | Args:
84 | data: Dict of inputs containing, among others, "image" and "label."
85 | mode: model's mode: training, eval or prediction
86 |
87 | Returns:
88 | EstimatorSpec
89 | """
90 | images = data['image']
91 | batch_size = tf.shape(images)[0]
92 | print(' +++ Mode: %s, data: %s' % (mode, data))
93 |
94 | embed_dim = FLAGS.embed_dim
95 | patch_count = images.get_shape().as_list()[1]
96 |
97 | images = tf.reshape(
98 | images, shape=[-1] + images.get_shape().as_list()[-3:])
99 |
100 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
101 | images = tf.reshape(images, [-1] + images.get_shape().as_list()[-3:])
102 | with tf.variable_scope('module'):
103 | image_fn = lambda: images
104 | logits = apply_model(
105 | image_fn=image_fn,
106 | is_training=(mode == tf.estimator.ModeKeys.TRAIN),
107 | num_outputs=embed_dim,
108 | make_signature=False)
109 | else:
110 | input_shape = utils.str2intlist(
111 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3'))
112 | image_fn = lambda: tf.placeholder(shape=input_shape, # pylint: disable=g-long-lambda
113 | dtype=tf.float32)
114 | apply_model_function = functools.partial(
115 | apply_model,
116 | image_fn=image_fn,
117 | num_outputs=embed_dim,
118 | make_signature=True)
119 |
120 | tf_hub_module_spec = hub.create_module_spec(apply_model_function,
121 | [(utils.TAGS_IS_TRAINING, {
122 | 'is_training': True
123 | }),
124 | (set(), {
125 | 'is_training': False
126 | })],
127 | drop_collections=['summaries'])
128 | tf_hub_module = hub.Module(tf_hub_module_spec, trainable=False, tags=set())
129 | hub.register_module_for_export(tf_hub_module, export_name='module')
130 | logits = tf_hub_module(images)
131 | return make_estimator(mode, predictions=logits)
132 |
133 | labels = repeat(tf.range(batch_size), patch_count)
134 | norm_logits = tf.nn.l2_normalize(logits, axis=0)
135 | loss = tf.contrib.losses.metric_learning.triplet_semihard_loss(
136 | labels, norm_logits, margin=FLAGS.margin)
137 |
138 | return make_estimator(mode, loss, predictions=logits)
139 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Base trainer class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow as tf
24 |
25 | import datasets
26 | import utils
27 |
28 | FLAGS = tf.flags.FLAGS
29 |
30 |
31 | def get_lr(global_step, base_lr, steps_per_epoch, # pylint: disable=missing-docstring
32 | decay_epochs, lr_decay_factor, warmup_epochs):
33 |
34 | warmup_lr = 0.0
35 | if warmup_epochs > 0:
36 | warmup_lr = (tf.cast(global_step, tf.float32) *
37 | (base_lr / (warmup_epochs * steps_per_epoch)))
38 |
39 | normal_lr = tf.train.piecewise_constant(
40 | global_step,
41 | [e * steps_per_epoch for e in decay_epochs],
42 | [base_lr * (lr_decay_factor ** i) for i in range(len(decay_epochs) + 1)]
43 | )
44 |
45 | lr = tf.cond(tf.less(global_step, warmup_epochs * steps_per_epoch),
46 | lambda: warmup_lr,
47 | lambda: normal_lr)
48 |
49 | return lr
50 |
51 |
52 | # TODO(akolesnikov): add more logging
53 | class Trainer(object):
54 | """Base trainer class."""
55 |
56 | def __init__(self,
57 | update_batchnorm_params=True):
58 | self.update_batchnorm_params = update_batchnorm_params
59 |
60 | split = FLAGS.get_flag_value('train_split', 'train')
61 | num_samples = datasets.get_count(split)
62 | steps_per_epoch = num_samples // FLAGS.batch_size
63 |
64 | global_step = tf.train.get_or_create_global_step()
65 | self.global_step_inc = tf.assign_add(global_step, 1)
66 |
67 | # lr_scale_batch_size defines a canonical batch size that is coupled with
68 | # the initial learning rate. If actual batch size is not the same as
69 | # canonical than learning rate is linearly scaled. This is very convinient
70 | # as this allows to vary batch size without recomputing learning rate.
71 | lr_factor = 1.0
72 | if FLAGS.get_flag_value('lr_scale_batch_size', 0):
73 | lr_factor = FLAGS.batch_size / float(FLAGS.lr_scale_batch_size)
74 |
75 | deps = FLAGS.get_flag_value('decay_epochs', None)
76 | decay_epochs = utils.str2intlist(deps) if deps else [FLAGS.epochs]
77 |
78 | self.lr = get_lr(
79 | global_step,
80 | base_lr=FLAGS.lr * lr_factor,
81 | steps_per_epoch=steps_per_epoch,
82 | decay_epochs=decay_epochs,
83 | lr_decay_factor=FLAGS.get_flag_value('lr_decay_factor', 0.1),
84 | warmup_epochs=FLAGS.get_flag_value('warmup_epochs', 0))
85 |
86 | # TODO(marvinritter): Re-enable summaries with support for TPU training.
87 | # tf.summary.scalar('learning_rate', self.lr)
88 |
89 | def get_train_op(self, loss, # pylint: disable=missing-docstring
90 | var_list=None,
91 | add_reg_loss=True,
92 | use_tpu=False):
93 |
94 | if add_reg_loss:
95 | l2_loss = tf.reduce_sum(tf.losses.get_regularization_losses())
96 | loss += l2_loss
97 |
98 | optimizer = FLAGS.get_flag_value('optimizer', 'sgd')
99 | if optimizer == 'sgd':
100 | optimizer = tf.train.MomentumOptimizer(learning_rate=self.lr,
101 | momentum=0.9)
102 | elif optimizer == 'adam':
103 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
104 | else:
105 | raise ValueError('Unknown optimizer: %s' % optimizer)
106 |
107 | if use_tpu:
108 | # Wrap optimizer in CrossShardOptimizer which takes care of
109 | # synchronizing the weight updates between TPU cores.
110 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
111 |
112 | opt_step = optimizer.minimize(loss, var_list=var_list,
113 | colocate_gradients_with_ops=True)
114 |
115 | if self.update_batchnorm_params:
116 | opt_step = tf.group([opt_step] +
117 | tf.get_collection(tf.GraphKeys.UPDATE_OPS))
118 |
119 | opt_step = tf.group([opt_step, self.global_step_inc])
120 |
121 | return opt_step
122 |
123 |
124 | def make_estimator(mode, loss=None, eval_metrics=None, predictions=None):
125 | """Returns an EstimatorSpec (maybe TPU) for all modes."""
126 |
127 | # Always use TPUEstimator, even when not using TPU, then it's (almost) no-op.
128 | spec_type = tf.contrib.tpu.TPUEstimatorSpec
129 |
130 | if mode == tf.estimator.ModeKeys.PREDICT:
131 | assert predictions is not None, 'Need to pass `predict` arg.'
132 | return spec_type(mode=mode, predictions=predictions)
133 |
134 | if mode == tf.estimator.ModeKeys.EVAL:
135 | return spec_type(mode=mode, loss=loss, eval_metrics=eval_metrics)
136 |
137 | if mode == tf.estimator.ModeKeys.TRAIN:
138 | assert loss is not None, 'Need to pass `loss` arg.'
139 | trainer = Trainer(update_batchnorm_params=True)
140 | train_op = trainer.get_train_op(loss, use_tpu=FLAGS.use_tpu)
141 | return spec_type(mode=mode, loss=loss, train_op=train_op)
142 |
143 | raise ValueError('Unsupported mode %s' % mode)
144 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Revisiting self-supervised visual representation learning
2 |
3 | Tensorflow implementation of experiments from
4 | [our paper on unsupervised visual representation learning](http://arxiv.org/abs/1901.09005).
5 |
6 | If you find this repository useful in your research, please consider citing:
7 |
8 | ```
9 | @inproceedings{kolesnikov2019revisiting,
10 | title={Revisiting self-supervised visual representation learning},
11 | author={Kolesnikov, Alexander and Zhai, Xiaohua and Beyer, Lucas},
12 | journal={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
13 | month={June},
14 | year={2019}
15 | }
16 | ```
17 |
18 | ## Overview
19 |
20 | This codebase allows to reproduce core experiments from our paper. It contains
21 | our re-implementation of four self-supervised representation learning
22 | techniques, utility code for running training and evaluation loops (including on
23 | TPUs) and an implementation of standard CNN models, such as ResNet v1, ResNet v2
24 | and VGG19.
25 |
26 | Specifically, we provide a re-implementation of the following self-supervised
27 | representation learning techniques:
28 |
29 | 1. [Unsupervised Representation Learning by Predicting Image Rotations](https://arxiv.org/abs/1803.07728)
30 | 2. [Unsupervised Visual Representation Learning by Context Prediction](https://arxiv.org/abs/1505.05192)
31 | 3. [Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles](https://arxiv.org/abs/1603.09246)
32 | 4. [Discriminative Unsupervised Feature Learning with Exemplar Convolutional
33 | Neural Networks](https://arxiv.org/abs/1406.6909)
34 |
35 | ## Usage instructions
36 |
37 | In the paper we train self-supervised models using 32 or 128 TPU cores. We
38 | evaluate the resulting representations by training a logistic regression model
39 | on 32 TPU cores.
40 |
41 | In this codebase we provide configurations for training/evaluation of our models
42 | using an 8 TPU core setup as this setup is more affordable for public TPU users
43 | through the Google Cloud API. These configurations produce results close to those
44 | reported in the paper, which used more TPU chips.
45 |
46 | For debugging or running small experiments we also support training and
47 | evaluation using a single GPU device.
48 |
49 | ### Preparing data
50 |
51 | Please refer to the
52 | [instructions in the slim library](https://github.com/tensorflow/models/blob/master/research/inception/README.md#getting-started)
53 | for downloading and preprocessing ImageNet data.
54 |
55 | ### Clone the repository and install dependencies
56 |
57 | ```
58 | git clone https://github.com/google/revisiting-self-supervised
59 | cd revisiting-self-supervised
60 | python -m pip install -e . --user
61 | ```
62 |
63 | We depend on some external files that need to be downloaded and placed in the
64 | root repository folder. You can run the following commands to download them:
65 |
66 | ```
67 | wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/preprocessing/inception_preprocessing.py
68 | wget https://github.com/MehdiNoroozi/JigsawPuzzleSolver/raw/master/permutations_100_max.bin
69 | ```
70 |
71 | ### Running locally on a single GPU
72 |
73 | Run any experiment by running the corresponding shell script with the following
74 | options, here exemplified for the fully supervised experiment:
75 |
76 | ```
77 | ./config/supervised/imagenet.sh \
78 | --workdir \
79 | --nouse_tpu \
80 | --master='' \
81 | --dataset_dir
82 | ```
83 |
84 | ### Running on Google Cloud using TPUs
85 |
86 | #### Step 1:
87 |
88 | Create your own TPU cloud instance by following the
89 | [official documentation](https://cloud.google.com/tpu/).
90 |
91 | #### Step 2:
92 |
93 | Clone the repository and install dependencies as described above.
94 |
95 | #### Step 3:
96 |
97 | Run the self supervised model training script with TPUs. For example:
98 |
99 | ```
100 | gsutil mb gs://
101 | export TPU_NAME=
102 | config/supervised/imagenet.sh --workdir gs:// --dataset_dir gs://
103 | ```
104 |
105 | After/during training, run the self supervised model evaluation script with
106 | TPUs. It generates the loss and metric on the validation set, and exports a hub
107 | module under directory `gs:///export/hub//module`:
108 |
109 | ```
110 | config/supervised/imagenet.sh --workdir gs:// --dataset_dir gs:// --run_eval
111 | ```
112 |
113 | Note, that `` is set by the user when creating the Cloud TPU
114 | node. Moreover, ImageNet data and the working directory should be placed in a
115 | Google Cloud bucket storage.
116 |
117 | #### Step 4:
118 |
119 | Evaluates the self supervised models with logistic regression. You need to pass
120 | the exported hub module from step 3 above as an additional argument:
121 |
122 | ```
123 | gsutil mb gs://
124 | export TPU_NAME=
125 | config/evaluation/rotation_or_exemplar.sh --workdir gs:// --dataset_dir gs:// --hub_module gs://
126 |
127 | config/evaluation/rotation_or_exemplar.sh --workdir gs:// --dataset_dir gs:// --hub_module gs:// --run_eval
128 | ```
129 |
130 | You could start a tensorboard to visualize the training/evaluation progress:
131 |
132 | ```
133 | tensorboard --port 2222 --logdir gs://
134 | ```
135 |
136 | ## Pretrained models
137 |
138 | If you want to download and try our best self-supervised models please see this [Ipython
139 | notebook](https://colab.research.google.com/drive/1HdApkScZpulQrACrPKZiKYHhy7MeR3iN).
140 |
141 |
142 | ## Authors
143 |
144 | - [Alexander Kolesnikov](https://github.com/kolesman)
145 | - [Xiaohua Zhai](https://sites.google.com/site/xzhai89/)
146 | - [Lucas Beyer](http://lucasb.eyer.be/)
147 | - [Marvin Ritter](https://github.com/Marvin182)
148 |
149 | ### This is not an official Google product
150 |
--------------------------------------------------------------------------------
/self_supervision/linear_eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements fully-supervised model.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import collections
25 | import functools
26 | import os
27 |
28 | import numpy as np
29 | import tensorflow as tf
30 | import tensorflow_hub as hub
31 |
32 | import datasets
33 | from self_supervision.patch_utils import get_patch_representation
34 | from trainer import make_estimator
35 | import utils
36 |
37 | FLAGS = tf.flags.FLAGS
38 |
39 |
40 | def apply_fractional_pooling(taps, target_features=9000, mode='adaptive_max'):
41 | """Applies fractional pooling to each of `taps`.
42 |
43 | Args:
44 | taps: A dict of names:tensors to which to attach the head.
45 | target_features: If the input tensor has more than this number of features,
46 | perform fractional pooling to reduce it to this amount.
47 | mode: one of {'adaptive_max', 'adaptive_avg', 'max', 'avg'}
48 |
49 | Returns:
50 | tensors: An ordered dict with target_features dimension tensors.
51 |
52 | Raises:
53 | ValueError: mode is unexpected.
54 | """
55 | out_tensors = collections.OrderedDict()
56 | for k, t in sorted(taps.items()):
57 | if len(t.get_shape().as_list()) == 2:
58 | t = t[:, None, None, :]
59 | _, h, w, num_channels = t.get_shape().as_list()
60 | if h * w * num_channels > target_features:
61 | t = utils.adaptive_pool(t, target_features, mode)
62 | _, h, w, num_channels = t.get_shape().as_list()
63 | out_tensors[k] = t
64 |
65 | return out_tensors
66 |
67 |
68 | def add_linear_heads(rep_tensors, n_out):
69 | """Adds a linear head to each of rep_tensors.
70 |
71 | Args:
72 | rep_tensors: A dict of names:tensors to which to attach the head.
73 | n_out: The number of features the head should map to.
74 |
75 | Returns:
76 | tensors: An ordered dict like `taps` but with the head's output as value.
77 | """
78 | for k in list(rep_tensors.keys()):
79 | t = rep_tensors[k]
80 | t = tf.reshape(
81 | t, tf.stack([-1, 1, 1, tf.reduce_prod(t.shape[1:])]))
82 | t = tf.layers.conv2d(
83 | t,
84 | filters=n_out,
85 | kernel_size=1,
86 | padding='valid',
87 | activation=None,
88 | kernel_regularizer=tf.contrib.layers.l2_regularizer(
89 | scale=FLAGS.get_flag_value('weight_decay', 0.)))
90 | rep_tensors[k] = tf.squeeze(t, [1, 2])
91 |
92 | return rep_tensors
93 |
94 |
95 | def add_mlp_heads(rep_tensors, n_out, is_training):
96 | """Adds a mlp head to each of rep_tensors.
97 |
98 | Args:
99 | rep_tensors: A dict of names:tensors to which to attach the head.
100 | n_out: The number of features the head should map to.
101 | is_training: whether in training mode.
102 |
103 | Returns:
104 | tensors: An ordered dict like `taps` but with the head's output as value.
105 | """
106 | kernel_regularizer = tf.contrib.layers.l2_regularizer(
107 | scale=FLAGS.get_flag_value('weight_decay', 0.))
108 | channels_hidden = FLAGS.get_flag_value('channels_hidden', 1000)
109 | for k, t in rep_tensors.iteritems():
110 | t = tf.reshape(t, [-1, 1, 1, np.prod(t.shape[1:])])
111 | t = tf.layers.conv2d(
112 | t, channels_hidden, kernel_size=1, padding='VALID',
113 | activation=tf.nn.relu, kernel_regularizer=kernel_regularizer)
114 | t = tf.layers.dropout(t, rate=0.5, training=is_training)
115 | t = tf.layers.conv2d(
116 | t, n_out, kernel_size=1, padding='VALID',
117 | kernel_regularizer=kernel_regularizer)
118 | rep_tensors[k] = tf.squeeze(t, [1, 2])
119 |
120 | return rep_tensors
121 |
122 |
123 | def model_fn(data, mode):
124 | """Produces a loss for the fully-supervised task.
125 |
126 | Args:
127 | data: Dict of inputs containing, among others, "image" and "label."
128 | mode: model's mode: training, eval or prediction
129 |
130 | Returns:
131 | EstimatorSpec
132 |
133 | Raises:
134 | ValueError: Unexpected FLAGS.eval_model
135 | """
136 | images = data['image']
137 | tf.logging.info('model_fn(): features=%s, mode=%s)', images, mode)
138 |
139 | input_rank = len(images.get_shape().as_list())
140 | image_input_rank = 4 # NHWC
141 | patch_input_rank = 5 # NPHWC
142 | assert input_rank in [image_input_rank, patch_input_rank], (
143 | 'Unsupported input rank: %d' % input_rank)
144 |
145 | module = hub.Module(os.path.expanduser(str(FLAGS.hub_module)))
146 |
147 | if mode == tf.estimator.ModeKeys.PREDICT:
148 | return make_estimator(
149 | mode,
150 | predictions=module(
151 | images,
152 | signature=FLAGS.get_flag_value('signature', 'representation'),
153 | as_dict=True))
154 |
155 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
156 | pooling_fn = functools.partial(
157 | apply_fractional_pooling,
158 | mode=FLAGS.get_flag_value('pool_mode', 'adaptive_max'))
159 | target_features = 9000
160 | if input_rank == patch_input_rank:
161 | out_tensors = get_patch_representation(
162 | images,
163 | module,
164 | patch_preprocess=None,
165 | is_training=is_training,
166 | target_features=target_features,
167 | combine_patches=FLAGS.get_flag_value('combine_patches', 'avg_pool'),
168 | signature='representation',
169 | pooling_fn=pooling_fn)
170 | else:
171 | out_tensors = module(
172 | images,
173 | signature=FLAGS.get_flag_value('signature', 'representation'),
174 | as_dict=True)
175 | out_tensors = pooling_fn(out_tensors, target_features=target_features)
176 |
177 | eval_model = FLAGS.get_flag_value('eval_model', 'linear')
178 | if eval_model == 'linear':
179 | out_logits = add_linear_heads(out_tensors, datasets.get_num_classes())
180 | elif eval_model == 'mlp':
181 | out_logits = add_mlp_heads(out_tensors, datasets.get_num_classes(),
182 | is_training=is_training)
183 | else:
184 | raise ValueError('Unsupported eval %s model.' % eval_model)
185 |
186 | # build loss and accuracy
187 | labels = data['label']
188 | losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
189 | logits=logits)
190 | for logits in out_logits.values()]
191 | loss = tf.add_n([tf.reduce_mean(loss) for loss in losses])
192 |
193 | metrics_fn = utils.get_classification_metrics(
194 | tensor_names=out_logits.keys())
195 | # A tuple of metric_fn and a list of tensors to be evaluated by TPUEstimator.
196 | eval_metrics_tuple = (metrics_fn, [labels] + list(out_logits.values()))
197 |
198 | return make_estimator(mode, loss, eval_metrics_tuple)
199 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Util functions for representation learning.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import collections
25 | import csv
26 | import re
27 |
28 | import numpy as np
29 | import tensorflow as tf
30 |
31 |
32 | INPUT_DATA_STR = "input_data"
33 | IS_TRAINING_STR = "is_training"
34 | REPR_PREFIX_STR = "representation_"
35 | TAGS_IS_TRAINING = ["is_training"]
36 |
37 |
38 | def adaptive_pool(inp, num_target_dimensions=9000, mode="adaptive_max"):
39 | """Adaptive pooling layer.
40 |
41 | This layer performs adaptive pooling, such that the total
42 | dimensionality of output is not bigger than num_target_dimension
43 |
44 | Args:
45 | inp: input tensor
46 | num_target_dimensions: maximum number of output dimensions
47 | mode: one of {"adaptive_max", "adaptive_avg", "max", "avg"}
48 |
49 | Returns:
50 | Result of the pooling operation
51 |
52 | Raises:
53 | ValueError: mode is unexpected.
54 | """
55 |
56 | size, _, k = inp.get_shape().as_list()[1:]
57 | if mode in ["adaptive_max", "adaptive_avg"]:
58 | if mode == "adaptive_max":
59 | pool_fn = tf.nn.fractional_max_pool
60 | else:
61 | pool_fn = tf.nn.fractional_avg_pool
62 |
63 | # Find the optimal target output tensor size
64 | target_size = (num_target_dimensions / float(k)) ** 0.5
65 | if (abs(num_target_dimensions - k * np.floor(target_size) ** 2) <
66 | abs(num_target_dimensions - k * np.ceil(target_size) ** 2)):
67 | target_size = max(np.floor(target_size), 1.0)
68 | else:
69 | target_size = max(np.ceil(target_size), 1.0)
70 |
71 | # Get optimal stride. Subtract epsilon to ensure correct rounding in
72 | # pool_fn.
73 | stride = size / target_size - 1.0e-5
74 |
75 | # Make sure that the stride is valid
76 | stride = max(stride, 1)
77 | stride = min(stride, size)
78 |
79 | result = pool_fn(inp, [1, stride, stride, 1])[0]
80 | elif mode in ["max", "avg"]:
81 | if mode == "max":
82 | pool_fn = tf.contrib.layers.max_pool2d
83 | else:
84 | pool_fn = tf.contrib.layers.avg_pool2d
85 | total_size = float(np.prod(inp.get_shape()[1:].as_list()))
86 | stride = int(np.ceil(np.sqrt(total_size / num_target_dimensions)))
87 | stride = min(max(1, stride), size)
88 |
89 | result = pool_fn(inp, kernel_size=stride, stride=stride)
90 | else:
91 | raise ValueError("Not supported %s pool." % mode)
92 |
93 | return result
94 |
95 |
96 | def append_multiple_rows_to_csv(dictionaries, csv_path):
97 | """Writes multiples rows to csv file from a list of dictionaries.
98 |
99 | Args:
100 | dictionaries: a list of dictionaries, mapping from csv header to value.
101 | csv_path: path to the result csv file.
102 | """
103 |
104 | keys = set([])
105 | for d in dictionaries:
106 | keys.update(d.keys())
107 |
108 | if not tf.gfile.Exists(csv_path):
109 | with tf.gfile.Open(csv_path, "w") as f:
110 | writer = csv.DictWriter(f, sorted(keys))
111 | writer.writeheader()
112 | f.flush()
113 |
114 | with tf.gfile.Open(csv_path, "a") as f:
115 | writer = csv.DictWriter(f, sorted(keys))
116 | writer.writerows(dictionaries)
117 | f.flush()
118 |
119 |
120 | def concat_dicts(dict_list):
121 | """Given a list of dicts merges them into a single dict.
122 |
123 | This function takes a list of dictionaries as an input and then merges all
124 | these dictionaries into a single dictionary by concatenating the values
125 | (along the first axis) that correspond to the same key.
126 |
127 | Args:
128 | dict_list: list of dictionaries
129 |
130 | Returns:
131 | d: merged dictionary
132 | """
133 | d = collections.defaultdict(list)
134 | for e in dict_list:
135 | for k, v in e.items():
136 | d[k].append(v)
137 | for k in d:
138 | d[k] = tf.concat(d[k], axis=0)
139 | return d
140 |
141 |
142 | def str2intlist(s, repeats_if_single=None):
143 | """Parse a config's "1,2,3"-style string into a list of ints.
144 |
145 | Args:
146 | s: The string to be parsed, or possibly already an int.
147 | repeats_if_single: If s is already an int or is a single element list,
148 | repeat it this many times to create the list.
149 |
150 | Returns:
151 | A list of integers based on `s`.
152 | """
153 | if isinstance(s, int):
154 | result = [s]
155 | else:
156 | result = [int(i.strip()) if i != "None" else None
157 | for i in s.split(",")]
158 | if repeats_if_single is not None and len(result) == 1:
159 | result *= repeats_if_single
160 | return result
161 |
162 |
163 | def tf_apply_to_image_or_images(fn, image_or_images):
164 | """Applies a function to a single image or each image in a batch of them.
165 |
166 | Args:
167 | fn: the function to apply, receives an image, returns an image.
168 | image_or_images: Either a single image, or a batch of images.
169 |
170 | Returns:
171 | The result of applying the function to the image or batch of images.
172 |
173 | Raises:
174 | ValueError: if the input is not of rank 3 or 4.
175 | """
176 | static_rank = len(image_or_images.get_shape().as_list())
177 | if static_rank == 3: # A single image: HWC
178 | return fn(image_or_images)
179 | elif static_rank == 4: # A batch of images: BHWC
180 | return tf.map_fn(fn, image_or_images)
181 | elif static_rank > 4: # A batch of images: ...HWC
182 | input_shape = tf.shape(image_or_images)
183 | h, w, c = image_or_images.get_shape().as_list()[-3:]
184 | image_or_images = tf.reshape(image_or_images, [-1, h, w, c])
185 | image_or_images = tf.map_fn(fn, image_or_images)
186 | return tf.reshape(image_or_images, input_shape)
187 | else:
188 | raise ValueError("Unsupported image rank: %d" % static_rank)
189 |
190 |
191 | def tf_apply_with_probability(p, fn, x):
192 | """Apply function `fn` to input `x` randomly `p` percent of the time."""
193 | return tf.cond(
194 | tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), p),
195 | lambda: fn(x),
196 | lambda: x)
197 |
198 |
199 | def expand_glob(glob_patterns):
200 | checkpoints = []
201 | for pattern in glob_patterns:
202 | checkpoints.extend(tf.gfile.Glob(pattern))
203 | assert checkpoints, "There are no checkpoints in " + str(glob_patterns)
204 | return checkpoints
205 |
206 |
207 | def get_latest_hub_per_task(hub_module_paths):
208 | """Get latest hub module for each task.
209 |
210 | The hub module path should match format ".*/hub/[0-9]*/module/.*".
211 | Example usage:
212 | get_latest_hub_per_task(expand_glob(["/cns/el-d/home/dune/representation/"
213 | "xzhai/1899361/*/export/hub/*/module/"]))
214 | returns 4 latest hub module from 4 tasks respectivley.
215 |
216 | Args:
217 | hub_module_paths: a list of hub module paths.
218 |
219 | Returns:
220 | A list of latest hub modules for each task.
221 |
222 | """
223 | task_to_path = {}
224 | for path in hub_module_paths:
225 | task_name, module_name = path.split("/hub/")
226 | timestamp = int(re.findall(r"([0-9]*)/module", module_name)[0])
227 | current_path = task_to_path.get(task_name, "0/module")
228 | current_timestamp = int(re.findall(r"([0-9]*)/module", current_path)[0])
229 | if current_timestamp < timestamp:
230 | task_to_path[task_name] = path
231 | return sorted(task_to_path.values())
232 |
233 |
234 | def get_classification_metrics(tensor_names):
235 | """Gets classification eval metric on input logits and labels.
236 |
237 | Args:
238 | tensor_names: a list of tensor names for _metrics input tensors.
239 |
240 | Returns:
241 | A function computes the metric result, from input logits and labels.
242 | """
243 |
244 | def _top_k_accuracy(k, labels, logits):
245 | in_top_k = tf.nn.in_top_k(predictions=logits, targets=labels, k=k)
246 | return tf.metrics.mean(tf.cast(in_top_k, tf.float32))
247 |
248 | def _metrics(labels, *tensors):
249 | """Computes the metric from logits and labels.
250 |
251 | Args:
252 | labels: ground truth labels.
253 | *tensors: tensors to be evaluated.
254 |
255 | Returns:
256 | Result dict mapping from the metric name to the list of result tensor and
257 | update_op used by tf.metrics.
258 | """
259 | metrics = {}
260 | assert len(tensor_names) == len(tensors), "Names must match tensors."
261 | for i in range(len(tensors)):
262 | tensor = tensors[i]
263 | name = tensor_names[i]
264 | for k in (1, 5):
265 | metrics["top%d_accuracy_%s" % (k, name)] = _top_k_accuracy(
266 | k, labels, tensor)
267 |
268 | return metrics
269 |
270 | return _metrics
271 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # pylint: disable=missing-docstring
18 | """Preprocessing methods.
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | import functools
26 | import inception_preprocessing
27 | import tensorflow as tf
28 |
29 | import self_supervision.patch_model_preprocess as pp_lib
30 | import utils
31 |
32 |
33 | FLAGS = tf.flags.FLAGS
34 |
35 |
36 | def get_inception_preprocess(is_training, im_size):
37 | def _inception_preprocess(data):
38 | data["image"] = inception_preprocessing.preprocess_image(
39 | data["image"], im_size[0], im_size[1], is_training,
40 | add_image_summaries=False)
41 | return data
42 | return _inception_preprocess
43 |
44 |
45 | def get_resize_small(smaller_size):
46 | """Resizes the smaller side to `smaller_size` keeping aspect ratio."""
47 | def _resize_small_pp(data):
48 | image = data["image"]
49 | # A single image: HWC
50 | # A batch of images: BHWC
51 | h, w = tf.shape(image)[-3], tf.shape(image)[-2]
52 |
53 | # Figure out the necessary h/w.
54 | ratio = tf.to_float(smaller_size) / tf.to_float(tf.minimum(h, w))
55 | h = tf.to_int32(tf.round(tf.to_float(h) * ratio))
56 | w = tf.to_int32(tf.round(tf.to_float(w) * ratio))
57 |
58 | # NOTE: use align_corners=False for AREA resize, but True for Bilinear.
59 | # See also https://github.com/tensorflow/tensorflow/issues/6720
60 | static_rank = len(image.get_shape().as_list())
61 | if static_rank == 3: # A single image: HWC
62 | data["image"] = tf.image.resize_area(image[None], [h, w])[0]
63 | elif static_rank == 4: # A batch of images: BHWC
64 | data["image"] = tf.image.resize_area(image, [h, w])
65 | return data
66 | return _resize_small_pp
67 |
68 |
69 | def get_crop(is_training, crop_size):
70 | """Returns a random (or central at test-time) crop of `crop_size`."""
71 | def _crop_pp(data):
72 | crop_fn = functools.partial(
73 | pp_lib.crop, is_training=is_training, crop_size=crop_size)
74 | data["image"] = utils.tf_apply_to_image_or_images(crop_fn, data["image"])
75 |
76 | return data
77 | return _crop_pp
78 |
79 |
80 | def get_inception_crop(is_training, **kw):
81 | # kw of interest are: aspect_ratio_range, area_range.
82 | # Note that image is not resized yet here.
83 | def _inception_crop_pp(data):
84 | if is_training:
85 | image = data["image"]
86 | begin, size, _ = tf.image.sample_distorted_bounding_box(
87 | tf.shape(image), tf.zeros([0, 0, 4], tf.float32),
88 | use_image_if_no_bounding_boxes=True, **kw)
89 | data["image"] = tf.slice(image, begin, size)
90 | # Unfortunately, the above operation loses the depth-dimension. So we need
91 | # to Restore it the manual way.
92 | data["image"].set_shape([None, None, image.shape[-1]])
93 | return data
94 | return _inception_crop_pp
95 |
96 |
97 | def get_random_flip_lr(is_training):
98 | def _random_flip_lr_pp(data):
99 | if is_training:
100 | data["image"] = utils.tf_apply_to_image_or_images(
101 | tf.image.random_flip_left_right, data["image"])
102 | return data
103 | return _random_flip_lr_pp
104 |
105 |
106 | def get_resize_preprocess(im_size, randomize_resize_method=False):
107 |
108 | def _resize(image, method, align_corners):
109 |
110 | def _process():
111 | # The resized_images are of type float32 and might fall outside of range
112 | # [0, 255].
113 | resized = tf.cast(
114 | tf.image.resize_images(
115 | image, im_size, method, align_corners=align_corners),
116 | dtype=tf.float32)
117 | return resized
118 |
119 | return _process
120 |
121 | def _resize_pp(data):
122 | im = data["image"]
123 |
124 | if randomize_resize_method:
125 | # pick random resizing method
126 | r = tf.random_uniform([], 0, 3, dtype=tf.int32)
127 | im = tf.case({
128 | tf.equal(r, tf.cast(0, r.dtype)):
129 | _resize(im, tf.image.ResizeMethod.BILINEAR, True),
130 | tf.equal(r, tf.cast(1, r.dtype)):
131 | _resize(im, tf.image.ResizeMethod.NEAREST_NEIGHBOR, True),
132 | tf.equal(r, tf.cast(2, r.dtype)):
133 | _resize(im, tf.image.ResizeMethod.BICUBIC, True),
134 | # NOTE: use align_corners=False for AREA resize, but True for the
135 | # others. See https://github.com/tensorflow/tensorflow/issues/6720
136 | tf.equal(r, tf.cast(3, r.dtype)):
137 | _resize(im, tf.image.ResizeMethod.AREA, False),
138 | })
139 | else:
140 | im = tf.image.resize_images(im, im_size)
141 | data["image"] = im
142 | return data
143 |
144 | return _resize_pp
145 |
146 |
147 | def get_rotate_preprocess():
148 | """Returns a function that does 90deg rotations and sets according labels."""
149 |
150 | def _rotate_pp(data):
151 | data["label"] = tf.constant([0, 1, 2, 3])
152 | # We use our own instead of tf.image.rot90 because that one broke
153 | # internally shortly before deadline...
154 | data["image"] = tf.stack([
155 | data["image"],
156 | tf.transpose(tf.reverse_v2(data["image"], [1]), [1, 0, 2]),
157 | tf.reverse_v2(data["image"], [0, 1]),
158 | tf.reverse_v2(tf.transpose(data["image"], [1, 0, 2]), [1]),
159 | ])
160 | return data
161 |
162 | return _rotate_pp
163 |
164 |
165 | def get_value_range_preprocess(vmin=-1, vmax=1, dtype=tf.float32):
166 | """Returns a function that sends [0,255] image to [vmin,vmax]."""
167 |
168 | def _value_range_pp(data):
169 | img = tf.cast(data["image"], dtype)
170 | img = vmin + (img / tf.constant(255.0, dtype)) * (vmax - vmin)
171 | data["image"] = img
172 | return data
173 | return _value_range_pp
174 |
175 |
176 | def get_standardization_preprocess():
177 | def _standardization_pp(data):
178 | # Trick: normalize each patch to avoid low level statistics.
179 | data["image"] = utils.tf_apply_to_image_or_images(
180 | tf.image.per_image_standardization, data["image"])
181 | return data
182 | return _standardization_pp
183 |
184 |
185 | def get_inception_preprocess_patches(is_training, resize_size, num_of_patches):
186 |
187 | def _inception_preprocess_patches(data):
188 | patches = []
189 | for _ in range(num_of_patches):
190 | patches.append(
191 | inception_preprocessing.preprocess_image(
192 | data["image"],
193 | resize_size[0],
194 | resize_size[1],
195 | is_training,
196 | add_image_summaries=False))
197 | patches = tf.stack(patches)
198 | data["image"] = patches
199 | return data
200 |
201 | return _inception_preprocess_patches
202 |
203 |
204 | def get_to_gray_preprocess(grayscale_probability):
205 |
206 | def _to_gray(image):
207 | # Transform to grayscale by taking the mean of RGB.
208 | return tf.tile(tf.reduce_mean(image, axis=2, keepdims=True), [1, 1, 3])
209 |
210 | def _to_gray_pp(data):
211 | data["image"] = utils.tf_apply_to_image_or_images(
212 | lambda img: utils.tf_apply_with_probability( # pylint:disable=g-long-lambda
213 | grayscale_probability, _to_gray, img),
214 | data["image"])
215 | return data
216 |
217 | return _to_gray_pp
218 |
219 |
220 | def get_preprocess_fn(fn_names, is_training):
221 | """Returns preprocessing function.
222 |
223 | Args:
224 | fn_names: name of a preprocessing function.
225 | is_training: Whether this should be run in train or eval mode.
226 | Returns:
227 | preprocessing function
228 |
229 | Raises:
230 | ValueError: if preprocessing function name is unknown
231 | """
232 |
233 | def _fn(data):
234 | def expand(fn_name):
235 | if fn_name == "plain_preprocess":
236 | yield lambda x: x
237 | elif fn_name == "0_to_1":
238 | yield get_value_range_preprocess(0, 1)
239 | elif fn_name == "-1_to_1":
240 | yield get_value_range_preprocess(-1, 1)
241 | elif fn_name == "resize":
242 | yield get_resize_preprocess(
243 | utils.str2intlist(FLAGS.resize_size, 2),
244 | is_training and FLAGS.get_flag_value("randomize_resize_method",
245 | False))
246 | elif fn_name == "resize_small":
247 | yield get_resize_small(FLAGS.smaller_size)
248 | elif fn_name == "crop":
249 | yield get_crop(is_training,
250 | utils.str2intlist(FLAGS.crop_size, 2))
251 | elif fn_name == "central_crop":
252 | yield get_crop(False, utils.str2intlist(FLAGS.crop_size, 2))
253 | elif fn_name == "inception_crop":
254 | yield get_inception_crop(is_training)
255 | elif fn_name == "flip_lr":
256 | yield get_random_flip_lr(is_training)
257 | elif fn_name == "crop_inception_preprocess_patches":
258 | yield get_inception_preprocess_patches(
259 | is_training, utils.str2intlist(FLAGS.resize_size, 2),
260 | FLAGS.num_of_inception_patches)
261 | elif fn_name == "to_gray":
262 | yield get_to_gray_preprocess(
263 | FLAGS.get_flag_value("grayscale_probability", 1.0))
264 | elif fn_name == "crop_patches":
265 | yield pp_lib.get_crop_patches_fn(
266 | is_training,
267 | split_per_side=FLAGS.splits_per_side,
268 | patch_jitter=FLAGS.get_flag_value("patch_jitter", 0))
269 | elif fn_name == "standardization":
270 | yield get_standardization_preprocess()
271 | elif fn_name == "rotate":
272 | yield get_rotate_preprocess()
273 |
274 | # Below this line specific combos decomposed.
275 | # It would be nice to move them to the configs at some point.
276 |
277 | elif fn_name == "inception_preprocess":
278 | yield get_inception_preprocess(
279 | is_training, utils.str2intlist(FLAGS.resize_size, 2))
280 | else:
281 | raise ValueError("Not supported preprocessing %s" % fn_name)
282 |
283 | # Apply all the individual steps in sequence.
284 | tf.logging.info("Data before pre-processing:\n%s", data)
285 | for fn_name in fn_names.split(","):
286 | for p in expand(fn_name.strip()):
287 | data = p(data)
288 | tf.logging.info("Data after `%s`:\n%s", p, data)
289 | return data
290 |
291 | return _fn
292 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Datasets class to provide images and labels in tf batch.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import abc
25 | import os
26 |
27 | import tensorflow as tf
28 |
29 | from preprocess import get_preprocess_fn
30 |
31 | FLAGS = tf.flags.FLAGS
32 |
33 |
34 | class AbstractDataset(object):
35 | """Base class for datasets using the simplied input pipeline."""
36 |
37 | def __init__(self,
38 | filenames,
39 | reader,
40 | num_epochs,
41 | shuffle,
42 | shuffle_buffer_size=10000,
43 | random_seed=None,
44 | num_reader_threads=64,
45 | drop_remainder=True):
46 | """Creates a new dataset. Sub-classes have to implement _parse_fn().
47 |
48 | Args:
49 | filenames: A list of filenames.
50 | reader: A dataset reader, e.g. `tf.data.TFRecordDataset`.
51 | `tf.data.TextLineDataset` and `tf.data.FixedLengthRecordDataset`.
52 | num_epochs: An int, defaults to `None`. Number of epochs to cycle
53 | through the dataset before stopping. If set to `None` this will read
54 | samples indefinitely.
55 | shuffle: A boolean, defaults to `False`. Whether output data are
56 | shuffled.
57 | shuffle_buffer_size: `int`, number of examples in the buffer for
58 | shuffling.
59 | random_seed: Optional int. Random seed for shuffle operation.
60 | num_reader_threads: An int, defaults to None. Number of threads reading
61 | from files. When `shuffle` is False, number of threads is set to 1. When
62 | using default value, there is one thread per filenames.
63 | drop_remainder: If true, then the last incomplete batch is dropped.
64 | """
65 | self.filenames = filenames
66 | self.reader = reader
67 | self.num_reader_threads = num_reader_threads
68 | self.num_epochs = num_epochs
69 | self.shuffle = shuffle
70 | self.shuffle_buffer_size = shuffle_buffer_size
71 | self.random_seed = random_seed
72 | self.drop_remainder = drop_remainder
73 |
74 | # Additional options for optimizing TPU input pipelines.
75 | self.num_parallel_batches = 8
76 |
77 | def _make_source_dataset(self):
78 | """Reads the files in self.filenames and returns a `tf.data.Dataset`.
79 |
80 | This does not parse the examples!
81 |
82 | Returns:
83 | `tf.data.Dataset` repeated for self.num_epochs and shuffled if
84 | self.shuffle is `True`. Files are always read in parallel and sloppy.
85 | """
86 | # Shuffle the filenames to ensure better randomization.
87 | dataset = tf.data.Dataset.list_files(self.filenames, shuffle=self.shuffle,
88 | seed=self.random_seed)
89 |
90 | dataset = dataset.repeat(self.num_epochs)
91 |
92 | def fetch_dataset(filename):
93 | buffer_size = 8 * 1024 * 1024 # 8 MiB per file
94 | dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size)
95 | return dataset
96 |
97 | # Read the data from disk in parallel
98 | dataset = dataset.apply(
99 | tf.data.experimental.parallel_interleave(
100 | fetch_dataset,
101 | cycle_length=self.num_reader_threads,
102 | sloppy=self.shuffle and self.random_seed is None))
103 |
104 | if self.shuffle:
105 | dataset = dataset.shuffle(self.shuffle_buffer_size, seed=self.random_seed)
106 | return dataset
107 |
108 | @abc.abstractmethod
109 | def _parse_fn(self, value):
110 | """Parses an image and its label from a serialized TFExample.
111 |
112 | Args:
113 | value: serialized string containing an TFExample.
114 |
115 | Returns:
116 | Returns a tuple of (image, label) from the TFExample.
117 | """
118 | raise NotImplementedError
119 |
120 | def input_fn(self, params):
121 | """Input function which provides a single batch for train or eval.
122 |
123 | Args:
124 | params: `dict` of parameters passed from the `TPUEstimator`.
125 | `params['batch_size']` is provided and should be used as the effective
126 | batch size.
127 |
128 | Returns:
129 | A `tf.data.Dataset` object.
130 | """
131 | # Retrieves the batch size for the current shard. The # of shards is
132 | # computed according to the input pipeline deployment. See
133 | # tf.contrib.tpu.RunConfig for details.
134 | batch_size = params['batch_size']
135 |
136 | dataset = self._make_source_dataset()
137 |
138 | # Use the fused map-and-batch operation.
139 | #
140 | # For XLA, we must used fixed shapes. Because we repeat the source training
141 | # dataset indefinitely, we can use `drop_remainder=True` to get fixed-size
142 | # batches without dropping any training examples.
143 | #
144 | # When evaluating, `drop_remainder=True` prevents accidentally evaluating
145 | # the same image twice by dropping the final batch if it is less than a full
146 | # batch size. As long as this validation is done with consistent batch size,
147 | # exactly the same images will be used.
148 | dataset = dataset.apply(
149 | tf.data.experimental.map_and_batch(
150 | self._parse_fn,
151 | batch_size=batch_size,
152 | num_parallel_batches=self.num_parallel_batches,
153 | drop_remainder=self.drop_remainder))
154 |
155 | # Prefetch overlaps in-feed with training
156 | dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
157 | return dataset
158 |
159 |
160 | def generate_sharded_filenames(filename):
161 | base, count = filename.split('@')
162 | count = int(count)
163 | return ['{}-{:05d}-of-{:05d}'.format(base, i, count)
164 | for i in range(count)]
165 |
166 |
167 | class DatasetImagenet(AbstractDataset):
168 | """Provides train/val/trainval/test splits for Imagenet data.
169 |
170 | -> trainval split represents official Imagenet train split.
171 | -> train split is derived by taking the first 984 of 1024 shards of
172 | the offcial training data.
173 | -> val split is derived by taking the last 40 shard of the official
174 | training data.
175 | -> test split represents official Imagenet test split.
176 | """
177 |
178 | COUNTS = {'train': 1231121,
179 | 'val': 50046,
180 | 'trainval': 1281167,
181 | 'test': 50000}
182 |
183 | NUM_CLASSES = 1000
184 | IMAGE_KEY = 'image/encoded'
185 | LABEL_KEY = 'image/class/label'
186 |
187 | FEATURE_MAP = {
188 | IMAGE_KEY: tf.FixedLenFeature(shape=[], dtype=tf.string),
189 | LABEL_KEY: tf.FixedLenFeature(shape=[], dtype=tf.int64)
190 | }
191 |
192 | LABEL_OFFSET = 1
193 |
194 | def __init__(self,
195 | split_name,
196 | preprocess_fn,
197 | num_epochs,
198 | shuffle,
199 | random_seed=None,
200 | drop_remainder=True):
201 | """Initialize the dataset object.
202 |
203 | Args:
204 | split_name: A string split name, to load from the dataset.
205 | preprocess_fn: Preprocess a single example. The example is already
206 | parsed into a dictionary.
207 | num_epochs: An int, defaults to `None`. Number of epochs to cycle
208 | through the dataset before stopping. If set to `None` this will read
209 | samples indefinitely.
210 | shuffle: A boolean, defaults to `False`. Whether output data are
211 | shuffled.
212 | random_seed: Optional int. Random seed for shuffle operation.
213 | drop_remainder: If true, then the last incomplete batch is dropped.
214 | """
215 | # This is an instance-variable instead of a class-variable because it
216 | # depends on FLAGS, which is not parsed yet at class-parse-time.
217 | files = os.path.join(os.path.expanduser(FLAGS.dataset_dir),
218 | '%s@%i')
219 | filenames = {
220 | 'train': generate_sharded_filenames(files % ('train', 1024))[:-40],
221 | 'val': generate_sharded_filenames(files % ('train', 1024))[-40:],
222 | 'trainval': generate_sharded_filenames(files % ('train', 1024)),
223 | 'test': generate_sharded_filenames(files % ('validation', 128))
224 | }
225 |
226 | super(DatasetImagenet, self).__init__(
227 | filenames=filenames[split_name],
228 | reader=tf.data.TFRecordDataset,
229 | num_epochs=num_epochs,
230 | shuffle=shuffle,
231 | random_seed=random_seed,
232 | drop_remainder=drop_remainder)
233 | self.split_name = split_name
234 | self.preprocess_fn = preprocess_fn
235 |
236 | def _parse_fn(self, value):
237 | """Parses an image and its label from a serialized TFExample.
238 |
239 | Args:
240 | value: serialized string containing an TFExample.
241 |
242 | Returns:
243 | Returns a tuple of (image, label) from the TFExample.
244 | """
245 | example = tf.parse_single_example(value, self.FEATURE_MAP)
246 | image = tf.image.decode_jpeg(example[self.IMAGE_KEY], channels=3)
247 | # Subtract LABEL_OFFSET so that labels are in [0, 1000).
248 | label = tf.cast(example[self.LABEL_KEY], tf.int32) - self.LABEL_OFFSET
249 |
250 | return self.preprocess_fn({'image': image, 'label': label})
251 |
252 |
253 | DATASET_MAP = {
254 | 'imagenet': DatasetImagenet,
255 | }
256 |
257 |
258 | def get_data(params,
259 | split_name,
260 | is_training,
261 | shuffle=True,
262 | num_epochs=None,
263 | drop_remainder=False):
264 | """Produces image/label tensors for a given dataset.
265 |
266 | Args:
267 | params: dictionary with `batch_size` entry (thanks TPU...).
268 | split_name: data split, e.g. train, val, test
269 | is_training: whether to run pre-processing in train or test mode.
270 | shuffle: if True, shuffles the data
271 | num_epochs: number of epochs. If None, proceeds indefenitely
272 | drop_remainder: Drop remaining examples in the last dataset batch. It is
273 | useful for third party checkpoints with fixed batch size.
274 |
275 | Returns:
276 | image, label, example counts
277 | """
278 | dataset = DATASET_MAP[FLAGS.dataset]
279 | preprocess_fn = get_preprocess_fn(FLAGS.preprocessing, is_training)
280 |
281 | return dataset(
282 | split_name=split_name,
283 | preprocess_fn=preprocess_fn,
284 | num_epochs=num_epochs,
285 | shuffle=shuffle,
286 | random_seed=FLAGS.get_flag_value('random_seed', None),
287 | drop_remainder=drop_remainder).input_fn(params)
288 |
289 |
290 | def get_count(split_name):
291 | return DATASET_MAP[FLAGS.dataset].COUNTS[split_name]
292 |
293 |
294 | def get_num_classes():
295 | return DATASET_MAP[FLAGS.dataset].NUM_CLASSES
296 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright 2019 Google LLC
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/self_supervision/patch_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Utils for patch based image processing."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import functools
24 | import struct
25 | import numpy as np
26 | import tensorflow as tf
27 | import tensorflow_hub as hub
28 |
29 | import preprocess
30 | import utils
31 | from models.utils import get_net
32 | from trainer import make_estimator
33 |
34 | FLAGS = tf.flags.FLAGS
35 |
36 | PATCH_H_COUNT = 3
37 | PATCH_W_COUNT = 3
38 | PATCH_COUNT = PATCH_H_COUNT * PATCH_W_COUNT
39 |
40 |
41 | # It's supposed to be in the root folder, which is also pwd when running, if the
42 | # instructions in the README are followed. Hence not a flag.
43 | PERMUTATION_PATH = 'permutations_100_max.bin'
44 |
45 |
46 | def apply_model(image_fn,
47 | is_training,
48 | num_outputs,
49 | perms,
50 | make_signature=False):
51 | """Creates the patch based model output from patches representations.
52 |
53 | Args:
54 | image_fn: function returns image tensor.
55 | is_training: is training flag used for batch norm and drop out.
56 | num_outputs: number of output classes.
57 | perms: numpy array with shape [m, k], element range [0, PATCH_COUNT). k
58 | stands for the patch numbers used in a permutation. m stands forthe number
59 | of permutations. Each permutation is used to concat the patch inputs
60 | [n*PATCH_COUNT, h, w, c] into tensor with shape [n*m, h, w, c*k].
61 | make_signature: whether to create signature for hub module.
62 |
63 | Returns:
64 | out: output tensor with shape [n*m, 1, 1, num_outputs].
65 |
66 | Raises:
67 | ValueError: An error occurred when the architecture is unknown.
68 | """
69 | images = image_fn()
70 |
71 | net = get_net(num_classes=FLAGS.get_flag_value('embed_dim', 1000))
72 | out, end_points = net(images, is_training,
73 | weight_decay=FLAGS.get_flag_value('weight_decay', 1e-4))
74 |
75 | print(end_points)
76 | if not make_signature:
77 | out = permutate_and_concat_batch_patches(out, perms)
78 | out = fully_connected(out, num_outputs, is_training=is_training)
79 |
80 | out = tf.squeeze(out, [1, 2])
81 |
82 | if make_signature:
83 | hub.add_signature(inputs={'image': images}, outputs=out)
84 | hub.add_signature(
85 | name='representation',
86 | inputs={'image': images},
87 | outputs=end_points)
88 | return out
89 |
90 |
91 | def image_grid(images, ny, nx, padding=0):
92 | """Create a batch of image grids from a batch of images.
93 |
94 | Args:
95 | images: A batch of patches (B,N,H,W,C)
96 | ny: vertical number of images
97 | nx: horizontal number of images
98 | padding: number of zeros between images, if any.
99 |
100 | Returns:
101 | A tensor batch of image grids shaped (B,H*ny,W*nx,C), although that is a
102 | simplifying lie: if padding is used h/w will be different.
103 | """
104 | with tf.name_scope('grid_image'):
105 | if padding:
106 | padding = [padding, padding]
107 | images = tf.pad(images, [[0, 0], [0, 0], padding, padding, [0, 0]])
108 |
109 | return tf.concat([
110 | tf.concat([images[:, y * nx + x] for x in range(nx)], axis=-2)
111 | for y in range(ny)], axis=-3)
112 |
113 |
114 | def creates_estimator_model(images, labels, perms, num_classes, mode):
115 | """Creates EstimatorSpec for the patch based self supervised models.
116 |
117 | Args:
118 | images: images
119 | labels: self supervised labels (class indices)
120 | perms: patch permutations
121 | num_classes: number of different permutations
122 | mode: model's mode: training, eval or prediction
123 |
124 | Returns:
125 | EstimatorSpec
126 | """
127 | print(' +++ Mode: %s, images: %s, labels: %s' % (mode, images, labels))
128 |
129 | images = tf.reshape(images, shape=[-1] + images.get_shape().as_list()[-3:])
130 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
131 | with tf.variable_scope('module'):
132 | image_fn = lambda: images
133 | logits = apply_model(
134 | image_fn=image_fn,
135 | is_training=(mode == tf.estimator.ModeKeys.TRAIN),
136 | num_outputs=num_classes,
137 | perms=perms,
138 | make_signature=False)
139 | else:
140 | input_shape = utils.str2intlist(
141 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3'))
142 | image_fn = lambda: tf.placeholder( # pylint: disable=g-long-lambda
143 | shape=input_shape,
144 | dtype=tf.float32)
145 |
146 | apply_model_function = functools.partial(
147 | apply_model,
148 | image_fn=image_fn,
149 | num_outputs=num_classes,
150 | perms=perms,
151 | make_signature=True)
152 |
153 | tf_hub_module_spec = hub.create_module_spec(
154 | apply_model_function, [(utils.TAGS_IS_TRAINING, {
155 | 'is_training': True
156 | }), (set(), {
157 | 'is_training': False
158 | })],
159 | drop_collections=['summaries'])
160 | tf_hub_module = hub.Module(tf_hub_module_spec, trainable=False, tags=set())
161 | hub.register_module_for_export(tf_hub_module, export_name='module')
162 | logits = tf_hub_module(images)
163 | return make_estimator(mode, predictions=logits)
164 |
165 | # build loss and accuracy
166 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
167 | labels=labels, logits=logits)
168 | loss = tf.reduce_mean(loss)
169 |
170 | eval_metrics = (
171 | lambda labels, logits: { # pylint: disable=g-long-lambda
172 | 'accuracy': tf.metrics.accuracy(
173 | labels=labels, predictions=tf.argmax(logits, axis=-1))},
174 | [labels, logits])
175 | return make_estimator(mode, loss, eval_metrics, logits)
176 |
177 |
178 | def fully_connected(inputs,
179 | num_classes=100,
180 | weight_decay=5e-4,
181 | keep_prob=0.5,
182 | is_training=True):
183 | """Two layers fully connected network copied from Alexnet fc7-fc8."""
184 | net = inputs
185 | _, _, w, _ = net.get_shape().as_list()
186 | kernel_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
187 | net = tf.layers.conv2d(
188 | net,
189 | filters=4096,
190 | kernel_size=w,
191 | padding='same',
192 | kernel_initializer=tf.truncated_normal_initializer(0.0, 0.005),
193 | bias_initializer=tf.constant_initializer(0.1),
194 | kernel_regularizer=kernel_regularizer)
195 | net = tf.layers.batch_normalization(
196 | net, momentum=0.997, epsilon=1e-5, fused=None, training=is_training)
197 | net = tf.nn.relu(net)
198 | if is_training:
199 | net = tf.nn.dropout(net, keep_prob=keep_prob)
200 | net = tf.layers.conv2d(
201 | net,
202 | filters=num_classes,
203 | kernel_size=1,
204 | padding='same',
205 | kernel_initializer=tf.truncated_normal_initializer(0.0, 0.005),
206 | bias_initializer=tf.zeros_initializer(),
207 | kernel_regularizer=kernel_regularizer)
208 |
209 | return net
210 |
211 |
212 | def generate_patch_locations():
213 | """Generates relative patch locations."""
214 | perms = np.array([(i, 4) for i in range(9) if i != 4])
215 | return perms, len(perms)
216 |
217 |
218 | def load_permutations():
219 | """Loads a set of pre-defined permutations."""
220 | with tf.gfile.Open(PERMUTATION_PATH, 'rb') as f:
221 | int32_size = 4
222 | s = f.read(int32_size * 2)
223 | [num_perms, c] = struct.unpack(' NPHWC
358 | t = tf.reshape(t, [-1, num_of_patches] + t.get_shape().as_list()[-3:])
359 | if combine_patches == 'concat':
360 | # [N, P, H, W, C] -> [N, H, W, P*C]
361 | _, p, h, w, c = t.get_shape().as_list()
362 | out_tensors[k] = tf.reshape(
363 | tf.transpose(t, perm=[0, 2, 3, 4, 1]), tf.stack([-1, h, w, p * c]))
364 | elif combine_patches == 'max_pool':
365 | # Reduce max on P channel of NPHWC.
366 | out_tensors[k] = tf.reduce_max(t, axis=1)
367 | elif combine_patches == 'avg_pool':
368 | # Reduce mean on P channel of NPHWC.
369 | out_tensors[k] = tf.reduce_mean(t, axis=1)
370 | else:
371 | raise ValueError(
372 | 'Unsupported combine patches method %s.' % combine_patches)
373 |
374 | return out_tensors
375 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements Resnet model.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import functools
25 |
26 | import tensorflow as tf
27 |
28 |
29 | def get_shape_as_list(x):
30 | return x.get_shape().as_list()
31 |
32 |
33 | def fixed_padding(x, kernel_size):
34 | pad_total = kernel_size - 1
35 | pad_beg = pad_total // 2
36 | pad_end = pad_total - pad_beg
37 |
38 | x = tf.pad(x, [[0, 0],
39 | [pad_beg, pad_end], [pad_beg, pad_end],
40 | [0, 0]])
41 | return x
42 |
43 |
44 | def batch_norm(x, training):
45 | return tf.layers.batch_normalization(x, fused=True, training=training)
46 |
47 |
48 | def identity_norm(x, training):
49 | del training
50 | return x
51 |
52 |
53 | def bottleneck_v1(x, filters, training, # pylint: disable=missing-docstring
54 | strides=1,
55 | activation_fn=tf.nn.relu,
56 | normalization_fn=batch_norm,
57 | kernel_regularizer=None):
58 |
59 | # Record input tensor, such that it can be used later in as skip-connection
60 | x_shortcut = x
61 |
62 | # Project input if necessary
63 | if (strides > 1) or (filters != x.shape[-1]):
64 | x_shortcut = tf.layers.conv2d(x_shortcut, filters=filters, kernel_size=1,
65 | strides=strides,
66 | kernel_regularizer=kernel_regularizer,
67 | use_bias=False,
68 | padding='SAME')
69 | x_shortcut = normalization_fn(x_shortcut, training=training)
70 |
71 | # First convolution
72 | # Note, that unlike original Resnet paper we never use stride in the first
73 | # convolution. Instead, we apply stride in the second convolution. The reason
74 | # is that the first convolution has kernel of size 1x1, which results in
75 | # information loss when combined with stride bigger than one.
76 | x = tf.layers.conv2d(x, filters=filters // 4,
77 | kernel_size=1,
78 | kernel_regularizer=kernel_regularizer,
79 | use_bias=False,
80 | padding='SAME')
81 | x = normalization_fn(x, training=training)
82 | x = activation_fn(x)
83 |
84 | # Second convolution
85 | x = fixed_padding(x, kernel_size=3)
86 | x = tf.layers.conv2d(x, filters=filters // 4,
87 | strides=strides,
88 | kernel_size=3,
89 | kernel_regularizer=kernel_regularizer,
90 | use_bias=False,
91 | padding='VALID')
92 | x = normalization_fn(x, training=training)
93 | x = activation_fn(x)
94 |
95 | # Third convolution
96 | x = tf.layers.conv2d(x, filters=filters,
97 | kernel_size=1,
98 | kernel_regularizer=kernel_regularizer,
99 | use_bias=False,
100 | padding='SAME')
101 | x = normalization_fn(x, training=training)
102 |
103 | # Skip connection
104 | x = x_shortcut + x
105 | x = activation_fn(x)
106 |
107 | return x
108 |
109 |
110 | def bottleneck_v2(x, filters, training, # pylint: disable=missing-docstring
111 | strides=1,
112 | activation_fn=tf.nn.relu,
113 | normalization_fn=batch_norm,
114 | kernel_regularizer=None,
115 | no_shortcut=False):
116 |
117 | # Record input tensor, such that it can be used later in as skip-connection
118 | x_shortcut = x
119 |
120 | x = normalization_fn(x, training=training)
121 | x = activation_fn(x)
122 |
123 | # Project input if necessary
124 | if (strides > 1) or (filters != x.shape[-1]):
125 | x_shortcut = tf.layers.conv2d(x, filters=filters, kernel_size=1,
126 | strides=strides,
127 | kernel_regularizer=kernel_regularizer,
128 | use_bias=False,
129 | padding='VALID')
130 |
131 | # First convolution
132 | # Note, that unlike original Resnet paper we never use stride in the first
133 | # convolution. Instead, we apply stride in the second convolution. The reason
134 | # is that the first convolution has kernel of size 1x1, which results in
135 | # information loss when combined with stride bigger than one.
136 | x = tf.layers.conv2d(x, filters=filters // 4,
137 | kernel_size=1,
138 | kernel_regularizer=kernel_regularizer,
139 | use_bias=False,
140 | padding='SAME')
141 |
142 | # Second convolution
143 | x = normalization_fn(x, training=training)
144 | x = activation_fn(x)
145 | # Note, that padding depends on the dilation rate.
146 | x = fixed_padding(x, kernel_size=3)
147 | x = tf.layers.conv2d(x, filters=filters // 4,
148 | strides=strides,
149 | kernel_size=3,
150 | kernel_regularizer=kernel_regularizer,
151 | use_bias=False,
152 | padding='VALID')
153 |
154 | # Third convolution
155 | x = normalization_fn(x, training=training)
156 | x = activation_fn(x)
157 | x = tf.layers.conv2d(x, filters=filters,
158 | kernel_size=1,
159 | kernel_regularizer=kernel_regularizer,
160 | use_bias=False,
161 | padding='SAME')
162 |
163 | if no_shortcut:
164 | return x
165 | else:
166 | return x + x_shortcut
167 |
168 |
169 | def resnet(x, # pylint: disable=missing-docstring
170 | is_training,
171 | num_layers,
172 | strides=(2, 2, 2),
173 | num_classes=1000,
174 | filters_factor=4,
175 | weight_decay=1e-4,
176 | include_root_block=True,
177 | root_conv_size=7, root_conv_stride=2,
178 | root_pool_size=3, root_pool_stride=2,
179 | activation_fn=tf.nn.relu,
180 | last_relu=True,
181 | normalization_fn=batch_norm,
182 | global_pool=True,
183 | mode='v2'):
184 |
185 | assert mode in ['v1', 'v2'], 'Unknown Resnet mode: {}'.format(mode)
186 | unit = bottleneck_v2 if mode == 'v2' else bottleneck_v1
187 |
188 | end_points = {}
189 |
190 | filters = 16 * filters_factor
191 |
192 | kernel_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
193 |
194 | if include_root_block:
195 | x = fixed_padding(x, kernel_size=root_conv_size)
196 | x = tf.layers.conv2d(x, filters=filters,
197 | kernel_size=root_conv_size,
198 | strides=root_conv_stride,
199 | padding='VALID', use_bias=False,
200 | kernel_regularizer=kernel_regularizer)
201 |
202 | if mode == 'v1':
203 | x = normalization_fn(x, training=is_training)
204 | x = activation_fn(x)
205 |
206 | x = fixed_padding(x, kernel_size=root_pool_size)
207 | x = tf.layers.max_pooling2d(x, pool_size=root_pool_size,
208 | strides=root_pool_stride, padding='VALID')
209 | end_points['after_root'] = x
210 |
211 | params = {'activation_fn': activation_fn,
212 | 'normalization_fn': normalization_fn,
213 | 'training': is_training,
214 | 'kernel_regularizer': kernel_regularizer,
215 | }
216 |
217 | strides = list(strides)[::-1]
218 | num_layers = list(num_layers)[::-1]
219 |
220 | filters *= 4
221 | for _ in range(num_layers.pop()):
222 | x = unit(x, filters, strides=1, **params)
223 | end_points['block1'] = x
224 |
225 | filters *= 2
226 | x = unit(x, filters, strides=strides.pop(), **params)
227 | for _ in range(num_layers.pop() - 1):
228 | x = unit(x, filters, strides=1, **params)
229 | end_points['block2'] = x
230 |
231 | filters *= 2
232 | x = unit(x, filters, strides=strides.pop(), **params)
233 | for _ in range(num_layers.pop() - 1):
234 | x = unit(x, filters, strides=1, **params)
235 | end_points['block3'] = x
236 |
237 | filters *= 2
238 | x = unit(x, filters, strides=strides.pop(), **params)
239 | for _ in range(num_layers.pop() - 1):
240 | x = unit(x, filters, strides=1, **params)
241 | end_points['block4'] = x
242 |
243 | if (mode == 'v1') and (not last_relu):
244 | raise ValueError('last_relu is always True (implicitly) in the v1 mode.')
245 |
246 | if mode == 'v2':
247 | x = normalization_fn(x, training=is_training)
248 | if last_relu:
249 | x = activation_fn(x)
250 |
251 | if global_pool:
252 | x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
253 | end_points['pre_logits'] = tf.squeeze(x, [1, 2])
254 | else:
255 | end_points['pre_logits'] = x
256 |
257 | if num_classes:
258 | logits = tf.layers.conv2d(x, filters=num_classes,
259 | kernel_size=1,
260 | kernel_regularizer=kernel_regularizer)
261 | if global_pool:
262 | logits = tf.squeeze(logits, [1, 2])
263 | end_points['logits'] = logits
264 | return logits, end_points
265 | else:
266 | return end_points['pre_logits'], end_points
267 |
268 | resnet50 = functools.partial(resnet, num_layers=(3, 4, 6, 3))
269 |
270 | # Experimental code ########################################
271 | # "Reversible" resnet ######################################
272 |
273 |
274 | # Invertible residual block as outlined in https://arxiv.org/abs/1707.04585
275 | def bottleneck_rev(x, training, # pylint: disable=missing-docstring
276 | activation_fn=tf.nn.relu,
277 | normalization_fn=batch_norm,
278 | kernel_regularizer=None):
279 |
280 | unit = bottleneck_v2
281 |
282 | x1, x2 = tf.split(x, 2, 3)
283 |
284 | y1 = x1 + unit(x2, x2.shape[-1], training,
285 | strides=1,
286 | activation_fn=activation_fn,
287 | normalization_fn=normalization_fn,
288 | kernel_regularizer=kernel_regularizer,
289 | no_shortcut=True)
290 | y2 = x2
291 |
292 | return tf.concat([y2, y1], axis=3)
293 |
294 |
295 | def pool_and_double_channels(x, pool_stride):
296 | if pool_stride > 1:
297 | x = tf.layers.average_pooling2d(x, pool_size=pool_stride,
298 | strides=pool_stride,
299 | padding='SAME')
300 | return tf.pad(x, [[0, 0], [0, 0], [0, 0],
301 | [x.shape[3] // 2, x.shape[3] // 2]])
302 |
303 |
304 | def revnet(x, # pylint: disable=missing-docstring
305 | is_training,
306 | num_layers,
307 | strides=(2, 2, 2),
308 | num_classes=1000,
309 | filters_factor=4,
310 | weight_decay=1e-4,
311 | include_root_block=True,
312 | root_conv_size=7, root_conv_stride=2,
313 | root_pool_size=3, root_pool_stride=2,
314 | global_pool=True,
315 | activation_fn=tf.nn.relu,
316 | normalization_fn=batch_norm,
317 | last_relu=False,
318 | mode='v2'):
319 |
320 | del mode # unused parameter, exists for compatibility with resnet function
321 |
322 | unit = bottleneck_rev
323 |
324 | end_points = {}
325 |
326 | filters = 16 * filters_factor
327 |
328 | kernel_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
329 |
330 | # First convolution serves as random projection in order to increase number
331 | # of channels. It is not possible to skip it.
332 | x = fixed_padding(x, kernel_size=root_conv_size)
333 | x = tf.layers.conv2d(x, filters=4 * filters,
334 | kernel_size=root_conv_size,
335 | strides=root_conv_stride,
336 | padding='VALID', use_bias=False,
337 | kernel_regularizer=None)
338 |
339 | if include_root_block:
340 | x = fixed_padding(x, kernel_size=root_pool_size)
341 | x = tf.layers.max_pooling2d(
342 | x, pool_size=root_pool_size, strides=root_pool_stride, padding='VALID')
343 |
344 | end_points['after_root'] = x
345 |
346 | params = {'activation_fn': activation_fn,
347 | 'normalization_fn': normalization_fn,
348 | 'training': is_training,
349 | 'kernel_regularizer': kernel_regularizer,
350 | }
351 |
352 | num_layers = list(num_layers)[::-1]
353 | strides = list(strides)[::-1]
354 |
355 | for _ in range(num_layers.pop()):
356 | x = unit(x, **params)
357 | end_points['block1'] = x
358 | x = pool_and_double_channels(x, strides.pop())
359 |
360 | for _ in range(num_layers.pop()):
361 | x = unit(x, **params)
362 | end_points['block2'] = x
363 | x = pool_and_double_channels(x, strides.pop())
364 |
365 | for _ in range(num_layers.pop()):
366 | x = unit(x, **params)
367 | end_points['block3'] = x
368 | x = pool_and_double_channels(x, strides.pop())
369 |
370 | for _ in range(num_layers.pop()):
371 | x = unit(x, **params)
372 | end_points['block4'] = x
373 |
374 | x = normalization_fn(x, training=is_training)
375 |
376 | if last_relu:
377 | x = activation_fn(x)
378 |
379 | if global_pool:
380 | x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
381 | end_points['pre_logits'] = tf.squeeze(x, [1, 2])
382 | else:
383 | end_points['pre_logits'] = x
384 |
385 | if num_classes:
386 | logits = tf.layers.conv2d(x, filters=num_classes,
387 | kernel_size=1,
388 | kernel_regularizer=kernel_regularizer)
389 | if global_pool:
390 | logits = tf.squeeze(logits, [1, 2])
391 | end_points['logits'] = logits
392 | return logits, end_points
393 | else:
394 | return end_points['pre_logits'], end_points
395 |
396 |
397 | revnet50 = functools.partial(revnet, num_layers=(3, 4, 6, 3))
398 |
--------------------------------------------------------------------------------
/train_and_eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2019 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # pylint: disable=line-too-long
18 | r"""The main script for starting training and evaluation.
19 |
20 | How to run:
21 | blaze run -c opt --config=dmtf_cuda \
22 | learning/brain/research/dune/experimental/representation/release/train_and_eval -- \
23 | --workdir /tmp/test \
24 | --config /google/src/cloud/akolesnikov/release/release/config/supervised/imagenet.py \
25 | --nouse_tpu
26 | """
27 | # pylint: enable=line-too-long
28 |
29 | from __future__ import absolute_import
30 | from __future__ import division
31 |
32 | import functools
33 | import math
34 | import os
35 |
36 | import absl.app as app
37 | import absl.flags as flags
38 | import absl.logging as logging
39 |
40 | import tensorflow as tf
41 | import tensorflow_hub as hub
42 |
43 | import datasets
44 | from self_supervision.self_supervision_lib import get_self_supervision_model
45 | import utils
46 |
47 | from tensorflow.contrib.cluster_resolver import TPUClusterResolver
48 |
49 |
50 | FLAGS = flags.FLAGS
51 |
52 | # General run setup flags.
53 | flags.DEFINE_string('workdir', None, 'Where to store files.')
54 | flags.mark_flag_as_required('workdir')
55 |
56 | flags.DEFINE_integer('num_gpus', 1, 'Number of GPUs to use.')
57 | flags.DEFINE_bool('use_tpu', True, 'Whether running on TPU or not.')
58 | flags.DEFINE_bool('run_eval', False, 'Run eval mode')
59 |
60 | flags.DEFINE_string('tpu_worker_name', 'tpu_worker',
61 | 'Name of a TPU worker.')
62 |
63 | # More detailed experiment flags
64 | flags.DEFINE_string('dataset', None, 'Which dataset to use, typically '
65 | '`imagenet`.')
66 | flags.mark_flag_as_required('dataset')
67 |
68 | flags.DEFINE_string('dataset_dir', None, 'Location of the dataset files.')
69 | flags.mark_flag_as_required('dataset_dir')
70 |
71 | flags.DEFINE_integer('eval_batch_size', None, 'Optional different batch-size'
72 | ' evaluation, defaults to the same as `batch_size`.')
73 |
74 | flags.DEFINE_integer('keep_checkpoint_every_n_hours', None, 'Keep one '
75 | 'checkpoint every this many hours. Otherwise, only the '
76 | 'last few ones are kept. Defaults to 4h.')
77 |
78 | flags.DEFINE_integer('random_seed', None, 'Seed to use. None is random.')
79 |
80 | flags.DEFINE_integer('save_checkpoints_secs', None, 'Every how many seconds '
81 | 'to save a checkpoint. Defaults to 600 ie every 10mins.')
82 |
83 | flags.DEFINE_string('serving_input_key', None, 'The name of the input tensor '
84 | 'in the generated hub module. Just leave it at default.')
85 |
86 | flags.DEFINE_string('serving_input_shape', None, 'The shape of the input tensor'
87 | ' in the stored hub module. Can contain `None`.')
88 |
89 | flags.DEFINE_string('signature', None, 'The name of the tensor to use as '
90 | 'representation for evaluation. Just leave to default.')
91 |
92 | flags.DEFINE_string('task', None, 'Which pretext-task to learn from. Can be '
93 | 'one of `rotation`, `exemplar`, `jigsaw`, '
94 | '`relative_patch_location`, `linear_eval`, `supervised`.')
95 | flags.mark_flag_as_required('task')
96 |
97 | flags.DEFINE_string('train_split', None, 'Which dataset split to train on. '
98 | 'Should only be `train` (default) or `trainval`.')
99 | flags.DEFINE_string('val_split', None, 'Which dataset split to eval on. '
100 | 'Should only be `val` (default) or `test`.')
101 |
102 | # Flags about the pretext tasks
103 |
104 | flags.DEFINE_integer('embed_dim', None, 'For most pretext tasks, which '
105 | 'dimension the embedding/hidden vector should be. '
106 | 'Defaults to 1000.')
107 |
108 | flags.DEFINE_float('margin', None, 'For the `exemplar` pretext task, '
109 | 'how large the triplet loss margin should be.')
110 |
111 | flags.DEFINE_integer('num_of_inception_patches', None, 'For the Exemplar '
112 | 'pretext task, how many instances of an image to create.')
113 |
114 | flags.DEFINE_integer('patch_jitter', None, 'For patch-based methods, by how '
115 | 'many pixels to jitter the patches. Defaults to 0.')
116 |
117 | flags.DEFINE_integer('perm_subset_size', None, 'Subset of permutations to '
118 | 'sample per example in the `jigsaw` pretext task. '
119 | 'Defaults to 8.')
120 |
121 | flags.DEFINE_integer('splits_per_side', None, 'For the `crop_patches` '
122 | 'preprocessor, how many times to split a side. '
123 | 'For example, 3 will result in 3x3=9 patches.')
124 |
125 | # Flags for evaluation.
126 | flags.DEFINE_string('eval_model', None, 'Whether to perform evaluation with a '
127 | '`linear` (default) model, or with an `mlp` model.')
128 |
129 | flags.DEFINE_string('hub_module', None, 'Folder where the hub module that '
130 | 'should be evaluated is stored.')
131 |
132 | flags.DEFINE_string('pool_mode', None, 'When running evaluation on '
133 | 'intermediate layers (not logits) of the network, it is '
134 | 'commonplace to pool the features down to 9000. This '
135 | 'decides the pooling method to be used: `adaptive_max` '
136 | '(default), `adaptive_avg`, `max`, or `avg`.')
137 |
138 | flags.DEFINE_string('combine_patches', None, 'When running evaluation on '
139 | 'patch models, it is used to merge patch representations'
140 | 'to the full image representation. The value should be set'
141 | 'to `avg_pool`(default), or `concat`.')
142 |
143 | # Flags about the model.
144 | flags.DEFINE_string('architecture', None,
145 | help='Which basic network architecture to use. '
146 | 'One of vgg19, resnet50, revnet50.')
147 | # flags.mark_flag_as_required('architecture') # Not required in eval mode.
148 |
149 | flags.DEFINE_integer('filters_factor', None, 'Widening factor for network '
150 | 'filters. For ResNet, default = 4 = vanilla ResNet.')
151 |
152 | flags.DEFINE_bool('last_relu', None, 'Whether to include (default) the final '
153 | 'ReLU layer in ResNet/RevNet models or not.')
154 |
155 | flags.DEFINE_string('mode', None, 'Which ResNet to use, `v1` or `v2`.')
156 |
157 | # Flags about the optimization process.
158 | flags.DEFINE_integer('batch_size', None, 'The global batch-size to use.')
159 | flags.mark_flag_as_required('batch_size')
160 |
161 | flags.DEFINE_string('decay_epochs', None, 'Optional list of epochs at which '
162 | 'learning-rate decay should happen, such as `15,25`.')
163 |
164 | flags.DEFINE_integer('epochs', None, 'Number of epochs to run training.')
165 | flags.mark_flag_as_required('epochs')
166 |
167 | flags.DEFINE_float('lr_decay_factor', None, 'Factor by which to decay the '
168 | 'learning-rate at each decay step. Default 0.1.')
169 |
170 | flags.DEFINE_float('lr', None, 'The base learning-rate to use for training.')
171 | flags.mark_flag_as_required('lr')
172 |
173 | flags.DEFINE_float('lr_scale_batch_size', None, 'The batch-size for which the '
174 | 'base learning-rate `lr` is defined. For batch-sizes '
175 | 'different from that, it is scaled linearly accordingly.'
176 | 'For example lr=0.1, batch_size=128, lr_scale_batch_size=32'
177 | ', then actual lr=0.025.')
178 | flags.mark_flag_as_required('lr_scale_batch_size')
179 |
180 | flags.DEFINE_string('optimizer', None, 'Which optimizer to use. '
181 | 'Only `sgd` (default) or `adam` are supported.')
182 |
183 | flags.DEFINE_integer('warmup_epochs', None, 'Duration of the linear learning-'
184 | 'rate warm-up (from 0 to actual). Defaults to 0.')
185 |
186 | flags.DEFINE_float('weight_decay', None, 'Strength of weight-decay. '
187 | 'Defaults to 1e-4, and may be set to 0.')
188 |
189 | # Flags about pre-processing/data augmentation.
190 | flags.DEFINE_string('crop_size', None, 'Size of the crop when using `crop` '
191 | 'or `central_crop` preprocessing. Either a single '
192 | 'integer like `32` or a pair like `32,24`.')
193 |
194 | flags.DEFINE_float('grayscale_probability', None, 'When using `to_gray` '
195 | 'preprocessing, probability of actually doing it. Defaults '
196 | 'to 1.0, i.e. deterministically grayscaling the input.')
197 |
198 | flags.DEFINE_string('preprocessing', None, 'A comma-separated list of '
199 | 'pre-processing steps to perform, see preprocess.py.')
200 | flags.mark_flag_as_required('preprocessing')
201 |
202 | flags.DEFINE_bool('randomize_resize_method', None, 'Whether or not (default) '
203 | 'to use a random interpolation method in the `resize` '
204 | 'preprocessor.')
205 |
206 | flags.DEFINE_string('resize_size', None, 'For the `resize`, '
207 | '`inception_preprocess`, and '
208 | '`crop_inception_preprocess_patches` preprocessors, the '
209 | 'size in pixels to which to resize the input. Can be a '
210 | 'single number for square, or a pair as `128,64`.')
211 |
212 | flags.DEFINE_integer('smaller_size', None, 'For the `resize_small` preprocessor'
213 | ', the desired size that the smaller side should have '
214 | 'after resizing the image (keeping aspect ratio).')
215 |
216 |
217 | # Number of iterations (=training steps) per TPU training loop. Use >100 for
218 | # good speed. This is the minimum number of steps between checkpoints.
219 | TPU_ITERATIONS_PER_LOOP = 500
220 |
221 |
222 | def train_and_eval():
223 | """Trains a network on (self) supervised data."""
224 | checkpoint_dir = os.path.join(FLAGS.workdir)
225 |
226 | if FLAGS.use_tpu:
227 | master = TPUClusterResolver(
228 | tpu=[os.environ['TPU_NAME']]).get_master()
229 | else:
230 | master = ''
231 |
232 | config = tf.contrib.tpu.RunConfig(
233 | model_dir=checkpoint_dir,
234 | tf_random_seed=FLAGS.get_flag_value('random_seed', None),
235 | master=master,
236 | evaluation_master=master,
237 | keep_checkpoint_every_n_hours=FLAGS.get_flag_value(
238 | 'keep_checkpoint_every_n_hours', 4),
239 | save_checkpoints_secs=FLAGS.get_flag_value('save_checkpoints_secs', 600),
240 | tpu_config=tf.contrib.tpu.TPUConfig(
241 | iterations_per_loop=TPU_ITERATIONS_PER_LOOP,
242 | tpu_job_name=FLAGS.tpu_worker_name))
243 |
244 | # The global batch-sizes are passed to the TPU estimator, and it will pass
245 | # along the local batch size in the model_fn's `params` argument dict.
246 | estimator = tf.contrib.tpu.TPUEstimator(
247 | model_fn=get_self_supervision_model(FLAGS.task),
248 | model_dir=checkpoint_dir,
249 | config=config,
250 | use_tpu=FLAGS.use_tpu,
251 | train_batch_size=FLAGS.batch_size,
252 | eval_batch_size=FLAGS.get_flag_value('eval_batch_size', FLAGS.batch_size))
253 |
254 | if FLAGS.run_eval:
255 | data_fn = functools.partial(
256 | datasets.get_data,
257 | split_name=FLAGS.get_flag_value('val_split', 'val'),
258 | is_training=False,
259 | shuffle=False,
260 | num_epochs=1,
261 | drop_remainder=FLAGS.use_tpu)
262 |
263 | # Contrary to what the documentation claims, the `train` and the
264 | # `evaluate` functions NEED to have `max_steps` and/or `steps` set and
265 | # cannot make use of the iterator's end-of-input exception, so we need
266 | # to do some math for that here.
267 | num_samples = datasets.get_count(FLAGS.get_flag_value('val_split', 'val'))
268 | num_steps = num_samples // FLAGS.get_flag_value('eval_batch_size',
269 | FLAGS.batch_size)
270 | tf.logging.info('val_steps: %d', num_steps)
271 |
272 | for checkpoint in tf.contrib.training.checkpoints_iterator(
273 | estimator.model_dir, timeout=10 * 60):
274 |
275 | estimator.evaluate(
276 | checkpoint_path=checkpoint, input_fn=data_fn, steps=num_steps)
277 |
278 | hub_exporter = hub.LatestModuleExporter('hub', serving_input_fn)
279 | hub_exporter.export(
280 | estimator,
281 | os.path.join(checkpoint_dir, 'export/hub'),
282 | checkpoint)
283 |
284 | if tf.gfile.Exists(os.path.join(FLAGS.workdir, 'TRAINING_IS_DONE')):
285 | break
286 |
287 | # Evaluates the latest checkpoint on validation set.
288 | result = estimator.evaluate(input_fn=data_fn, steps=num_steps)
289 | return result
290 |
291 | else:
292 | train_data_fn = functools.partial(
293 | datasets.get_data,
294 | split_name=FLAGS.get_flag_value('train_split', 'train'),
295 | is_training=True,
296 | num_epochs=int(math.ceil(FLAGS.epochs)),
297 | drop_remainder=True)
298 |
299 | # We compute the number of steps and make use of Estimator's max_steps
300 | # arguments instead of relying on the Dataset's iterator to run out after
301 | # a number of epochs so that we can use 'fractional' epochs, which are
302 | # used by regression tests. (And because TPUEstimator needs it anyways.)
303 | num_samples = datasets.get_count(FLAGS.get_flag_value('train_split',
304 | 'train'))
305 | # Depending on whether we drop the last batch each epoch or only at the
306 | # ver end, this should be ordered differently for rounding.
307 | updates_per_epoch = num_samples // FLAGS.batch_size
308 | num_steps = int(math.ceil(FLAGS.epochs * updates_per_epoch))
309 | tf.logging.info('train_steps: %d', num_steps)
310 |
311 | estimator.train(train_data_fn, max_steps=num_steps)
312 |
313 |
314 | def serving_input_fn():
315 | """A serving input fn."""
316 | input_shape = utils.str2intlist(
317 | FLAGS.get_flag_value('serving_input_shape', 'None,None,None,3'))
318 | image_features = {
319 | FLAGS.get_flag_value('serving_input_key', 'image'):
320 | tf.placeholder(dtype=tf.float32, shape=input_shape)}
321 | return tf.estimator.export.ServingInputReceiver(
322 | features=image_features, receiver_tensors=image_features)
323 |
324 |
325 | def main(unused_argv):
326 | # logging.info('config: %s', FLAGS)
327 | logging.info('workdir: %s', FLAGS.workdir)
328 |
329 | train_and_eval()
330 |
331 | logging.info('I\'m done with my work, ciao!')
332 |
333 |
334 | if __name__ == '__main__':
335 | app.run(main)
336 |
--------------------------------------------------------------------------------