├── .github
└── workflows
│ └── python-app.yml
├── .gitignore
├── .pylintrc
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── efficientdet
├── Det-AdvProp.md
├── README.md
├── __init__.py
├── aug
│ ├── __init__.py
│ ├── autoaugment.py
│ ├── autoaugment_test.py
│ ├── gridmask.py
│ ├── gridmask_test.py
│ ├── mosaic.py
│ └── mosaic_test.py
├── backbone
│ ├── __init__.py
│ ├── backbone_factory.py
│ ├── efficientnet_builder.py
│ ├── efficientnet_builder_test.py
│ ├── efficientnet_lite_builder.py
│ ├── efficientnet_lite_builder_test.py
│ ├── efficientnet_model.py
│ └── efficientnet_model_test.py
├── coco_metric.py
├── coco_metric_test.py
├── dataloader.py
├── dataloader_test.py
├── dataset
│ ├── README.md
│ ├── __init__.py
│ ├── create_coco_tfrecord.py
│ ├── create_coco_tfrecord_test.py
│ ├── create_pascal_tfrecord.py
│ ├── create_pascal_tfrecord_test.py
│ ├── inspect_tfrecords.py
│ ├── label_map_util.py
│ └── tfrecord_util.py
├── det_advprop_tutorial.ipynb
├── det_model_fn.py
├── det_model_fn_test.py
├── efficientdet_arch.py
├── efficientdet_arch_test.py
├── g3doc
│ ├── Det-AdvProp.png
│ ├── coco_ids.yaml
│ ├── faq.md
│ ├── flops.png
│ ├── network.png
│ ├── params.png
│ └── street.jpg
├── hparams_config.py
├── hparams_config_test.py
├── inference.py
├── install_deps.sh
├── iou_utils.py
├── iou_utils_test.py
├── main.py
├── model_inspect.py
├── model_inspect_test.py
├── nms_np.py
├── object_detection
│ ├── __init__.py
│ ├── argmax_matcher.py
│ ├── box_coder.py
│ ├── box_list.py
│ ├── faster_rcnn_box_coder.py
│ ├── matcher.py
│ ├── preprocessor.py
│ ├── region_similarity_calculator.py
│ ├── shape_utils.py
│ ├── target_assigner.py
│ └── tf_example_decoder.py
├── requirements.txt
├── run_tflite.py
├── tensorrt.py
├── test.sh
├── test_util.py
├── testdata
│ ├── img1-d1.jpg
│ └── img1.jpg
├── tf2
│ ├── README.md
│ ├── __init__.py
│ ├── anchors.py
│ ├── efficientdet_keras.py
│ ├── efficientdet_keras_test.py
│ ├── eval.py
│ ├── eval_tflite.py
│ ├── fpn_configs.py
│ ├── fpn_configs_test.py
│ ├── infer.py
│ ├── infer_lib.py
│ ├── infer_lib_test.py
│ ├── inspector.py
│ ├── inspector_test.py
│ ├── label_util.py
│ ├── postprocess.py
│ ├── postprocess_test.py
│ ├── segmentation.py
│ ├── tfmot.py
│ ├── train.py
│ ├── train_lib.py
│ ├── train_lib_test.py
│ ├── tutorial.ipynb
│ ├── util_keras.py
│ ├── util_keras_test.py
│ ├── wbf.py
│ └── wbf_test.py
├── tutorial.ipynb
├── utils.py
├── utils_test.py
└── visualize
│ ├── __init__.py
│ ├── shape_utils.py
│ ├── standard_fields.py
│ ├── static_shape.py
│ ├── vis_utils.py
│ └── vis_utils_test.py
├── efficientnetv2
├── README.md
├── autoaugment.py
├── autoaugment_test.py
├── cflags.py
├── datasets.py
├── datasets_test.py
├── effnetv2_configs.py
├── effnetv2_configs_test.py
├── effnetv2_model.py
├── effnetv2_model_test.py
├── g3doc
│ ├── effnetv2-l-gpu.png
│ ├── effnetv2-m-gpu.png
│ ├── effnetv2-s-gpu.png
│ ├── effnetv2-s-relu6-gpu.png
│ ├── imagenet1k_labels.txt
│ ├── imagenet21k_labels.txt
│ ├── param_flops.png
│ └── train_params.png
├── hparams.py
├── infer.py
├── main.py
├── main_tf2.py
├── mlir.py
├── preprocess_legacy.py
├── preprocessing.py
├── preprocessing_test.py
├── smoke_test.py
├── tfhub.ipynb
├── tutorial.ipynb
├── utils.py
└── utils_test.py
├── hero
├── LICENSE
├── config_lib.py
├── core.py
├── core_test.py
├── data_lib.py
├── example_commands.sh
├── fn_lib.py
├── main.py
├── model_lib.py
├── requirements.txt
├── vb100864_openmix_v1.model
└── vb32000_t5_cc.model
└── lion
├── README.md
├── fig
├── ablation.png
├── alg.png
├── basic.png
├── diffusion.png
├── ft.png
├── i1k.png
├── imagen.png
├── jft-ft.png
├── jft.png
├── lit.png
├── llm.png
├── lm.png
└── retrieval.png
├── lion_optax.py
├── lion_pytorch.py
├── lion_tf1.py
└── lion_tf2.py
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Efficientdet
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-18.04
16 | strategy:
17 | matrix:
18 | python-version: [3.7, 3.8, 3.9]
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install pylint
29 | bash efficientdet/install_deps.sh
30 | - name: Test with pytest
31 | run: |
32 | bash efficientdet/test.sh
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .DS_Store
3 | .vscode
4 | tmp
5 | .ropeproject
6 | .pyc
7 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Brain AutoML
2 |
3 | This repository contains a list of AutoML related models and libraries.
4 |
--------------------------------------------------------------------------------
/efficientdet/Det-AdvProp.md:
--------------------------------------------------------------------------------
1 | # Det-AdvProp
2 | [](https://colab.sandbox.google.com/github/google/automl/blob/master/efficientdet/det_advprop_tutorial.ipynb)
3 |
4 | [1] Xiangning Chen, Cihang Xie, Mingxing Tan, Li Zhang, Cho-Jui Hsieh, Boqing
5 | Gong. CVPR 2021. Arxiv link: https://arxiv.org/abs/2103.13886
6 |
7 | Det-AdvProp is a data augmentation technique specifically designed for the
8 | fine-tuning process of object detectors. It can consistently and substantially
9 | outperform the vanilla training and AutoAugment under various settings. The
10 | obtained detector is not only more accurate on clean images, but also more
11 | robust to image distortions and domain shift.
12 |
13 |
14 |
15 |
16 |
17 | ## 1. Accurate on Clean Images
18 |
19 | The following table includes a list of models trained with Det-AdvProp +
20 | AutoAugment (AA):
21 |
22 | Model | APtest | AP50 | AP75 | APS | APM | APL | APval | | #params | #FLOPs
23 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------- | --------------- | --------------- | -------------- | -------------- | -------------- | ---------------- | --- | ------- | :----:
24 | EfficientDet-D0 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d0.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d0.txt)) | 35.3 | 54.1 | 37.8 | 12.7 | 39.9 | 53.2 | 35.1 | | 3.9M | 2.54B
25 | EfficientDet-D1 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d1.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d1.txt)) | 40.9 | 60.0 | 44.1 | 19.1 | 45.6 | 57.2 | 40.8 | | 6.6M | 6.10B
26 | EfficientDet-D2 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d2.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d2.txt)) | 44.3 | 63.5 | 47.9 | 23.5 | 48.5 | 59.9 | 44.3 | | 8.1M | 11.0B
27 | EfficientDet-D3 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d3.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d3.txt)) | 48.0 | 67.1 | 52.2 | 28.1 | 51.8 | 62.8 | 47.7 | | 12.0M | 24.9B
28 | EfficientDet-D4 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d4.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d4.txt)) | 50.4 | 69.5 | 54.9 | 30.9 | 54.3 | 64.4 | 50.4 | | 20.7M | 55.2B
29 | EfficientDet-D5 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d5.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d5.txt)) | 52.5 | 71.8 | 57.2 | 34.6 | 55.9 | 65.2 | 52.2 | | 33.7M | 130B
30 |
31 | Unlike the vanilla EfficientDet that scales the image with mean and std,
32 | here we scale the input to the range of [-1, 1] to make it easier for performing
33 | adversarial attack. Please see [this Colab](https://github.com/google/automl/blob/master/efficientdet/det_advprop_tutorial.ipynb) for reproducing the
34 | results.
35 |
36 | ## 2. Robust Against Common Corruptions
37 |
38 | We test the detectors' robustness against common image corruptions (e.g.,
39 | Gaussian Noise, Snow, etc.) based on the COCO-C dataset in
40 | [this paper](https://arxiv.org/abs/1907.07484). The table below shows the
41 | comparison between vanilla training and Det-AdvProp + AutoAugment (AA):
42 |
43 | Model | mAP
44 | ---------------------- | ---------------
45 | EfficientDet-D0 | 21.4
46 | **+ Det-AdvProp + AA** | **22.7 (+1.3)**
47 | EfficientDet-D1 | 24.4
48 | **+ Det-AdvProp + AA** | **26.7 (+2.3)**
49 | EfficientDet-D2 | 26.7
50 | **+ Det-AdvProp + AA** | **28.9 (+2.2)**
51 | EfficientDet-D3 | 28.8
52 | **+ Det-AdvProp + AA** | **32.0 (+3.2)**
53 | EfficientDet-D4 | 30.1
54 | **+ Det-AdvProp + AA** | **33.9 (+3.8)**
55 | EfficientDet-D5 | 31.4
56 | **+ Det-AdvProp + AA** | **35.0 (+3.6)**
57 |
58 | ## 3. Robust Against Domain Shift
59 |
60 | PASCAL VOC 2012 only contains 20 classes, which are much smaller than the 80
61 | labeled classes in COCO. The underlying distributions of the two datasets are
62 | also different in the image content or the bounding box sizes and locations. We
63 | use the trained detectors to run inference directly on the VOC dataset to test
64 | their transferibility. We maintain the COCO evaluation metrics in this
65 | experiment:
66 |
67 | Model | mAP | AP50 | AP75
68 | ---------------------- | --------------- | --------------- | ---------------
69 | EfficientDet-D0 | 55.6 | 77.6 | 61.4
70 | **+ Det-AdvProp + AA** | **56.2 (+0.6)** | **78.3 (+0.7)** | **62.3 (+0.9)**
71 | EfficientDet-D1 | 60.8 | 82.0 | 66.7
72 | **+ Det-AdvProp + AA** | **61.3 (+0.5)** | **82.5 (+0.5)** | **67.6 (+0.9)**
73 | EfficientDet-D2 | 63.3 | 83.6 | 69.3
74 | **+ Det-AdvProp + AA** | **63.6 (+0.3)** | **84.0 (+0.4)** | **70.0 (+0.7)**
75 | EfficientDet-D3 | 65.7 | 85.3 | 71.8
76 | **+ Det-AdvProp + AA** | **66.4 (+0.7)** | **85.9 (+0.6)** | **72.8 (+1.0)**
77 | EfficientDet-D4 | 67.0 | 86.0 | 73.0
78 | **+ Det-AdvProp + AA** | **67.8 (+0.8)** | **87.0 (+1.0)** | **74.3 (+1.3)**
79 | EfficientDet-D5 | 67.4 | 86.9 | 73.8
80 | **+ Det-AdvProp + AA** | **68.7 (+1.3)** | **88.0 (+1.1)** | **75.4 (+1.6)**
81 |
--------------------------------------------------------------------------------
/efficientdet/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/efficientdet/aug/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/efficientdet/aug/autoaugment_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for Autoaugment."""
16 | from absl import logging
17 | import tensorflow.compat.v1 as tf
18 |
19 | from aug import autoaugment
20 |
21 |
22 | class AutoaugmentTest(tf.test.TestCase):
23 |
24 | def test_autoaugment_policy(self):
25 | # A very simple test to verify no syntax error.
26 | image = tf.placeholder(tf.uint8, shape=[640, 640, 3])
27 | bboxes = tf.placeholder(tf.float32, shape=[4, 4])
28 | autoaugment.distort_image_with_autoaugment(image, bboxes, 'test')
29 |
30 | def test_randaugment_policy(self):
31 | image = tf.placeholder(tf.uint8, shape=[320, 320, 3])
32 | bboxes = tf.placeholder(tf.float32, shape=[4, 4])
33 | autoaugment.distort_image_with_randaugment(image, bboxes, 1, 15)
34 |
35 |
36 | if __name__ == '__main__':
37 | logging.set_verbosity(logging.WARNING)
38 | tf.disable_eager_execution()
39 | tf.test.main()
40 |
41 |
--------------------------------------------------------------------------------
/efficientdet/aug/gridmask.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Grid Masking Augmentation Reference: https://arxiv.org/abs/2001.04086."""
16 | import math
17 |
18 | import tensorflow as tf
19 | from tensorflow_addons import image as image_ops
20 |
21 |
22 | class GridMask(object):
23 | """GridMask class for grid masking augmentation."""
24 |
25 | def __init__(self,
26 | prob=0.6,
27 | ratio=0.6,
28 | rotate=10,
29 | gridmask_size_ratio=0.5,
30 | fill=1,
31 | interpolation="BILINEAR"):
32 | """initialization.
33 |
34 | Args:
35 | prob: probablity of occurance.
36 | ratio: grid mask ratio i.e if 0.5 grid and spacing will be equal.
37 | rotate: Rotation of grid mesh.
38 | gridmask_size_ratio: Grid mask size, grid to image size ratio.
39 | fill: Fill value for grids.
40 | interpolation: Interpolation method for rotation.
41 | """
42 | self.prob = prob
43 | self.ratio = ratio
44 | self.rotate = rotate
45 | self.gridmask_size_ratio = gridmask_size_ratio
46 | self.fill = fill
47 | self.interpolation = interpolation
48 |
49 | @tf.function
50 | def random_rotate(self, mask):
51 | """Randomly rotates mask on given range."""
52 |
53 | angle = self.rotate * tf.random.normal([], -1, 1)
54 | angle = math.pi * angle / 180
55 | return image_ops.rotate(mask, angle, interpolation=self.interpolation)
56 |
57 | @staticmethod
58 | def crop(mask, h, w):
59 | """crops in middle of mask and image corners."""
60 | ww = hh = tf.shape(mask)[0]
61 | mask = mask[(hh - h) // 2:(hh - h) // 2 + h,
62 | (ww - w) // 2:(ww - w) // 2 + w,]
63 | return mask
64 |
65 | @tf.function
66 | def mask(self, h, w):
67 | """mask helper function for initializing grid mask of required size."""
68 | h = tf.cast(h, tf.float32)
69 | w = tf.cast(w, tf.float32)
70 | mask_w = mask_h = tf.cast(
71 | tf.cast(
72 | (self.gridmask_size_ratio + 1), tf.float32) * tf.math.maximum(h, w),
73 | tf.int32)
74 | self.mask_w = mask_w
75 | mask = tf.zeros(shape=[mask_h, mask_w], dtype=tf.int32)
76 | gridblock = tf.random.uniform(
77 | shape=[],
78 | minval=int(tf.math.minimum(h * 0.5, w * 0.3)),
79 | maxval=int(tf.math.maximum(h * 0.5, w * 0.3)) + 1,
80 | dtype=tf.int32)
81 |
82 | if self.ratio == 1:
83 | length = tf.random.uniform(
84 | shape=[], minval=1, maxval=gridblock + 1, dtype=tf.int32)
85 | else:
86 | length = tf.cast(
87 | tf.math.minimum(
88 | tf.math.maximum(
89 | int(tf.cast(gridblock, tf.float32) * self.ratio + 0.5), 1),
90 | gridblock - 1), tf.int32)
91 |
92 | for _ in range(2):
93 | start_w = tf.random.uniform(
94 | shape=[], minval=0, maxval=gridblock + 1, dtype=tf.int32)
95 | for i in range(mask_w // gridblock):
96 | start = gridblock * i + start_w
97 | end = tf.math.minimum(start + length, mask_w)
98 | indices = tf.reshape(tf.range(start, end), [end - start, 1])
99 | updates = (
100 | tf.ones(shape=[end - start, mask_w], dtype=tf.int32) * self.fill)
101 | mask = tf.tensor_scatter_nd_update(mask, indices, updates)
102 | mask = tf.transpose(mask)
103 |
104 | return mask
105 |
106 | def __call__(self, image, label):
107 | """Masks input image tensor with random grid mask."""
108 | h = tf.shape(image)[0]
109 | w = tf.shape(image)[1]
110 | grid = self.mask(h, w)
111 | grid = self.random_rotate(grid)
112 | mask = self.crop(grid, h, w)
113 | mask = tf.cast(mask, image.dtype)
114 | mask = tf.reshape(mask, (h, w))
115 | mask = (tf.expand_dims(mask, -1) if image._rank() != mask._rank() else mask)
116 | occur = tf.random.normal([], 0, 1) < self.prob
117 | image = tf.cond(occur, lambda: image * mask, lambda: image)
118 | return image, label
119 |
120 |
121 | def gridmask(image,
122 | boxes,
123 | prob=0.5,
124 | ratio=0.6,
125 | rotate=10,
126 | gridmask_size_ratio=0.5,
127 | fill=1):
128 | """Callable instance of GridMask and transforms input image."""
129 | gridmask_obj = GridMask(
130 | prob=prob,
131 | ratio=ratio,
132 | rotate=rotate,
133 | gridmask_size_ratio=gridmask_size_ratio,
134 | fill=fill)
135 | image, boxes = gridmask_obj(image, boxes)
136 | return image, boxes
137 |
--------------------------------------------------------------------------------
/efficientdet/aug/gridmask_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """GridMask Augmentation simple test."""
16 | from absl import logging
17 | import tensorflow.compat.v1 as tf
18 |
19 | from aug import gridmask
20 |
21 |
22 | class GridMaskTest(tf.test.TestCase):
23 |
24 | def setUp(self):
25 | super().setUp()
26 | tf.random.set_random_seed(111111)
27 |
28 | def test_gridmask_images(self):
29 | """Verify transformed image shape is valid and syntax check."""
30 | images = tf.random.uniform(
31 | shape=(512, 512, 3), minval=0, maxval=255, dtype=tf.float32)
32 | bboxes = tf.random.uniform(
33 | shape=(2, 4), minval=1, maxval=511, dtype=tf.int32)
34 | transform_images, _ = gridmask.gridmask(images, bboxes)
35 | self.assertEqual(images.shape[1], transform_images.shape[1])
36 |
37 | def test_gridmask_tiny_images(self):
38 | """Verify transform image shape on very tiny image."""
39 | images = tf.zeros(shape=(4, 4, 3))
40 | bboxes = tf.random.uniform(
41 | shape=(2, 4), minval=1, maxval=511, dtype=tf.int32)
42 | transform_images, _ = gridmask.gridmask(images, bboxes)
43 | self.assertEqual(images.shape[1], transform_images.shape[1])
44 |
45 | def test_rectangle_image_shape(self):
46 | """Verify transform image shape on rectangle image."""
47 | images = tf.zeros(shape=(1028, 512, 3))
48 | bboxes = tf.random.uniform(
49 | shape=(2, 4), minval=1, maxval=511, dtype=tf.int32)
50 | transform_images, _ = gridmask.gridmask(images, bboxes)
51 | self.assertEqual(images.shape[1], transform_images.shape[1])
52 |
53 |
54 | if __name__ == "__main__":
55 | logging.set_verbosity(logging.WARNING)
56 | tf.disable_eager_execution()
57 | tf.test.main()
58 |
--------------------------------------------------------------------------------
/efficientdet/aug/mosaic_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Mosaic Augmentation simple test."""
16 | from absl import logging
17 | import tensorflow.compat.v1 as tf
18 |
19 | from aug import mosaic
20 |
21 |
22 | class MosaicTest(tf.test.TestCase):
23 |
24 | def __init__(self, *args, **kwargs):
25 | super().__init__(*args, **kwargs)
26 | self.output_size = (512, 512)
27 | self.mosaic = mosaic.Mosaic(out_size=self.output_size)
28 | tf.random.set_random_seed(111111)
29 |
30 | def test_mosaic_boxes(self):
31 | """Verify num of boxes are valid and syntax check random four images."""
32 | images = tf.random.uniform(
33 | shape=(4, 512, 512, 3), minval=0, maxval=255, dtype=tf.float32)
34 | bboxes = tf.random.uniform(
35 | shape=(4, 2, 4), minval=1, maxval=511, dtype=tf.int32)
36 | _, mosaic_boxes = self.mosaic(images, bboxes)
37 | self.assertEqual(bboxes.shape[0], len(mosaic_boxes))
38 |
39 | def test_mosaic_tiny_images(self):
40 | images = tf.zeros(shape=(4, 4, 4, 3))
41 | bboxes = tf.random.uniform(
42 | shape=(4, 2, 4), minval=1, maxval=511, dtype=tf.int32)
43 | _, mosaic_boxes = self.mosaic(images, bboxes)
44 | self.assertEqual(bboxes.shape[0], len(mosaic_boxes))
45 |
46 |
47 | if __name__ == "__main__":
48 | logging.set_verbosity(logging.WARNING)
49 | tf.disable_eager_execution()
50 | tf.test.main()
51 |
--------------------------------------------------------------------------------
/efficientdet/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/efficientdet/backbone/backbone_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Backbone network factory."""
16 | import os
17 | from absl import logging
18 | import tensorflow as tf
19 |
20 | from backbone import efficientnet_builder
21 | from backbone import efficientnet_lite_builder
22 | from backbone import efficientnet_model
23 |
24 |
25 | def get_model_builder(model_name):
26 | """Get the model_builder module for a given model name."""
27 | if model_name.startswith('efficientnet-lite'):
28 | return efficientnet_lite_builder
29 | elif model_name.startswith('efficientnet-'):
30 | return efficientnet_builder
31 | else:
32 | raise ValueError('Unknown model name {}'.format(model_name))
33 |
34 |
35 | def get_model(model_name, override_params=None, model_dir=None):
36 | """A helper function to create and return model.
37 |
38 | Args:
39 | model_name: string, the predefined model name.
40 | override_params: A dictionary of params for overriding. Fields must exist in
41 | efficientnet_model.GlobalParams.
42 | model_dir: string, optional model dir for saving configs.
43 |
44 | Returns:
45 | created model
46 |
47 | Raises:
48 | When model_name specified an undefined model, raises NotImplementedError.
49 | When override_params has invalid fields, raises ValueError.
50 | """
51 |
52 | # For backward compatibility.
53 | if override_params and override_params.get('drop_connect_rate', None):
54 | override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
55 |
56 | if not override_params:
57 | override_params = {}
58 |
59 | if model_name.startswith('efficientnet-lite'):
60 | builder = efficientnet_lite_builder
61 | elif model_name.startswith('efficientnet-'):
62 | builder = efficientnet_builder
63 | else:
64 | raise ValueError('Unknown model name {}'.format(model_name))
65 |
66 | blocks_args, global_params = builder.get_model_params(model_name,
67 | override_params)
68 |
69 | if model_dir:
70 | param_file = os.path.join(model_dir, 'model_params.txt')
71 | if not tf.io.gfile.exists(param_file):
72 | if not tf.io.gfile.exists(model_dir):
73 | tf.io.gfile.mkdir(model_dir)
74 | with tf.io.gfile.GFile(param_file, 'w') as f:
75 | logging.info('writing to %s', param_file)
76 | f.write('model_name= %s\n\n' % model_name)
77 | f.write('global_params= %s\n\n' % str(global_params))
78 | f.write('blocks_args= %s\n\n' % str(blocks_args))
79 |
80 | return efficientnet_model.Model(blocks_args, global_params, model_name)
81 |
--------------------------------------------------------------------------------
/efficientdet/backbone/efficientnet_builder_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for efficientnet_builder."""
16 | from absl import logging
17 | import numpy as np
18 | import tensorflow.compat.v1 as tf
19 |
20 | from backbone import efficientnet_builder
21 |
22 |
23 | class EfficientnetBuilderTest(tf.test.TestCase):
24 |
25 | def _test_model_params(self,
26 | model_name,
27 | input_size,
28 | expected_params,
29 | override_params=None,
30 | features_only=False,
31 | pooled_features_only=False):
32 | images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32)
33 | efficientnet_builder.build_model(
34 | images,
35 | model_name=model_name,
36 | override_params=override_params,
37 | training=False,
38 | features_only=features_only,
39 | pooled_features_only=pooled_features_only)
40 | num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
41 | self.assertEqual(num_params, expected_params)
42 |
43 | def test_efficientnet_b0(self):
44 | self._test_model_params('efficientnet-b0', 224, expected_params=5288548)
45 |
46 | def test_efficientnet_b1(self):
47 | self._test_model_params('efficientnet-b1', 240, expected_params=7794184)
48 |
49 | def test_efficientnet_b2(self):
50 | self._test_model_params('efficientnet-b2', 260, expected_params=9109994)
51 |
52 | def test_efficientnet_b3(self):
53 | self._test_model_params('efficientnet-b3', 300, expected_params=12233232)
54 |
55 | def test_efficientnet_b4(self):
56 | self._test_model_params('efficientnet-b4', 380, expected_params=19341616)
57 |
58 | def test_efficientnet_b5(self):
59 | self._test_model_params('efficientnet-b5', 456, expected_params=30389784)
60 |
61 | def test_efficientnet_b6(self):
62 | self._test_model_params('efficientnet-b6', 528, expected_params=43040704)
63 |
64 | def test_efficientnet_b7(self):
65 | self._test_model_params('efficientnet-b7', 600, expected_params=66347960)
66 |
67 | def test_efficientnet_b0_with_customized_num_classes(self):
68 | self._test_model_params(
69 | 'efficientnet-b0',
70 | 224,
71 | expected_params=4135648,
72 | override_params={'num_classes': 100})
73 |
74 | def test_efficientnet_b0_with_features_only(self):
75 | self._test_model_params(
76 | 'efficientnet-b0', 224, features_only=True, expected_params=3595388)
77 |
78 | def test_efficientnet_b0_with_pooled_features_only(self):
79 | self._test_model_params(
80 | 'efficientnet-b0',
81 | 224,
82 | pooled_features_only=True,
83 | expected_params=4007548)
84 |
85 | def test_efficientnet_b0_fails_if_both_features_requested(self):
86 | with self.assertRaises(AssertionError):
87 | efficientnet_builder.build_model(
88 | None,
89 | model_name='efficientnet-b0',
90 | training=False,
91 | features_only=True,
92 | pooled_features_only=True)
93 |
94 | def test_efficientnet_b0_base(self):
95 | # Creates a base model using the model configuration.
96 | images = tf.zeros((1, 224, 224, 3), dtype=tf.float32)
97 | _, endpoints = efficientnet_builder.build_model_base(
98 | images, model_name='efficientnet-b0', training=False)
99 |
100 | # reduction_1 to reduction_5 should be in endpoints
101 | self.assertEqual(len(endpoints), 5)
102 |
103 |
104 | if __name__ == '__main__':
105 | logging.set_verbosity(logging.WARNING)
106 | # Disable eager to allow tf.profile works for #params/#flops.
107 | tf.disable_eager_execution()
108 | tf.test.main()
109 |
--------------------------------------------------------------------------------
/efficientdet/backbone/efficientnet_lite_builder_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for efficientnet_lite_builder."""
16 | from absl import logging
17 | import numpy as np
18 | import tensorflow.compat.v1 as tf
19 |
20 | from backbone import efficientnet_lite_builder
21 |
22 |
23 | class EfficientnetBuilderTest(tf.test.TestCase):
24 |
25 | def _test_model_params(self,
26 | model_name,
27 | input_size,
28 | expected_params,
29 | override_params=None,
30 | features_only=False,
31 | pooled_features_only=False):
32 | images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32)
33 | efficientnet_lite_builder.build_model(
34 | images,
35 | model_name=model_name,
36 | override_params=override_params,
37 | training=False,
38 | features_only=features_only,
39 | pooled_features_only=pooled_features_only)
40 | num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
41 |
42 | self.assertEqual(num_params, expected_params)
43 |
44 | def test_efficientnet_b0(self):
45 | self._test_model_params(
46 | 'efficientnet-lite0', 224, expected_params=4652008)
47 |
48 | def test_efficientnet_b1(self):
49 | self._test_model_params(
50 | 'efficientnet-lite1', 240, expected_params=5416680)
51 |
52 | def test_efficientnet_b2(self):
53 | self._test_model_params(
54 | 'efficientnet-lite2', 260, expected_params=6092072)
55 |
56 | def test_efficientnet_b3(self):
57 | self._test_model_params(
58 | 'efficientnet-lite3', 280, expected_params=8197096)
59 |
60 | def test_efficientnet_b4(self):
61 | self._test_model_params(
62 | 'efficientnet-lite4', 300, expected_params=13006568)
63 |
64 |
65 | if __name__ == '__main__':
66 | logging.set_verbosity(logging.WARNING)
67 | # Disable eager to allow tf.profile works for #params/#flops.
68 | tf.disable_eager_execution()
69 | tf.test.main()
70 |
--------------------------------------------------------------------------------
/efficientdet/coco_metric_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for coco_metric."""
16 |
17 | from absl import logging
18 | import tensorflow.compat.v1 as tf
19 | import coco_metric
20 |
21 |
22 | class CocoMetricTest(tf.test.TestCase):
23 |
24 | def setUp(self):
25 | super(CocoMetricTest, self).setUp()
26 | # [y1, x1, y2, x2, is_crowd, area, class], in image coords.
27 | self.groundtruth_data = tf.constant([[
28 | [10.0, 10.0, 20.0, 20.0, 0.0, 100.0, 1],
29 | [10.0, 10.0, 30.0, 15.0, 0.0, 100.0, 2],
30 | [30.0, 30.0, 40.0, 50.0, 0.0, 100.0, 3]
31 | ]], dtype=tf.float32)
32 | # [image_id, x, y, width, height, score, class]
33 | self.detections = tf.constant([[
34 | [1.0, 10.0, 10.0, 10.0, 10.0, 0.6, 1],
35 | [1.0, 10.0, 10.0, 5.0, 20.0, 0.5, 2]
36 | ]], dtype=tf.float32)
37 | self.class_labels = {1: 'car', 2: 'truck', 3: 'bicycle'}
38 |
39 | def test_mAP(self):
40 |
41 | eval_metric = coco_metric.EvaluationMetric(label_map=self.class_labels)
42 | coco_metrics = eval_metric.estimator_metric_fn(self.detections,
43 | self.groundtruth_data)
44 | self.assertEqual(len(coco_metrics.keys()), 15)
45 | self.assertAllClose(coco_metrics['AP'][0], 2.0/3.0)
46 | self.assertAllClose(coco_metrics['AP_/car'][0], 1.0)
47 | self.assertAllClose(coco_metrics['AP_/truck'][0], 1.0)
48 | self.assertAllClose(coco_metrics['AP_/bicycle'][0], 0.0)
49 |
50 |
51 | if __name__ == '__main__':
52 | logging.set_verbosity(logging.WARNING)
53 | tf.test.main()
54 |
--------------------------------------------------------------------------------
/efficientdet/dataloader_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the 'License');
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an 'AS IS' BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Data loader and processing test cases."""
16 |
17 | import tensorflow as tf
18 |
19 | import dataloader
20 | import hparams_config
21 | import test_util
22 |
23 | from tf2 import anchors
24 | from object_detection import tf_example_decoder
25 |
26 |
27 | class DataloaderTest(tf.test.TestCase):
28 |
29 | def test_parser(self):
30 | tf.random.set_seed(111111)
31 | params = hparams_config.get_detection_config('efficientdet-d0').as_dict()
32 | input_anchors = anchors.Anchors(params['min_level'], params['max_level'],
33 | params['num_scales'],
34 | params['aspect_ratios'],
35 | params['anchor_scale'],
36 | params['image_size'])
37 | anchor_labeler = anchors.AnchorLabeler(input_anchors, params['num_classes'])
38 | example_decoder = tf_example_decoder.TfExampleDecoder(
39 | regenerate_source_id=params['regenerate_source_id'])
40 | tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir())
41 | dataset = tf.data.TFRecordDataset([tfrecord_path])
42 | value = next(iter(dataset))
43 | reader = dataloader.InputReader(tfrecord_path, True)
44 | result = reader.dataset_parser(value, example_decoder, anchor_labeler,
45 | params)
46 | self.assertEqual(len(result), 11)
47 |
48 |
49 | if __name__ == '__main__':
50 | tf.test.main()
51 |
--------------------------------------------------------------------------------
/efficientdet/dataset/README.md:
--------------------------------------------------------------------------------
1 | This folder provides tools for converting raw coco/pascal data to tfrecord.
2 |
3 | ### 1. Convert COCO validation set to tfrecord:
4 |
5 | # Download coco data.
6 | !wget http://images.cocodataset.org/zips/val2017.zip
7 | !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
8 | !unzip val2017.zip
9 | !unzip annotations_trainval2017.zip
10 |
11 | # convert coco data to tfrecord.
12 | !mkdir tfrecord
13 | !PYTHONPATH=".:$PYTHONPATH" python dataset/create_coco_tfrecord.py \
14 | --image_dir=val2017 \
15 | --object_annotations_file=annotations/instances_val2017.json \
16 | --output_file_prefix=tfrecord/val \
17 | --num_shards=32
18 |
19 | ### 2. Convert Pascal VOC 2012 to tfrecord:
20 |
21 | # Download and convert pascal data.
22 | !wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
23 | !tar xf VOCtrainval_11-May-2012.tar
24 | !mkdir tfrecord
25 | !PYTHONPATH=".:$PYTHONPATH" python dataset/create_pascal_tfrecord.py \
26 | --data_dir=VOCdevkit --year=VOC2012 --output_path=tfrecord/pascal
27 |
28 | Attention: soure_id (or image_id) needs to be an integer due to the official COCO library requreiments.
29 |
--------------------------------------------------------------------------------
/efficientdet/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | # This library is mostly based on tensorflow object detection API
16 | # https://github.com/tensorflow/models/blob/master/research/object_detection/dataset_tools/create_coco_tf_record.py
17 |
--------------------------------------------------------------------------------
/efficientdet/dataset/create_pascal_tfrecord_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Test for create_pascal_tfrecord.py."""
16 |
17 | import os
18 |
19 | from absl import logging
20 | import numpy as np
21 | import PIL.Image
22 | import six
23 | import tensorflow as tf
24 |
25 | from dataset import create_pascal_tfrecord
26 |
27 |
28 | class CreatePascalTFRecordTest(tf.test.TestCase):
29 |
30 | def _assertProtoEqual(self, proto_field, expectation):
31 | """Helper function to assert if a proto field equals some value.
32 |
33 | Args:
34 | proto_field: The protobuf field to compare.
35 | expectation: The expected value of the protobuf field.
36 | """
37 | proto_list = [p for p in proto_field]
38 | self.assertListEqual(proto_list, expectation)
39 |
40 | def test_dict_to_tf_example(self):
41 | image_file_name = '2012_12.jpg'
42 | image_data = np.random.rand(256, 256, 3)
43 | save_path = os.path.join(self.get_temp_dir(), image_file_name)
44 | image = PIL.Image.fromarray(image_data, 'RGB')
45 | image.save(save_path)
46 |
47 | data = {
48 | 'folder':
49 | '',
50 | 'filename':
51 | image_file_name,
52 | 'size': {
53 | 'height': 256,
54 | 'width': 256,
55 | },
56 | 'object': [
57 | {
58 | 'difficult': 1,
59 | 'bndbox': {
60 | 'xmin': 64,
61 | 'ymin': 64,
62 | 'xmax': 192,
63 | 'ymax': 192,
64 | },
65 | 'name': 'person',
66 | 'truncated': 0,
67 | 'pose': '',
68 | },
69 | {
70 | 'difficult': 0,
71 | 'bndbox': {
72 | 'xmin': 128,
73 | 'ymin': 128,
74 | 'xmax': 256,
75 | 'ymax': 256,
76 | },
77 | 'name': 'notperson',
78 | 'truncated': 0,
79 | 'pose': '',
80 | },
81 | ],
82 | }
83 |
84 | label_map_dict = {
85 | 'background': 0,
86 | 'person': 1,
87 | 'notperson': 2,
88 | }
89 |
90 | ann_json_dict = {'images': [], 'annotations': [], 'categories': []}
91 | unique_id = create_pascal_tfrecord.UniqueId()
92 | example = create_pascal_tfrecord.dict_to_tf_example(
93 | data,
94 | self.get_temp_dir(),
95 | label_map_dict,
96 | unique_id,
97 | ann_json_dict=ann_json_dict)
98 | self.assertEqual(unique_id.image_id, 1)
99 | self.assertEqual(unique_id.ann_id, 2)
100 |
101 | self._assertProtoEqual(
102 | example.features.feature['image/height'].int64_list.value, [256])
103 | self._assertProtoEqual(
104 | example.features.feature['image/width'].int64_list.value, [256])
105 | self._assertProtoEqual(
106 | example.features.feature['image/filename'].bytes_list.value,
107 | [six.b(image_file_name)])
108 | self._assertProtoEqual(
109 | example.features.feature['image/source_id'].bytes_list.value,
110 | [six.b(str(1))])
111 | self._assertProtoEqual(
112 | example.features.feature['image/format'].bytes_list.value,
113 | [six.b('jpeg')])
114 | self._assertProtoEqual(
115 | example.features.feature['image/object/bbox/xmin'].float_list.value,
116 | [0.25, 0.5])
117 | self._assertProtoEqual(
118 | example.features.feature['image/object/bbox/ymin'].float_list.value,
119 | [0.25, 0.5])
120 | self._assertProtoEqual(
121 | example.features.feature['image/object/bbox/xmax'].float_list.value,
122 | [0.75, 1.0])
123 | self._assertProtoEqual(
124 | example.features.feature['image/object/bbox/ymax'].float_list.value,
125 | [0.75, 1.0])
126 | self._assertProtoEqual(
127 | example.features.feature['image/object/class/text'].bytes_list.value,
128 | [six.b('person'), six.b('notperson')])
129 | self._assertProtoEqual(
130 | example.features.feature['image/object/class/label'].int64_list.value,
131 | [1, 2])
132 | self._assertProtoEqual(
133 | example.features.feature['image/object/difficult'].int64_list.value,
134 | [1, 0])
135 | self._assertProtoEqual(
136 | example.features.feature['image/object/truncated'].int64_list.value,
137 | [0, 0])
138 | self._assertProtoEqual(
139 | example.features.feature['image/object/view'].bytes_list.value,
140 | [six.b(''), six.b('')])
141 |
142 | expected_ann_json_dict = {
143 | 'annotations': [{
144 | 'area': 16384,
145 | 'iscrowd': 0,
146 | 'image_id': 1,
147 | 'bbox': [64, 64, 128, 128],
148 | 'category_id': 1,
149 | 'id': 1,
150 | 'ignore': 0,
151 | 'segmentation': []
152 | }, {
153 | 'area': 16384,
154 | 'iscrowd': 0,
155 | 'image_id': 1,
156 | 'bbox': [128, 128, 128, 128],
157 | 'category_id': 2,
158 | 'id': 2,
159 | 'ignore': 0,
160 | 'segmentation': []
161 | }],
162 | 'categories': [],
163 | 'images': [{
164 | 'file_name': '2012_12.jpg',
165 | 'height': 256,
166 | 'width': 256,
167 | 'id': 1
168 | }]
169 | }
170 | self.assertEqual(ann_json_dict, expected_ann_json_dict)
171 |
172 |
173 | if __name__ == '__main__':
174 | logging.set_verbosity(logging.WARNING)
175 | tf.test.main()
176 |
--------------------------------------------------------------------------------
/efficientdet/dataset/inspect_tfrecords.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Inspect tfrecord dataset."""
16 | import os
17 | from absl import app
18 | from absl import flags
19 | from absl import logging
20 | import numpy as np
21 | from PIL import Image
22 | import tensorflow as tf
23 |
24 | import dataloader
25 | import hparams_config
26 | import utils
27 | from visualize import vis_utils
28 |
29 | flags.DEFINE_string('save_samples_dir', 'tfrecord_samples',
30 | 'Location of samples to save')
31 | flags.DEFINE_string('model_name', 'efficientdet-d0',
32 | 'model name for config and image_size')
33 | flags.DEFINE_string(
34 | 'hparams', '', 'Comma separated k=v pairs of hyperparameters or a module'
35 | ' containing attributes to use as hyperparameters.')
36 | flags.DEFINE_integer('samples', 10,
37 | 'Number of random samples for visualization.')
38 | flags.DEFINE_string('file_pattern', None,
39 | 'Glob for data files (e.g., COCO train - minival set)')
40 | flags.DEFINE_bool('eval', False, 'flag for file pattern mode i.e eval')
41 | FLAGS = flags.FLAGS
42 |
43 |
44 | class RecordInspect:
45 | """Inspection Class."""
46 |
47 | def __init__(self, config):
48 | """Initializes RecordInspect with passed config.
49 |
50 | Args:
51 | config: config file to initialize input_fn.
52 | """
53 | self.input_fn = dataloader.InputReader(
54 | FLAGS.file_pattern,
55 | is_training=not FLAGS.eval,
56 | use_fake_data=False,
57 | max_instances_per_image=config.max_instances_per_image)
58 |
59 | self.params = dict(
60 | config.as_dict(), batch_size=FLAGS.samples, model_name=FLAGS.model_name)
61 | logging.info(self.params)
62 | self.cls_to_label = config.label_map
63 | os.makedirs(FLAGS.save_samples_dir, exist_ok=True)
64 |
65 | def visualize(self):
66 | """save tfrecords images with bounding boxes."""
67 | vis_ds = self.input_fn(params=self.params)
68 | data = next(iter(vis_ds)) # iterable.
69 | images = data[0]
70 | gt_data = data[1]['groundtruth_data']
71 |
72 | # scales
73 | scale_to_org = data[1]['image_scales']
74 | scales = 1.0 / scale_to_org
75 | offset = tf.constant([0.485, 0.456, 0.406])
76 | offset = tf.reshape(offset, (1, 1, -1))
77 | scale_image = tf.constant([0.229, 0.224, 0.225])
78 | scale_image = tf.reshape(scale_image, (1, 1, -1))
79 |
80 | logging.info('Visualizing TfRecords %s', FLAGS.file_pattern)
81 | for i, zip_data in enumerate(zip(gt_data, images, scales)):
82 | gt, image, scale = zip_data
83 | boxes = gt[:, :4]
84 | boxes = boxes[np.any(boxes > 0, axis=1)].numpy()
85 | if boxes.shape[0] > 0:
86 | classes = gt[:boxes.shape[0], -1].numpy()
87 | try:
88 | category_index = {idx: {'id': idx, 'name': self.cls_to_label[idx]}
89 | for idx in np.asarray(classes, dtype=np.int)}
90 | except Exception: # pylint: disable=broad-except
91 | category_index = {}
92 |
93 | # unnormalize image.
94 | image *= scale_image
95 | image += offset
96 |
97 | # 0-255. range
98 | image = np.asarray(image.numpy() * 255., dtype=np.uint8)
99 |
100 | # scale to image_size
101 | boxes *= scale.numpy()
102 |
103 | image = vis_utils.visualize_boxes_and_labels_on_image_array(
104 | image,
105 | boxes=boxes,
106 | classes=classes.astype(int),
107 | scores=np.ones(boxes.shape[0]),
108 | category_index=category_index,
109 | line_thickness=2,
110 | skip_scores=True)
111 | image = Image.fromarray(image)
112 | image.save(os.path.join(FLAGS.save_samples_dir, f'sample{i}.jpg'))
113 |
114 |
115 | def main(_):
116 | # Parse and override hparams
117 | config = hparams_config.get_detection_config(FLAGS.model_name)
118 | config.override(FLAGS.hparams)
119 |
120 | # Parse image size in case it is in string format.
121 | config.image_size = utils.parse_image_size(config.image_size)
122 | try:
123 | recordinspect = RecordInspect(config)
124 | recordinspect.visualize()
125 | except Exception as e: # pylint: disable=broad-except
126 | logging.error(e)
127 | else:
128 | logging.info('Done, please find samples at %s', FLAGS.save_samples_dir)
129 |
130 |
131 | if __name__ == '__main__':
132 | app.run(main)
133 |
--------------------------------------------------------------------------------
/efficientdet/dataset/label_map_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Label map utility functions."""
16 | from absl import logging
17 |
18 |
19 | def _validate_label_map(label_map):
20 | """Checks if a label map is valid.
21 |
22 | Args:
23 | label_map: StringIntLabelMap to validate.
24 |
25 | Raises:
26 | ValueError: if label map is invalid.
27 | """
28 | for item in label_map.item:
29 | if item.id < 0:
30 | raise ValueError('Label map ids should be >= 0.')
31 | if (item.id == 0 and item.name != 'background' and
32 | item.display_name != 'background'):
33 | raise ValueError('Label map id 0 is reserved for the background label')
34 |
35 |
36 | def create_category_index(categories):
37 | """Creates dictionary of COCO compatible categories keyed by category id.
38 |
39 | Args:
40 | categories: a list of dicts, each of which has the following keys:
41 | 'id': (required) an integer id uniquely identifying this category.
42 | 'name': (required) string representing category name
43 | e.g., 'cat', 'dog', 'pizza'.
44 |
45 | Returns:
46 | category_index: a dict containing the same entries as categories, but keyed
47 | by the 'id' field of each category.
48 | """
49 | category_index = {}
50 | for cat in categories:
51 | category_index[cat['id']] = cat
52 | return category_index
53 |
54 |
55 | def get_max_label_map_index(label_map):
56 | """Get maximum index in label map.
57 |
58 | Args:
59 | label_map: a StringIntLabelMapProto
60 |
61 | Returns:
62 | an integer
63 | """
64 | return max([item.id for item in label_map.item])
65 |
66 |
67 | def convert_label_map_to_categories(label_map,
68 | max_num_classes,
69 | use_display_name=True):
70 | """Given label map proto returns categories list compatible with eval.
71 |
72 | This function converts label map proto and returns a list of dicts, each of
73 | which has the following keys:
74 | 'id': (required) an integer id uniquely identifying this category.
75 | 'name': (required) string representing category name
76 | e.g., 'cat', 'dog', 'pizza'.
77 | 'keypoints': (optional) a dictionary of keypoint string 'label' to integer
78 | 'id'.
79 | We only allow class into the list if its id-label_id_offset is
80 | between 0 (inclusive) and max_num_classes (exclusive).
81 | If there are several items mapping to the same id in the label map,
82 | we will only keep the first one in the categories list.
83 |
84 | Args:
85 | label_map: a StringIntLabelMapProto or None. If None, a default categories
86 | list is created with max_num_classes categories.
87 | max_num_classes: maximum number of (consecutive) label indices to include.
88 | use_display_name: (boolean) choose whether to load 'display_name' field as
89 | category name. If False or if the display_name field does not exist, uses
90 | 'name' field as category names instead.
91 |
92 | Returns:
93 | categories: a list of dictionaries representing all possible categories.
94 | """
95 | categories = []
96 | list_of_ids_already_added = []
97 | if not label_map:
98 | label_id_offset = 1
99 | for class_id in range(max_num_classes):
100 | categories.append({
101 | 'id': class_id + label_id_offset,
102 | 'name': 'category_{}'.format(class_id + label_id_offset)
103 | })
104 | return categories
105 | for item in label_map.item:
106 | if not 0 < item.id <= max_num_classes:
107 | logging.info(
108 | 'Ignore item %d since it falls outside of requested '
109 | 'label range.', item.id)
110 | continue
111 | if use_display_name and item.HasField('display_name'):
112 | name = item.display_name
113 | else:
114 | name = item.name
115 | if item.id not in list_of_ids_already_added:
116 | list_of_ids_already_added.append(item.id)
117 | category = {'id': item.id, 'name': name}
118 | if item.keypoints:
119 | keypoints = {}
120 | list_of_keypoint_ids = []
121 | for kv in item.keypoints:
122 | if kv.id in list_of_keypoint_ids:
123 | raise ValueError('Duplicate keypoint ids are not allowed. '
124 | 'Found {} more than once'.format(kv.id))
125 | keypoints[kv.label] = kv.id
126 | list_of_keypoint_ids.append(kv.id)
127 | category['keypoints'] = keypoints
128 | categories.append(category)
129 | return categories
130 |
131 |
132 | def create_class_agnostic_category_index():
133 | """Creates a category index with a single `object` class."""
134 | return {1: {'id': 1, 'name': 'object'}}
135 |
--------------------------------------------------------------------------------
/efficientdet/dataset/tfrecord_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""TFRecord related utilities."""
16 | import tensorflow as tf
17 |
18 |
19 | def int64_feature(value):
20 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
21 |
22 |
23 | def int64_list_feature(value):
24 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
25 |
26 |
27 | def bytes_feature(value):
28 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
29 |
30 |
31 | def bytes_list_feature(value):
32 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
33 |
34 |
35 | def float_list_feature(value):
36 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
37 |
38 |
39 | def read_examples_list(path):
40 | """Read list of training or validation examples.
41 |
42 | The file is assumed to contain a single example per line where the first
43 | token in the line is an identifier that allows us to find the image and
44 | annotation xml for that example.
45 |
46 | For example, the line:
47 | xyz 3
48 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored).
49 |
50 | Args:
51 | path: absolute path to examples list file.
52 |
53 | Returns:
54 | list of example identifiers (strings).
55 | """
56 | with tf.io.gfile.GFile(path) as fid:
57 | lines = fid.readlines()
58 | return [line.strip().split(' ')[0] for line in lines]
59 |
60 |
61 | def recursive_parse_xml_to_dict(xml):
62 | """Recursively parses XML contents to python dict.
63 |
64 | We assume that `object` tags are the only ones that can appear
65 | multiple times at the same level of a tree.
66 |
67 | Args:
68 | xml: xml tree obtained by parsing XML file contents using lxml.etree
69 |
70 | Returns:
71 | Python dictionary holding XML contents.
72 | """
73 | if not len(xml): # pylint: disable=g-explicit-length-test
74 | return {xml.tag: xml.text if xml.text else ''}
75 | result = {}
76 | for child in xml:
77 | child_result = recursive_parse_xml_to_dict(child)
78 | if child.tag != 'object':
79 | result[child.tag] = child_result[child.tag]
80 | else:
81 | if child.tag not in result:
82 | result[child.tag] = []
83 | result[child.tag].append(child_result[child.tag])
84 | return {xml.tag: result}
85 |
86 |
87 | def open_sharded_output_tfrecords(exit_stack, base_path, num_shards):
88 | """Opens all TFRecord shards for writing and adds them to an exit stack.
89 |
90 | Args:
91 | exit_stack: A context2.ExitStack used to automatically closed the TFRecords
92 | opened in this function.
93 | base_path: The base path for all shards
94 | num_shards: The number of shards
95 |
96 | Returns:
97 | The list of opened TFRecords. Position k in the list corresponds to shard k.
98 | """
99 | tf_record_output_filenames = [
100 | '{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards)
101 | for idx in range(num_shards)
102 | ]
103 |
104 | tfrecords = [
105 | exit_stack.enter_context(tf.io.TFRecordWriter(file_name))
106 | for file_name in tf_record_output_filenames
107 | ]
108 |
109 | return tfrecords
110 |
--------------------------------------------------------------------------------
/efficientdet/det_model_fn_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for det_model_fn."""
16 | import tensorflow as tf
17 | import det_model_fn
18 |
19 |
20 | def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _=0):
21 | """A legacy focal loss that does not support label smoothing."""
22 | with tf.name_scope('focal_loss'):
23 | positive_label_mask = tf.equal(targets, 1.0)
24 | cross_entropy = (
25 | tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))
26 |
27 | neg_logits = -1.0 * logits
28 | modulator = tf.exp(gamma * targets * neg_logits -
29 | gamma * tf.math.log1p(tf.exp(neg_logits)))
30 | loss = modulator * cross_entropy
31 | weighted_loss = tf.where(positive_label_mask, alpha * loss,
32 | (1.0 - alpha) * loss)
33 | weighted_loss /= normalizer
34 | return weighted_loss
35 |
36 |
37 | class FocalLossTest(tf.test.TestCase):
38 |
39 | def test_focal_loss(self):
40 | tf.random.set_seed(1111)
41 | y_pred = tf.random.uniform([4, 32, 32, 90])
42 | y_true = tf.ones([4, 32, 32, 90])
43 | alpha, gamma, n = 0.25, 1.5, 100
44 | legacy_output = legacy_focal_loss(y_pred, y_true, alpha, gamma, n)
45 | new_output = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n)
46 | self.assertAllClose(legacy_output, new_output)
47 |
48 | def test_focal_loss_with_label_smoothing(self):
49 | tf.random.set_seed(1111)
50 | shape = [2, 2, 2, 2]
51 | y_pred = tf.random.uniform(shape)
52 |
53 | # A binary classification target [0.0, 1.0] becomes [.1, .9]
54 | # with smoothing .2
55 | y_true = tf.ones(shape) * [0.0, 1.0]
56 | y_true_presmoothed = tf.ones(shape) * [0.1, 0.9]
57 |
58 | alpha, gamma, n = 1, 0, 100
59 | presmoothed = det_model_fn.focal_loss(y_pred, y_true_presmoothed, alpha,
60 | gamma, n, 0)
61 | alpha, gamma, n = 0.9, 0, 100
62 | unsmoothed = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n, 0.2)
63 |
64 | self.assertAllClose(presmoothed, unsmoothed)
65 |
66 |
67 | if __name__ == '__main__':
68 | tf.test.main()
69 |
--------------------------------------------------------------------------------
/efficientdet/g3doc/Det-AdvProp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/Det-AdvProp.png
--------------------------------------------------------------------------------
/efficientdet/g3doc/coco_ids.yaml:
--------------------------------------------------------------------------------
1 | 0: background
2 | 1: person
3 | 2: bicycle
4 | 3: car
5 | 4: motorcycle
6 | 5: airplane
7 | 6: bus
8 | 7: train
9 | 8: truck
10 | 9: boat
11 | 10: traffic light
12 | 11: fire hydrant
13 | 13: stop sign
14 | 14: parking meter
15 | 15: bench
16 | 16: bird
17 | 17: cat
18 | 18: dog
19 | 19: horse
20 | 20: sheep
21 | 21: cow
22 | 22: elephant
23 | 23: bear
24 | 24: zebra
25 | 25: giraffe
26 | 27: backpack
27 | 28: umbrella
28 | 31: handbag
29 | 32: tie
30 | 33: suitcase
31 | 34: frisbee
32 | 35: skis
33 | 36: snowboard
34 | 37: sports ball
35 | 38: kite
36 | 39: baseball bat
37 | 40: baseball glove
38 | 41: skateboard
39 | 42: surfboard
40 | 43: tennis racket
41 | 44: bottle
42 | 46: wine glass
43 | 47: cup
44 | 48: fork
45 | 49: knife
46 | 50: spoon
47 | 51: bowl
48 | 52: banana
49 | 53: apple
50 | 54: sandwich
51 | 55: orange
52 | 56: broccoli
53 | 57: carrot
54 | 58: hot dog
55 | 59: pizza
56 | 60: donut
57 | 61: cake
58 | 62: chair
59 | 63: couch
60 | 64: potted plant
61 | 65: bed
62 | 67: dining table
63 | 70: toilet
64 | 72: tv
65 | 73: laptop
66 | 74: mouse
67 | 75: remote
68 | 76: keyboard
69 | 77: cell phone
70 | 78: microwave
71 | 79: oven
72 | 80: toaster
73 | 81: sink
74 | 82: refrigerator
75 | 84: book
76 | 85: clock
77 | 86: vase
78 | 87: scissors
79 | 88: teddy bear
80 | 89: hair drier
81 | 90: toothbrush
82 |
--------------------------------------------------------------------------------
/efficientdet/g3doc/faq.md:
--------------------------------------------------------------------------------
1 | # EfficientDet FQA
2 |
3 |
7 |
8 | [TOC]
9 |
10 | ## 1. For Users
11 |
12 | ### 1.1 How can I convert the saved model to tflite?
13 |
14 | Unfortunately, there is no way to do that with the current public tensorflow
15 | release due to some issues in tf converter. We have some internal fixes, which
16 | could potentially be available with the next TensorFlow release.
17 |
18 | ### 1.2 Why I see NaN during my training and how to debug it?
19 |
20 | Because we use batch norm, which needs reasonable batch size. If your batch size
21 | is too small, it may causes NaN. (We may add group norm to deal with this in
22 | futurre)
23 |
24 | If you see NaN, you can check the followings:
25 |
26 | - Is my batch size too small? It usually needs to be >=8.
27 | - Should I clip my gradient? How about h.clip_gradients_norm=5.0?
28 | - Should I use smaller jitter? How about jitter_min=0.8 and jitter_max=1.2?
29 |
30 | If you want to debug it, you can use these tools:
31 |
32 | ```
33 | tf.compat.v1.add_check_numerics_ops() # for Tensorflow 1.x
34 | tf.debugging.disable_check_numerics() # for TensorFlow 2.x
35 | ```
36 |
37 | ### 1.3 Why my last class eval AP is always zero?
38 |
39 | The current code assume class 0 is always reserved for background, so you if you K classes, then you should set num_classes=K+1.
40 |
41 | See [#391](https://github.com/google/automl/issues/391) and [#398](https://github.com/google/automl/issues/398) for more discussion.
42 |
43 | ### 1.4 Why my input pipeline has assert failure?
44 |
45 | This is most likely that your dataset has some images with many objects (more
46 | than the 100 limit for COCO), you should set --hparams="max_instances_per_image=200" or larger.
47 |
48 | See [#93](https://github.com/google/automl/issues/93) for more discussion.
49 |
50 |
51 | ## 2. For Developers
52 |
53 | ### 2.1 How can I format my code for PRs?
54 |
55 | Please use [yapf](https://github.com/google/yapf) with option
56 | --style='{based_on_style: yapf}'. You can also save the
57 | following file to ~/.config/yapf/style:
58 |
59 | [style]
60 | based_on_style = yapf
61 |
62 | If you want to check the format with lint, please run:
63 |
64 | !pylint --rcfile=../.pylintrc your_file.py
65 |
66 | ### 2.2 How can I run all tests?
67 |
68 | !export PYTHONPATH="`pwd`:$PYTHONPATH"
69 | !find . -name "*_test.py" | parallel python &> /tmp/test.log \
70 | && echo "All passed" || echo "Failed! Search keyword FAILED in /tmp/test.log"
71 |
--------------------------------------------------------------------------------
/efficientdet/g3doc/flops.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/flops.png
--------------------------------------------------------------------------------
/efficientdet/g3doc/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/network.png
--------------------------------------------------------------------------------
/efficientdet/g3doc/params.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/params.png
--------------------------------------------------------------------------------
/efficientdet/g3doc/street.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/street.jpg
--------------------------------------------------------------------------------
/efficientdet/hparams_config_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ======================================
15 | """Tests for hparams_config."""
16 | import os
17 | import tempfile
18 | from absl import logging
19 | import tensorflow.compat.v1 as tf
20 | import yaml
21 |
22 | import hparams_config
23 |
24 |
25 | class HparamsConfigTest(tf.test.TestCase):
26 |
27 | def test_config_override(self):
28 | c = hparams_config.Config({'a': 1, 'b': 2})
29 | self.assertEqual(c.as_dict(), {'a': 1, 'b': 2})
30 |
31 | c.update({'a': 10})
32 | self.assertEqual(c.as_dict(), {'a': 10, 'b': 2})
33 |
34 | c.b = 20
35 | self.assertEqual(c.as_dict(), {'a': 10, 'b': 20})
36 |
37 | c.override('a=true,b=ss')
38 | self.assertEqual(c.as_dict(), {'a': True, 'b': 'ss'})
39 |
40 | c.override('a=100,,,b=2.3,') # extra ',' is fine.
41 | self.assertEqual(c.as_dict(), {'a': 100, 'b': 2.3})
42 |
43 | c.override('a=2x3,b=50') # a is a special format for image size.
44 | self.assertEqual(c.as_dict(), {'a': '2x3', 'b': 50})
45 |
46 | # overrride string must be in the format of xx=yy.
47 | with self.assertRaises(ValueError):
48 | c.override('a=true,invalid_string')
49 |
50 | def test_config_yaml(self):
51 | tmpdir = tempfile.gettempdir()
52 | yaml_file_path = os.path.join(tmpdir, 'x.yaml')
53 | with open(yaml_file_path, 'w') as f:
54 | f.write("""
55 | x: 2
56 | y:
57 | z: 'test'
58 | """)
59 | c = hparams_config.Config(dict(x=234, y=2342))
60 | c.override(yaml_file_path)
61 | self.assertEqual(c.as_dict(), {'x': 2, 'y': {'z': 'test'}})
62 |
63 | yaml_file_path2 = os.path.join(tmpdir, 'y.yaml')
64 | c.save_to_yaml(yaml_file_path2)
65 | with open(yaml_file_path2, 'r') as f:
66 | config_dict = yaml.load(f, Loader=yaml.FullLoader)
67 | self.assertEqual(config_dict, {'x': 2, 'y': {'z': 'test'}})
68 |
69 | def test_config_override_recursive(self):
70 | c = hparams_config.Config({'x': 1})
71 | self.assertEqual(c.as_dict(), {'x': 1})
72 | c.override('y.y0=2,y.y1=3', allow_new_keys=True)
73 | self.assertEqual(c.as_dict(), {'x': 1, 'y': {'y0': 2, 'y1': 3}})
74 | c.update({'y': {'y0': 5, 'y1': {'y11': 100}}})
75 | self.assertEqual(c.as_dict(), {'x': 1, 'y': {'y0': 5, 'y1': {'y11': 100}}})
76 | self.assertEqual(c.y.y1.y11, 100)
77 |
78 | def test_config_override_list(self):
79 | c = hparams_config.Config({'x': [1.0, 2.0]})
80 | self.assertEqual(c.as_dict(), {'x': [1.0, 2.0]})
81 | c.override('x=3.0*4.0*5.0')
82 | self.assertEqual(c.as_dict(), {'x': [3.0, 4.0, 5.0]})
83 |
84 |
85 | if __name__ == '__main__':
86 | logging.set_verbosity(logging.WARNING)
87 | tf.test.main()
88 |
--------------------------------------------------------------------------------
/efficientdet/install_deps.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | for line in $(cat efficientdet/requirements.txt)
18 | do
19 | pip install $line
20 | done
21 |
--------------------------------------------------------------------------------
/efficientdet/iou_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ======================================
15 | """Tests for iou_utils."""
16 | from absl import logging
17 | import tensorflow as tf
18 | import iou_utils
19 |
20 |
21 | class IouUtilsTest(tf.test.TestCase):
22 | """IoU test class."""
23 |
24 | def setUp(self):
25 | super(IouUtilsTest, self).setUp()
26 | self.pb = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]],
27 | dtype=tf.float32)
28 | self.tb = tf.constant(
29 | [[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], dtype=tf.float32)
30 | self.zeros = tf.constant([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=tf.float32)
31 |
32 | def test_iou(self):
33 | self.assertAllClose(
34 | iou_utils.iou_loss(self.pb, self.tb, 'iou'), [0.875, 1.])
35 |
36 | def test_ciou(self):
37 | self.assertAllClose(
38 | iou_utils.iou_loss(self.pb, self.tb, 'ciou'), [0.99931306, 1.6415315])
39 |
40 | def test_diou(self):
41 | self.assertAllClose(
42 | iou_utils.iou_loss(self.pb, self.tb, 'diou'), [0.9969512, 1.6243094])
43 |
44 | def test_giou(self):
45 | self.assertAllClose(
46 | iou_utils.iou_loss(self.pb, self.tb, 'giou'), [1.075000, 1.933333])
47 |
48 | def test_iou_zero_target(self):
49 | self.assertAllClose(
50 | iou_utils.iou_loss(self.pb, self.zeros, 'iou'), [0.0, 0.0])
51 | self.assertAllClose(
52 | iou_utils.iou_loss(self.pb, self.zeros, 'ciou'), [0.0, 0.0])
53 | self.assertAllClose(
54 | iou_utils.iou_loss(self.pb, self.zeros, 'diou'), [0.0, 0.0])
55 | self.assertAllClose(
56 | iou_utils.iou_loss(self.pb, self.zeros, 'giou'), [0.0, 0.0])
57 |
58 | def test_iou_multiple_anchors(self):
59 | pb = tf.tile(self.pb, [1, 2])
60 | tb = tf.tile(self.tb, [1, 2])
61 | self.assertAllClose(iou_utils.iou_loss(pb, tb, 'iou'), [1.75, 2.0])
62 |
63 | def test_iou_multiple_anchors_mixed(self):
64 | pb = tf.concat([self.pb, self.zeros], axis=-1)
65 | tb = tf.concat([self.tb, self.zeros], axis=-1)
66 | self.assertAllClose(iou_utils.iou_loss(pb, tb, 'iou'), [0.875, 1.0])
67 |
68 | def test_ciou_grad(self):
69 | pb = tf.concat([self.pb, self.zeros], axis=-1)
70 | tb = tf.concat([self.tb, self.zeros], axis=-1)
71 | with tf.GradientTape() as tape:
72 | tape.watch([pb, tb])
73 | loss = iou_utils.iou_loss(pb, tb, 'ciou')
74 | grad = tape.gradient(loss, [tb, pb])
75 | self.assertAlmostEqual(tf.reduce_sum(grad[0]).numpy(), 0.1476393)
76 | self.assertAlmostEqual(tf.reduce_sum(grad[1]).numpy(), -0.14763935)
77 |
78 |
79 | if __name__ == '__main__':
80 | logging.set_verbosity(logging.WARNING)
81 | tf.test.main()
82 |
--------------------------------------------------------------------------------
/efficientdet/object_detection/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/efficientdet/object_detection/box_coder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Base box coder.
16 |
17 | Box coders convert between coordinate frames, namely image-centric
18 | (with (0,0) on the top left of image) and anchor-centric (with (0,0) being
19 | defined by a specific anchor).
20 |
21 | Users of a BoxCoder can call two methods:
22 | encode: which encodes a box with respect to a given anchor
23 | (or rather, a tensor of boxes wrt a corresponding tensor of anchors) and
24 | decode: which inverts this encoding with a decode operation.
25 | In both cases, the arguments are assumed to be in 1-1 correspondence already;
26 | it is not the job of a BoxCoder to perform matching.
27 | """
28 | from abc import ABCMeta
29 | from abc import abstractmethod
30 | from abc import abstractproperty
31 |
32 | import tensorflow.compat.v1 as tf
33 |
34 |
35 | # Box coder types.
36 | FASTER_RCNN = 'faster_rcnn'
37 | KEYPOINT = 'keypoint'
38 | MEAN_STDDEV = 'mean_stddev'
39 | SQUARE = 'square'
40 |
41 |
42 | class BoxCoder(object):
43 | """Abstract base class for box coder."""
44 | __metaclass__ = ABCMeta
45 |
46 | @abstractproperty
47 | def code_size(self):
48 | """Return the size of each code.
49 |
50 | This number is a constant and should agree with the output of the `encode`
51 | op (e.g. if rel_codes is the output of self.encode(...), then it should have
52 | shape [N, code_size()]). This abstractproperty should be overridden by
53 | implementations.
54 |
55 | Returns:
56 | an integer constant
57 | """
58 | pass
59 |
60 | def encode(self, boxes, anchors):
61 | """Encode a box list relative to an anchor collection.
62 |
63 | Args:
64 | boxes: BoxList holding N boxes to be encoded
65 | anchors: BoxList of N anchors
66 |
67 | Returns:
68 | a tensor representing N relative-encoded boxes
69 | """
70 | with tf.name_scope('Encode'):
71 | return self._encode(boxes, anchors)
72 |
73 | def decode(self, rel_codes, anchors):
74 | """Decode boxes that are encoded relative to an anchor collection.
75 |
76 | Args:
77 | rel_codes: a tensor representing N relative-encoded boxes
78 | anchors: BoxList of anchors
79 |
80 | Returns:
81 | boxlist: BoxList holding N boxes encoded in the ordinary way (i.e.,
82 | with corners y_min, x_min, y_max, x_max)
83 | """
84 | with tf.name_scope('Decode'):
85 | return self._decode(rel_codes, anchors)
86 |
87 | @abstractmethod
88 | def _encode(self, boxes, anchors):
89 | """Method to be overridden by implementations.
90 |
91 | Args:
92 | boxes: BoxList holding N boxes to be encoded
93 | anchors: BoxList of N anchors
94 |
95 | Returns:
96 | a tensor representing N relative-encoded boxes
97 | """
98 | pass
99 |
100 | @abstractmethod
101 | def _decode(self, rel_codes, anchors):
102 | """Method to be overridden by implementations.
103 |
104 | Args:
105 | rel_codes: a tensor representing N relative-encoded boxes
106 | anchors: BoxList of anchors
107 |
108 | Returns:
109 | boxlist: BoxList holding N boxes encoded in the ordinary way (i.e.,
110 | with corners y_min, x_min, y_max, x_max)
111 | """
112 | pass
113 |
114 |
115 | def batch_decode(encoded_boxes, box_coder, anchors):
116 | """Decode a batch of encoded boxes.
117 |
118 | This op takes a batch of encoded bounding boxes and transforms
119 | them to a batch of bounding boxes specified by their corners in
120 | the order of [y_min, x_min, y_max, x_max].
121 |
122 | Args:
123 | encoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
124 | code_size] representing the location of the objects.
125 | box_coder: a BoxCoder object.
126 | anchors: a BoxList of anchors used to encode `encoded_boxes`.
127 |
128 | Returns:
129 | decoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
130 | coder_size] representing the corners of the objects in the order
131 | of [y_min, x_min, y_max, x_max].
132 |
133 | Raises:
134 | ValueError: if batch sizes of the inputs are inconsistent, or if
135 | the number of anchors inferred from encoded_boxes and anchors are
136 | inconsistent.
137 | """
138 | encoded_boxes.get_shape().assert_has_rank(3)
139 | if encoded_boxes.get_shape()[1].value != anchors.num_boxes_static():
140 | raise ValueError('The number of anchors inferred from encoded_boxes'
141 | ' and anchors are inconsistent: shape[1] of encoded_boxes'
142 | ' %s should be equal to the number of anchors: %s.' %
143 | (encoded_boxes.get_shape()[1].value,
144 | anchors.num_boxes_static()))
145 |
146 | decoded_boxes = tf.stack([
147 | box_coder.decode(boxes, anchors).get()
148 | for boxes in tf.unstack(encoded_boxes)
149 | ])
150 | return decoded_boxes
151 |
--------------------------------------------------------------------------------
/efficientdet/object_detection/faster_rcnn_box_coder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Faster RCNN box coder.
16 |
17 | Faster RCNN box coder follows the coding schema described below:
18 | ty = (y - ya) / ha
19 | tx = (x - xa) / wa
20 | th = log(h / ha)
21 | tw = log(w / wa)
22 | where x, y, w, h denote the box's center coordinates, width and height
23 | respectively. Similarly, xa, ya, wa, ha denote the anchor's center
24 | coordinates, width and height. tx, ty, tw and th denote the anchor-encoded
25 | center, width and height respectively.
26 |
27 | See http://arxiv.org/abs/1506.01497 for details.
28 | """
29 |
30 | import tensorflow.compat.v1 as tf
31 |
32 | from object_detection import box_coder
33 | from object_detection import box_list
34 |
35 | EPSILON = 1e-8
36 |
37 |
38 | class FasterRcnnBoxCoder(box_coder.BoxCoder):
39 | """Faster RCNN box coder."""
40 |
41 | def __init__(self, scale_factors=None):
42 | """Constructor for FasterRcnnBoxCoder.
43 |
44 | Args:
45 | scale_factors: List of 4 positive scalars to scale ty, tx, th and tw.
46 | If set to None, does not perform scaling. For Faster RCNN,
47 | the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0].
48 | """
49 | if scale_factors:
50 | assert len(scale_factors) == 4
51 | for scalar in scale_factors:
52 | assert scalar > 0
53 | self._scale_factors = scale_factors
54 |
55 | @property
56 | def code_size(self):
57 | return 4
58 |
59 | def _encode(self, boxes, anchors):
60 | """Encode a box collection with respect to anchor collection.
61 |
62 | Args:
63 | boxes: BoxList holding N boxes to be encoded.
64 | anchors: BoxList of anchors.
65 |
66 | Returns:
67 | a tensor representing N anchor-encoded boxes of the format
68 | [ty, tx, th, tw].
69 | """
70 | # Convert anchors to the center coordinate representation.
71 | ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
72 | ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes()
73 | # Avoid NaN in division and log below.
74 | ha = tf.maximum(EPSILON, ha)
75 | wa = tf.maximum(EPSILON, wa)
76 | h = tf.maximum(EPSILON, h)
77 | w = tf.maximum(EPSILON, w)
78 |
79 | tx = (xcenter - xcenter_a) / wa
80 | ty = (ycenter - ycenter_a) / ha
81 | tw = tf.log(w / wa)
82 | th = tf.log(h / ha)
83 | # Scales location targets as used in paper for joint training.
84 | if self._scale_factors:
85 | ty *= self._scale_factors[0]
86 | tx *= self._scale_factors[1]
87 | th *= self._scale_factors[2]
88 | tw *= self._scale_factors[3]
89 | return tf.transpose(tf.stack([ty, tx, th, tw]))
90 |
91 | def _decode(self, rel_codes, anchors):
92 | """Decode relative codes to boxes.
93 |
94 | Args:
95 | rel_codes: a tensor representing N anchor-encoded boxes.
96 | anchors: BoxList of anchors.
97 |
98 | Returns:
99 | boxes: BoxList holding N bounding boxes.
100 | """
101 | ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
102 |
103 | ty, tx, th, tw = tf.unstack(tf.transpose(rel_codes))
104 | if self._scale_factors:
105 | ty /= self._scale_factors[0]
106 | tx /= self._scale_factors[1]
107 | th /= self._scale_factors[2]
108 | tw /= self._scale_factors[3]
109 | w = tf.exp(tw) * wa
110 | h = tf.exp(th) * ha
111 | ycenter = ty * ha + ycenter_a
112 | xcenter = tx * wa + xcenter_a
113 | ymin = ycenter - h / 2.
114 | xmin = xcenter - w / 2.
115 | ymax = ycenter + h / 2.
116 | xmax = xcenter + w / 2.
117 | return box_list.BoxList(tf.transpose(tf.stack([ymin, xmin, ymax, xmax])))
118 |
--------------------------------------------------------------------------------
/efficientdet/object_detection/region_similarity_calculator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Region Similarity Calculators for BoxLists.
16 |
17 | Region Similarity Calculators compare a pairwise measure of similarity
18 | between the boxes in two BoxLists.
19 | """
20 | from abc import ABCMeta
21 | from abc import abstractmethod
22 |
23 | import tensorflow.compat.v1 as tf
24 |
25 |
26 | def area(boxlist, scope=None):
27 | """Computes area of boxes.
28 |
29 | Args:
30 | boxlist: BoxList holding N boxes
31 | scope: name scope.
32 |
33 | Returns:
34 | a tensor with shape [N] representing box areas.
35 | """
36 | with tf.name_scope(scope, 'Area'):
37 | y_min, x_min, y_max, x_max = tf.split(
38 | value=boxlist.get(), num_or_size_splits=4, axis=1)
39 | return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])
40 |
41 |
42 | def intersection(boxlist1, boxlist2, scope=None):
43 | """Compute pairwise intersection areas between boxes.
44 |
45 | Args:
46 | boxlist1: BoxList holding N boxes
47 | boxlist2: BoxList holding M boxes
48 | scope: name scope.
49 |
50 | Returns:
51 | a tensor with shape [N, M] representing pairwise intersections
52 | """
53 | with tf.name_scope(scope, 'Intersection'):
54 | y_min1, x_min1, y_max1, x_max1 = tf.split(
55 | value=boxlist1.get(), num_or_size_splits=4, axis=1)
56 | y_min2, x_min2, y_max2, x_max2 = tf.split(
57 | value=boxlist2.get(), num_or_size_splits=4, axis=1)
58 | all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2))
59 | all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2))
60 | intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin)
61 | all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2))
62 | all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2))
63 | intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin)
64 | return intersect_heights * intersect_widths
65 |
66 |
67 | def iou(boxlist1, boxlist2, scope=None):
68 | """Computes pairwise intersection-over-union between box collections.
69 |
70 | Args:
71 | boxlist1: BoxList holding N boxes
72 | boxlist2: BoxList holding M boxes
73 | scope: name scope.
74 |
75 | Returns:
76 | a tensor with shape [N, M] representing pairwise iou scores.
77 | """
78 | with tf.name_scope(scope, 'IOU'):
79 | intersections = intersection(boxlist1, boxlist2)
80 | areas1 = area(boxlist1)
81 | areas2 = area(boxlist2)
82 | unions = (
83 | tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections)
84 | return tf.where(
85 | tf.equal(intersections, 0.0),
86 | tf.zeros_like(intersections), tf.truediv(intersections, unions))
87 |
88 |
89 | class RegionSimilarityCalculator(object):
90 | """Abstract base class for region similarity calculator."""
91 | __metaclass__ = ABCMeta
92 |
93 | def compare(self, boxlist1, boxlist2, scope=None):
94 | """Computes matrix of pairwise similarity between BoxLists.
95 |
96 | This op (to be overridden) computes a measure of pairwise similarity between
97 | the boxes in the given BoxLists. Higher values indicate more similarity.
98 |
99 | Note that this method simply measures similarity and does not explicitly
100 | perform a matching.
101 |
102 | Args:
103 | boxlist1: BoxList holding N boxes.
104 | boxlist2: BoxList holding M boxes.
105 | scope: Op scope name. Defaults to 'Compare' if None.
106 |
107 | Returns:
108 | a (float32) tensor of shape [N, M] with pairwise similarity score.
109 | """
110 | with tf.name_scope(scope, 'Compare', [boxlist1, boxlist2]) as scope:
111 | return self._compare(boxlist1, boxlist2)
112 |
113 | @abstractmethod
114 | def _compare(self, boxlist1, boxlist2):
115 | pass
116 |
117 |
118 | class IouSimilarity(RegionSimilarityCalculator):
119 | """Class to compute similarity based on Intersection over Union (IOU) metric.
120 |
121 | This class computes pairwise similarity between two BoxLists based on IOU.
122 | """
123 |
124 | def _compare(self, boxlist1, boxlist2):
125 | """Compute pairwise IOU similarity between the two BoxLists.
126 |
127 | Args:
128 | boxlist1: BoxList holding N boxes.
129 | boxlist2: BoxList holding M boxes.
130 |
131 | Returns:
132 | A tensor with shape [N, M] representing pairwise iou scores.
133 | """
134 | return iou(boxlist1, boxlist2)
135 |
--------------------------------------------------------------------------------
/efficientdet/object_detection/shape_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utils used to manipulate tensor shapes."""
16 |
17 | import tensorflow.compat.v1 as tf
18 |
19 |
20 | def assert_shape_equal(shape_a, shape_b):
21 | """Asserts that shape_a and shape_b are equal.
22 |
23 | If the shapes are static, raises a ValueError when the shapes
24 | mismatch.
25 |
26 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
27 | mismatch.
28 |
29 | Args:
30 | shape_a: a list containing shape of the first tensor.
31 | shape_b: a list containing shape of the second tensor.
32 |
33 | Returns:
34 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
35 | when the shapes are dynamic.
36 |
37 | Raises:
38 | ValueError: When shapes are both static and unequal.
39 | """
40 | if (all(isinstance(dim, int) for dim in shape_a) and
41 | all(isinstance(dim, int) for dim in shape_b)):
42 | if shape_a != shape_b:
43 | raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
44 | else: return tf.no_op()
45 | else:
46 | return tf.assert_equal(shape_a, shape_b)
47 |
48 |
49 | def combined_static_and_dynamic_shape(tensor):
50 | """Returns a list containing static and dynamic values for the dimensions.
51 |
52 | Returns a list of static and dynamic values for shape dimensions. This is
53 | useful to preserve static shapes when available in reshape operation.
54 |
55 | Args:
56 | tensor: A tensor of any type.
57 |
58 | Returns:
59 | A list of size tensor.shape.ndims containing integers or a scalar tensor.
60 | """
61 | static_tensor_shape = tensor.shape.as_list()
62 | dynamic_tensor_shape = tf.shape(tensor)
63 | combined_shape = []
64 | for index, dim in enumerate(static_tensor_shape):
65 | if dim is not None:
66 | combined_shape.append(dim)
67 | else:
68 | combined_shape.append(dynamic_tensor_shape[index])
69 | return combined_shape
70 |
--------------------------------------------------------------------------------
/efficientdet/requirements.txt:
--------------------------------------------------------------------------------
1 | lxml>=4.6.1
2 | absl-py>=0.10.0
3 | matplotlib>=3.0.3
4 | numpy>=1.19.4,<1.24.0
5 | Pillow>=9.5.0
6 | PyYAML>=5.1
7 | six>=1.15.0
8 | tensorflow>=2.10.0,<2.16.0
9 | tensorflow-addons>=0.18.0
10 | tensorflow-hub>=0.11
11 | neural-structured-learning>=1.3.1
12 | Cython>=0.29.13
13 | git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
14 |
--------------------------------------------------------------------------------
/efficientdet/run_tflite.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Run TF Lite model."""
16 | from absl import app
17 | from absl import flags
18 |
19 | from PIL import Image
20 | import tensorflow as tf
21 |
22 | import inference
23 |
24 | FLAGS = flags.FLAGS
25 |
26 |
27 | def define_flags():
28 | """Define flags."""
29 | flags.DEFINE_string('tflite_path', None, 'Path of tflite file.')
30 | flags.DEFINE_string('sample_image', None, 'Sample image path')
31 | flags.DEFINE_string('output_image', None, 'Output image path')
32 | flags.DEFINE_string('image_size', '512x512', 'Image size "WxH".')
33 |
34 |
35 | def load_image(image_path, image_size):
36 | """Loads an image, and returns numpy.ndarray.
37 |
38 | Args:
39 | image_path: str, path to image.
40 | image_size: list of int, representing [width, height].
41 |
42 | Returns:
43 | image_batch: numpy.ndarray of shape [1, H, W, C].
44 | """
45 | input_data = tf.io.gfile.GFile(image_path, 'rb').read()
46 | image = tf.io.decode_image(input_data, channels=3, dtype=tf.uint8)
47 | image = tf.image.resize(
48 | image, image_size, method='bilinear', antialias=True)
49 | return tf.expand_dims(tf.cast(image, tf.uint8), 0).numpy()
50 |
51 |
52 | def save_visualized_image(image, prediction, output_path):
53 | """Saves the visualized image with prediction.
54 |
55 | Args:
56 | image: numpy.ndarray of shape [H, W, C].
57 | prediction: numpy.ndarray of shape [num_predictions, 7].
58 | output_path: str, output image path.
59 | """
60 | output_image = inference.visualize_image_prediction(
61 | image,
62 | prediction,
63 | label_map='coco')
64 | Image.fromarray(output_image).save(output_path)
65 |
66 |
67 | class TFLiteRunner:
68 | """Wrapper to run TFLite model."""
69 |
70 | def __init__(self, model_path):
71 | """Init.
72 |
73 | Args:
74 | model_path: str, path to tflite model.
75 | """
76 | self.interpreter = tf.lite.Interpreter(model_path=model_path)
77 | self.interpreter.allocate_tensors()
78 | self.input_index = self.interpreter.get_input_details()[0]['index']
79 | self.output_index = self.interpreter.get_output_details()[0]['index']
80 |
81 | def run(self, image):
82 | """Run inference on a single images.
83 |
84 | Args:
85 | image: numpy.ndarray of shape [1, H, W, C].
86 |
87 | Returns:
88 | prediction: numpy.ndarray of shape [1, num_detections, 7].
89 | """
90 | self.interpreter.set_tensor(self.input_index, image)
91 | self.interpreter.invoke()
92 | return self.interpreter.get_tensor(self.output_index)
93 |
94 |
95 | def main(_):
96 | image_size = [int(dim) for dim in FLAGS.image_size.split('x')]
97 | image = load_image(FLAGS.sample_image, image_size)
98 |
99 | runner = TFLiteRunner(FLAGS.tflite_path)
100 | prediction = runner.run(image)
101 |
102 | save_visualized_image(image[0], prediction[0], FLAGS.output_image)
103 |
104 |
105 | if __name__ == '__main__':
106 | define_flags()
107 | app.run(main)
108 |
--------------------------------------------------------------------------------
/efficientdet/tensorrt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Simple tools for TensorRT.
16 |
17 | Example usage:
18 |
19 | $ export ROOT=/tmp/d4
20 | $ python model_inspect.py --runmode=freeze --model_name=efficientdet-d4 \
21 | --logdir=$ROOT # --hparams=xyz.yaml
22 | $ python tensorrt.py --tf_savedmodel_dir=$ROOT/savedmodel \
23 | --trt_savedmodel_dir=$ROOT/trtmodel
24 | """
25 | import time
26 | from absl import app
27 | from absl import flags
28 | import numpy as np
29 | import tensorflow.compat.v1 as tf
30 | # pylint: disable=g-direct-tensorflow-import
31 | from tensorflow.python.compiler.tensorrt import trt_convert as trt
32 |
33 | flags.DEFINE_string('tf_savedmodel_dir', None, 'TensorFlow saved model dir.')
34 | flags.DEFINE_string('trt_savedmodel_dir', None, 'TensorRT saved model dir.')
35 | FLAGS = flags.FLAGS
36 |
37 |
38 | def convert2trt(tf_savedmodel_dir: str, trt_savedmodel_dir: str):
39 | converter = trt.TrtGraphConverter(
40 | input_saved_model_dir=tf_savedmodel_dir,
41 | max_workspace_size_bytes=(2 << 20),
42 | precision_mode='FP16',
43 | maximum_cached_engines=1)
44 | converter.convert()
45 | converter.save(trt_savedmodel_dir)
46 |
47 |
48 | def benchmark(trt_savedmodel_dir: str, warmup_runs: int = 5, bm_runs: int = 20):
49 | """Benchmark TRT latency for a given TRT saved model."""
50 | with tf.Session() as sess:
51 | # First load the Saved Model into the session
52 | tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
53 | trt_savedmodel_dir)
54 | graph = tf.get_default_graph()
55 | input_shape = graph.get_tensor_by_name('input:0').shape
56 | x = np.ones(input_shape).astype(np.float32)
57 | ss = lambda i: '' if i == 0 else '_%d' % i
58 | outputs = ['box_net/box-predict%s/BiasAdd:0' % ss(i) for i in range(1)]
59 | outputs += ['class_net/class-predict%s/BiasAdd:0' % ss(i) for i in range(5)]
60 | # Apply reduce_sum to avoid massive data move between GPU and CPU.
61 | outputs = [tf.reduce_sum(graph.get_tensor_by_name(i)) for i in outputs]
62 |
63 | # warmup
64 | for _ in range(warmup_runs):
65 | sess.run(outputs, feed_dict={'input:0': x})
66 | # benchmark
67 | s = time.perf_counter()
68 | for _ in range(bm_runs):
69 | sess.run(outputs, feed_dict={'input:0': x})
70 | e = time.perf_counter()
71 | print('Benchmark latency=%.4f FPS=%.2f', (e - s) / bm_runs,
72 | bm_runs / (e - s))
73 |
74 |
75 | def main(_):
76 | if FLAGS.tf_savedmodel_dir:
77 | convert2trt(FLAGS.tf_savedmodel_dir, FLAGS.trt_savedmodel_dir)
78 | benchmark(FLAGS.trt_savedmodel_dir, FLAGS.warmup_runs, FLAGS.bm_runs)
79 |
80 |
81 | if __name__ == '__main__':
82 | flags.mark_flag_as_required('trt_savedmodel_dir')
83 | tf.disable_v2_behavior()
84 | app.run(main)
85 |
--------------------------------------------------------------------------------
/efficientdet/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | for file in `find $PWD/efficientdet -name '*.py'`
17 | do
18 | pylint --rcfile=.pylintrc $file
19 | done
20 |
21 | cd efficientdet
22 | for file in `find $PWD -name '*_test.py'`
23 | do
24 | PYTHONPATH=$PWD TF_CPP_MIN_LOG_LEVEL=1 python $file
25 | done
--------------------------------------------------------------------------------
/efficientdet/test_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Test utilities."""
16 | import os
17 |
18 | import tensorflow as tf
19 | from dataset import tfrecord_util
20 |
21 |
22 | def make_fake_tfrecord(temp_dir):
23 | """Makes fake TFRecord to test input."""
24 | tfrecord_path = os.path.join(temp_dir, 'test.tfrecords')
25 | writer = tf.io.TFRecordWriter(tfrecord_path)
26 | encoded_jpg = tf.io.encode_jpeg(tf.ones([512, 512, 3], dtype=tf.uint8))
27 | example = tf.train.Example(
28 | features=tf.train.Features(
29 | feature={
30 | 'image/height':
31 | tfrecord_util.int64_feature(512),
32 | 'image/width':
33 | tfrecord_util.int64_feature(512),
34 | 'image/filename':
35 | tfrecord_util.bytes_feature('test_file_name.jpg'.encode(
36 | 'utf8')),
37 | 'image/source_id':
38 | tfrecord_util.bytes_feature('123456'.encode('utf8')),
39 | 'image/key/sha256':
40 | tfrecord_util.bytes_feature('qwdqwfw12345'.encode('utf8')),
41 | 'image/encoded':
42 | tfrecord_util.bytes_feature(encoded_jpg.numpy()),
43 | 'image/format':
44 | tfrecord_util.bytes_feature('jpeg'.encode('utf8')),
45 | 'image/object/bbox/xmin':
46 | tfrecord_util.float_list_feature([0.1]),
47 | 'image/object/bbox/xmax':
48 | tfrecord_util.float_list_feature([0.1]),
49 | 'image/object/bbox/ymin':
50 | tfrecord_util.float_list_feature([0.2]),
51 | 'image/object/bbox/ymax':
52 | tfrecord_util.float_list_feature([0.2]),
53 | 'image/object/class/text':
54 | tfrecord_util.bytes_list_feature(['test'.encode('utf8')]),
55 | 'image/object/class/label':
56 | tfrecord_util.int64_list_feature([1]),
57 | 'image/object/difficult':
58 | tfrecord_util.int64_list_feature([]),
59 | 'image/object/truncated':
60 | tfrecord_util.int64_list_feature([]),
61 | 'image/object/view':
62 | tfrecord_util.bytes_list_feature([]),
63 | }))
64 | writer.write(example.SerializeToString())
65 | return tfrecord_path
66 |
--------------------------------------------------------------------------------
/efficientdet/testdata/img1-d1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/testdata/img1-d1.jpg
--------------------------------------------------------------------------------
/efficientdet/testdata/img1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/testdata/img1.jpg
--------------------------------------------------------------------------------
/efficientdet/tf2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/tf2/__init__.py
--------------------------------------------------------------------------------
/efficientdet/tf2/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Eval libraries."""
16 | from absl import app
17 | from absl import flags
18 | from absl import logging
19 | import tensorflow as tf
20 |
21 | import coco_metric
22 | import dataloader
23 | import hparams_config
24 | import utils
25 | from tf2 import anchors
26 | from tf2 import efficientdet_keras
27 | from tf2 import label_util
28 | from tf2 import postprocess
29 | from tf2 import util_keras
30 |
31 | # Cloud TPU Cluster Resolvers
32 | flags.DEFINE_string('tpu', None, 'The Cloud TPU name.')
33 | flags.DEFINE_string('gcp_project', None, 'Project name.')
34 | flags.DEFINE_string('tpu_zone', None, 'GCE zone name.')
35 |
36 | flags.DEFINE_integer('eval_samples', None, 'Number of eval samples.')
37 | flags.DEFINE_string('val_file_pattern', None,
38 | 'Glob for eval tfrecords, e.g. coco/val-*.tfrecord.')
39 | flags.DEFINE_string('val_json_file', None,
40 | 'Groudtruth, e.g. annotations/instances_val2017.json.')
41 | flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name to use.')
42 | flags.DEFINE_string('model_dir', None, 'Location of the checkpoint to run.')
43 | flags.DEFINE_integer('batch_size', 8, 'GLobal batch size.')
44 | flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file')
45 | FLAGS = flags.FLAGS
46 |
47 |
48 | def main(_):
49 | config = hparams_config.get_efficientdet_config(FLAGS.model_name)
50 | config.override(FLAGS.hparams)
51 | config.val_json_file = FLAGS.val_json_file
52 | config.nms_configs.max_nms_inputs = anchors.MAX_DETECTION_POINTS
53 | config.drop_remainder = False # eval all examples w/o drop.
54 | config.image_size = utils.parse_image_size(config['image_size'])
55 |
56 | if config.strategy == 'tpu':
57 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
58 | FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
59 | tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
60 | tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
61 | ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
62 | logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
63 | elif config.strategy == 'gpus':
64 | ds_strategy = tf.distribute.MirroredStrategy()
65 | logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
66 | else:
67 | if tf.config.list_physical_devices('GPU'):
68 | ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
69 | else:
70 | ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')
71 |
72 | with ds_strategy.scope():
73 | # Network
74 | model = efficientdet_keras.EfficientDetNet(config=config)
75 | model.build((None, *config.image_size, 3))
76 | util_keras.restore_ckpt(model,
77 | tf.train.latest_checkpoint(FLAGS.model_dir),
78 | config.moving_average_decay,
79 | skip_mismatch=False)
80 | @tf.function
81 | def model_fn(images, labels):
82 | cls_outputs, box_outputs = model(images, training=False)
83 | detections = postprocess.generate_detections(config,
84 | cls_outputs,
85 | box_outputs,
86 | labels['image_scales'],
87 | labels['source_ids'])
88 | tf.numpy_function(evaluator.update_state,
89 | [labels['groundtruth_data'],
90 | postprocess.transform_detections(detections)], [])
91 |
92 | # Evaluator for AP calculation.
93 | label_map = label_util.get_label_map(config.label_map)
94 | evaluator = coco_metric.EvaluationMetric(
95 | filename=config.val_json_file, label_map=label_map)
96 |
97 | # dataset
98 | batch_size = FLAGS.batch_size # global batch size.
99 | ds = dataloader.InputReader(
100 | FLAGS.val_file_pattern,
101 | is_training=False,
102 | max_instances_per_image=config.max_instances_per_image)(
103 | config, batch_size=batch_size)
104 | if FLAGS.eval_samples:
105 | ds = ds.take((FLAGS.eval_samples + batch_size - 1) // batch_size)
106 | ds = ds_strategy.experimental_distribute_dataset(ds)
107 |
108 | # evaluate all images.
109 | eval_samples = FLAGS.eval_samples or 5000
110 | pbar = tf.keras.utils.Progbar((eval_samples + batch_size - 1) // batch_size)
111 | for i, (images, labels) in enumerate(ds):
112 | ds_strategy.run(model_fn, (images, labels))
113 | pbar.update(i)
114 |
115 | # compute the final eval results.
116 | metrics = evaluator.result()
117 | metric_dict = {}
118 | for i, name in enumerate(evaluator.metric_names):
119 | metric_dict[name] = metrics[i]
120 |
121 | if label_map:
122 | for i, cid in enumerate(sorted(label_map.keys())):
123 | name = 'AP_/%s' % label_map[cid]
124 | metric_dict[name] = metrics[i + len(evaluator.metric_names)]
125 | print(FLAGS.model_name, metric_dict)
126 |
127 |
128 | if __name__ == '__main__':
129 | flags.mark_flag_as_required('val_file_pattern')
130 | flags.mark_flag_as_required('model_dir')
131 | logging.set_verbosity(logging.ERROR)
132 | app.run(main)
133 |
--------------------------------------------------------------------------------
/efficientdet/tf2/fpn_configs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """BiFPN/QuFPN and other FPN configs.
16 |
17 | BiFPN is presented in the EfficientDet paper.
18 | QuFPN is proposed in https://github.com/google/automl/pull/580
19 | """
20 | import itertools
21 | import hparams_config
22 |
23 |
24 | def bifpn_config(min_level, max_level, weight_method):
25 | """A dynamic bifpn config that can adapt to different min/max levels."""
26 | p = hparams_config.Config()
27 | p.weight_method = weight_method or 'fastattn'
28 |
29 | # Node id starts from the input features and monotonically increase whenever
30 | # a new node is added. Here is an example for level P3 - P7:
31 | # P7 (4) P7" (12)
32 | # P6 (3) P6' (5) P6" (11)
33 | # P5 (2) P5' (6) P5" (10)
34 | # P4 (1) P4' (7) P4" (9)
35 | # P3 (0) P3" (8)
36 | # So output would be like:
37 | # [
38 | # {'feat_level': 6, 'inputs_offsets': [3, 4]}, # for P6'
39 | # {'feat_level': 5, 'inputs_offsets': [2, 5]}, # for P5'
40 | # {'feat_level': 4, 'inputs_offsets': [1, 6]}, # for P4'
41 | # {'feat_level': 3, 'inputs_offsets': [0, 7]}, # for P3"
42 | # {'feat_level': 4, 'inputs_offsets': [1, 7, 8]}, # for P4"
43 | # {'feat_level': 5, 'inputs_offsets': [2, 6, 9]}, # for P5"
44 | # {'feat_level': 6, 'inputs_offsets': [3, 5, 10]}, # for P6"
45 | # {'feat_level': 7, 'inputs_offsets': [4, 11]}, # for P7"
46 | # ]
47 | num_levels = max_level - min_level + 1
48 | node_ids = {min_level + i: [i] for i in range(num_levels)}
49 |
50 | level_last_id = lambda level: node_ids[level][-1]
51 | level_all_ids = lambda level: node_ids[level]
52 | id_cnt = itertools.count(num_levels)
53 |
54 | p.nodes = []
55 | for i in range(max_level - 1, min_level - 1, -1):
56 | # top-down path.
57 | p.nodes.append({
58 | 'feat_level': i,
59 | 'inputs_offsets': [level_last_id(i),
60 | level_last_id(i + 1)]
61 | })
62 | node_ids[i].append(next(id_cnt))
63 |
64 | for i in range(min_level + 1, max_level + 1):
65 | # bottom-up path.
66 | p.nodes.append({
67 | 'feat_level': i,
68 | 'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)]
69 | })
70 | node_ids[i].append(next(id_cnt))
71 |
72 | return p
73 |
74 |
75 | def qufpn_config(min_level, max_level, weight_method=None):
76 | """A dynamic quad fpn config that can adapt to different min/max levels."""
77 | # It extends the idea of BiFPN, and has four paths:
78 | # (up_down -> bottom_up) + (bottom_up -> up_down).
79 | # See test for an example for level 2 and 7.
80 | p = hparams_config.Config()
81 | p.weight_method = weight_method or 'fastattn'
82 | p.quad_method = 'fastattn'
83 | num_levels = max_level - min_level + 1
84 | node_ids = {min_level + i: [i] for i in range(num_levels)}
85 | level_last_id = lambda level: node_ids[level][-1]
86 | level_all_ids = lambda level: node_ids[level]
87 | level_first_id = lambda level: node_ids[level][0]
88 | id_cnt = itertools.count(num_levels)
89 |
90 | p.nodes = []
91 | for i in range(max_level - 1, min_level - 1, -1):
92 | # top-down path 1.
93 | p.nodes.append({
94 | 'feat_level': i,
95 | 'inputs_offsets': [level_last_id(i),
96 | level_last_id(i + 1)],
97 | 'weight_method': p.weight_method
98 | })
99 | node_ids[i].append(next(id_cnt))
100 | node_ids[max_level].append(node_ids[max_level][-1])
101 |
102 | for i in range(min_level + 1, max_level):
103 | # bottom-up path 2.
104 | p.nodes.append({
105 | 'feat_level': i,
106 | 'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)],
107 | 'weight_method': p.weight_method
108 | })
109 | node_ids[i].append(next(id_cnt))
110 |
111 | i = max_level
112 | p.nodes.append({
113 | 'feat_level': i,
114 | 'inputs_offsets': [level_first_id(i)] + [level_last_id(i - 1)],
115 | 'weight_method': p.weight_method
116 | })
117 | node_ids[i].append(next(id_cnt))
118 | node_ids[min_level].append(node_ids[min_level][-1])
119 |
120 | for i in range(min_level + 1, max_level + 1, 1):
121 | # bottom-up path 3.
122 | p.nodes.append({
123 | 'feat_level': i,
124 | 'inputs_offsets': [
125 | level_first_id(i),
126 | level_last_id(i - 1) if i != min_level + 1 else level_first_id(i -
127 | 1)
128 | ],
129 | 'weight_method': p.weight_method
130 | })
131 | node_ids[i].append(next(id_cnt))
132 | node_ids[min_level].append(node_ids[min_level][-1])
133 |
134 | for i in range(max_level - 1, min_level, -1):
135 | # top-down path 4.
136 | p.nodes.append({
137 | 'feat_level':
138 | i,
139 | 'inputs_offsets': [node_ids[i][0]] + [node_ids[i][-1]] +
140 | [level_last_id(i + 1)],
141 | 'weight_method':
142 | p.weight_method
143 | })
144 | node_ids[i].append(next(id_cnt))
145 | i = min_level
146 | p.nodes.append({
147 | 'feat_level': i,
148 | 'inputs_offsets': [node_ids[i][0]] + [level_last_id(i + 1)],
149 | 'weight_method': p.weight_method
150 | })
151 | node_ids[i].append(next(id_cnt))
152 | node_ids[max_level].append(node_ids[max_level][-1])
153 |
154 | for i in range(max_level, min_level - 1, -1):
155 | # quad-add path.
156 | p.nodes.append({
157 | 'feat_level': i,
158 | 'inputs_offsets': [node_ids[i][2], node_ids[i][4]],
159 | 'weight_method': p.quad_method
160 | })
161 | node_ids[i].append(next(id_cnt))
162 |
163 | return p
164 |
165 |
166 | def get_fpn_config(fpn_name, min_level, max_level, weight_method):
167 | """Get fpn related configuration."""
168 | if not fpn_name:
169 | fpn_name = 'bifpn'
170 | name_to_config = {
171 | 'bifpn': bifpn_config(min_level, max_level, weight_method),
172 | 'qufpn': qufpn_config(min_level, max_level, weight_method),
173 | # legacy only: to be deprecated.
174 | 'bifpn_dyn': bifpn_config(min_level, max_level, weight_method),
175 | }
176 | return name_to_config[fpn_name]
177 |
--------------------------------------------------------------------------------
/efficientdet/tf2/fpn_configs_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for fpn_configs."""
16 | from absl import logging
17 | import tensorflow as tf
18 | from tf2 import fpn_configs
19 |
20 |
21 | class FpnConfigTest(tf.test.TestCase):
22 |
23 | def test_bifpn_l3l7(self):
24 | p1 = fpn_configs.bifpn_config(3, 7, None)
25 | # pyformat: disable
26 | self.assertEqual(
27 | p1.nodes,
28 | [
29 | {'feat_level': 6, 'inputs_offsets': [3, 4]},
30 | {'feat_level': 5, 'inputs_offsets': [2, 5]},
31 | {'feat_level': 4, 'inputs_offsets': [1, 6]},
32 | {'feat_level': 3, 'inputs_offsets': [0, 7]},
33 | {'feat_level': 4, 'inputs_offsets': [1, 7, 8]},
34 | {'feat_level': 5, 'inputs_offsets': [2, 6, 9]},
35 | {'feat_level': 6, 'inputs_offsets': [3, 5, 10]},
36 | {'feat_level': 7, 'inputs_offsets': [4, 11]},
37 | ])
38 | # pyformat: enable
39 |
40 | def test_bifpn_l2l7(self):
41 | p = fpn_configs.bifpn_config(2, 7, None)
42 |
43 | # pyformat: disable
44 | self.assertEqual(
45 | p.nodes,
46 | [
47 | {'feat_level': 6, 'inputs_offsets': [4, 5]},
48 | {'feat_level': 5, 'inputs_offsets': [3, 6]},
49 | {'feat_level': 4, 'inputs_offsets': [2, 7]},
50 | {'feat_level': 3, 'inputs_offsets': [1, 8]},
51 | {'feat_level': 2, 'inputs_offsets': [0, 9]},
52 | {'feat_level': 3, 'inputs_offsets': [1, 9, 10]},
53 | {'feat_level': 4, 'inputs_offsets': [2, 8, 11]},
54 | {'feat_level': 5, 'inputs_offsets': [3, 7, 12]},
55 | {'feat_level': 6, 'inputs_offsets': [4, 6, 13]},
56 | {'feat_level': 7, 'inputs_offsets': [5, 14]},
57 | ])
58 | # pyformat: enable
59 |
60 | def test_qufpn_dynamic_l3l7(self):
61 | p = fpn_configs.qufpn_config(3, 7, None)
62 |
63 | # pyformat: disable
64 | # pylint: disable=line-too-long
65 | self.assertEqual(
66 | p.nodes,
67 | [
68 | {'feat_level': 6, 'inputs_offsets': [3, 4], 'weight_method': 'fastattn'},
69 | {'feat_level': 5, 'inputs_offsets': [2, 5], 'weight_method': 'fastattn'},
70 | {'feat_level': 4, 'inputs_offsets': [1, 6], 'weight_method': 'fastattn'},
71 | {'feat_level': 3, 'inputs_offsets': [0, 7], 'weight_method': 'fastattn'},
72 | {'feat_level': 4, 'inputs_offsets': [1, 7, 8], 'weight_method': 'fastattn'},
73 | {'feat_level': 5, 'inputs_offsets': [2, 6, 9], 'weight_method': 'fastattn'},
74 | {'feat_level': 6, 'inputs_offsets': [3, 5, 10], 'weight_method': 'fastattn'},
75 | {'feat_level': 7, 'inputs_offsets': [4, 11], 'weight_method': 'fastattn'},
76 | {'feat_level': 4, 'inputs_offsets': [1, 0], 'weight_method': 'fastattn'},
77 | {'feat_level': 5, 'inputs_offsets': [2, 13], 'weight_method': 'fastattn'},
78 | {'feat_level': 6, 'inputs_offsets': [3, 14], 'weight_method': 'fastattn'},
79 | {'feat_level': 7, 'inputs_offsets': [4, 15], 'weight_method': 'fastattn'},
80 | {'feat_level': 6, 'inputs_offsets': [3, 15, 16], 'weight_method': 'fastattn'},
81 | {'feat_level': 5, 'inputs_offsets': [2, 14, 17], 'weight_method': 'fastattn'},
82 | {'feat_level': 4, 'inputs_offsets': [1, 13, 18], 'weight_method': 'fastattn'},
83 | {'feat_level': 3, 'inputs_offsets': [0, 19], 'weight_method': 'fastattn'},
84 | {'feat_level': 7, 'inputs_offsets': [12, 16], 'weight_method': 'fastattn'},
85 | {'feat_level': 6, 'inputs_offsets': [11, 17], 'weight_method': 'fastattn'},
86 | {'feat_level': 5, 'inputs_offsets': [10, 18], 'weight_method': 'fastattn'},
87 | {'feat_level': 4, 'inputs_offsets': [9, 19], 'weight_method': 'fastattn'},
88 | {'feat_level': 3, 'inputs_offsets': [8, 20], 'weight_method': 'fastattn'},
89 | ])
90 | # pylint: enable=line-too-long
91 | # pyformat: enable
92 |
93 |
94 | if __name__ == '__main__':
95 | logging.set_verbosity(logging.WARNING)
96 | tf.test.main()
97 |
--------------------------------------------------------------------------------
/efficientdet/tf2/infer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A simple example on how to use keras model for inference."""
16 | import os
17 | from absl import app
18 | from absl import flags
19 | from absl import logging
20 | import numpy as np
21 | from PIL import Image
22 | import tensorflow as tf
23 |
24 | import hparams_config
25 | import inference
26 | from tf2 import efficientdet_keras
27 |
28 | flags.DEFINE_string('image_path', None, 'Location of test image.')
29 | flags.DEFINE_string('output_dir', None, 'Directory of annotated output images.')
30 | flags.DEFINE_string('model_dir', None, 'Location of the checkpoint to run.')
31 | flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name to use.')
32 | flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file')
33 | flags.DEFINE_bool('debug', False, 'If true, run function in eager for debug.')
34 | flags.DEFINE_string('saved_model_dir', None, 'Saved model directory')
35 | FLAGS = flags.FLAGS
36 |
37 |
38 | def main(_):
39 |
40 | # pylint: disable=line-too-long
41 | # Prepare images and checkpoints: please run these commands in shell.
42 | # !mkdir tmp
43 | # !wget https://user-images.githubusercontent.com/11736571/77320690-099af300-6d37-11ea-9d86-24f14dc2d540.png -O tmp/img.png
44 | # !wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-d0.tar.gz -O tmp/efficientdet-d0.tar.gz
45 | # !tar zxf tmp/efficientdet-d0.tar.gz -C tmp
46 | imgs = [np.array(Image.open(FLAGS.image_path))]
47 | # Create model config.
48 | config = hparams_config.get_efficientdet_config(FLAGS.model_name)
49 | config.is_training_bn = False
50 | config.image_size = '1920x1280'
51 | config.nms_configs.score_thresh = 0.4
52 | config.nms_configs.max_output_size = 100
53 | config.override(FLAGS.hparams)
54 |
55 | # Use 'mixed_float16' if running on GPUs.
56 | policy = tf.keras.mixed_precision.Policy('float32')
57 | tf.keras.mixed_precision.set_global_policy(policy)
58 | tf.config.run_functions_eagerly(FLAGS.debug)
59 |
60 | # Create and run the model.
61 | model = efficientdet_keras.EfficientDetModel(config=config)
62 | model.build((None, None, None, 3))
63 | model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir))
64 | model.summary(expand_nested=True)
65 |
66 | class ExportModel(tf.Module):
67 |
68 | def __init__(self, model):
69 | super().__init__()
70 | self.model = model
71 |
72 | @tf.function
73 | def f(self, imgs):
74 | return self.model(imgs, training=False, post_mode='global')
75 |
76 | imgs = tf.convert_to_tensor(imgs, dtype=tf.uint8)
77 | export_model = ExportModel(model)
78 | if FLAGS.saved_model_dir:
79 | tf.saved_model.save(
80 | export_model,
81 | FLAGS.saved_model_dir,
82 | signatures=export_model.f.get_concrete_function(
83 | tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8)))
84 | export_model = tf.saved_model.load(FLAGS.saved_model_dir)
85 |
86 | boxes, scores, classes, valid_len = export_model.f(imgs)
87 |
88 | # Visualize results.
89 | for i, img in enumerate(imgs):
90 | length = valid_len[i]
91 | img = inference.visualize_image(
92 | img,
93 | boxes[i].numpy()[:length],
94 | classes[i].numpy().astype(np.int)[:length],
95 | scores[i].numpy()[:length],
96 | label_map=config.label_map,
97 | min_score_thresh=config.nms_configs.score_thresh,
98 | max_boxes_to_draw=config.nms_configs.max_output_size)
99 | output_image_path = os.path.join(FLAGS.output_dir, str(i) + '.jpg')
100 | Image.fromarray(img).save(output_image_path)
101 | print('writing annotated image to %s' % output_image_path)
102 |
103 |
104 | if __name__ == '__main__':
105 | flags.mark_flag_as_required('image_path')
106 | flags.mark_flag_as_required('output_dir')
107 | flags.mark_flag_as_required('model_dir')
108 | logging.set_verbosity(logging.ERROR)
109 | app.run(main)
110 |
--------------------------------------------------------------------------------
/efficientdet/tf2/infer_lib_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Inference test cases."""
16 | import os
17 | import tempfile
18 | from absl import logging
19 | import tensorflow as tf
20 | import test_util
21 | from tf2 import efficientdet_keras
22 | from tf2 import infer_lib
23 |
24 |
25 | class InferenceTest(tf.test.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | tf.random.set_seed(111111)
30 | model = efficientdet_keras.EfficientDetModel('efficientdet-d0')
31 | self.tmp_path = tempfile.mkdtemp()
32 | model.build([1, 512, 512, 3])
33 | model.save_weights(os.path.join(self.tmp_path, 'model'))
34 |
35 | lite_model = efficientdet_keras.EfficientDetModel('efficientdet-lite0')
36 | self.lite_tmp_path = tempfile.mkdtemp()
37 | lite_model.build([1, 512, 512, 3])
38 | lite_model.save_weights(os.path.join(self.lite_tmp_path, 'model'))
39 |
40 | def test_export(self):
41 | saved_model_path = os.path.join(self.tmp_path, 'saved_model')
42 | driver = infer_lib.KerasDriver(self.tmp_path, False, 'efficientdet-d0')
43 | driver.export(saved_model_path)
44 | has_saved_model = tf.saved_model.contains_saved_model(saved_model_path)
45 | self.assertAllEqual(has_saved_model, True)
46 | driver = infer_lib.SavedModelDriver(saved_model_path, 'efficientdet-d0')
47 | fg_path = os.path.join(saved_model_path, 'efficientdet-d0_frozen.pb')
48 | driver = infer_lib.SavedModelDriver(fg_path, 'efficientdet-d0')
49 |
50 | def test_export_tflite_only_network(self):
51 | saved_model_path = os.path.join(self.lite_tmp_path, 'saved_model')
52 | driver = infer_lib.KerasDriver(
53 | self.lite_tmp_path, False, 'efficientdet-lite0', only_network=True)
54 | driver.export(saved_model_path, tflite='FP32')
55 | self.assertTrue(
56 | tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite')))
57 | tf.io.gfile.rmtree(saved_model_path)
58 | driver.export(saved_model_path, tflite='FP16')
59 | self.assertTrue(
60 | tf.io.gfile.exists(os.path.join(saved_model_path, 'fp16.tflite')))
61 | tf.io.gfile.rmtree(saved_model_path)
62 | tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir())
63 | driver.export(
64 | saved_model_path,
65 | tflite='INT8',
66 | file_pattern=[tfrecord_path],
67 | num_calibration_steps=1)
68 | self.assertTrue(
69 | tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite')))
70 |
71 | def test_export_tflite_with_post_processing(self):
72 | saved_model_path = os.path.join(self.lite_tmp_path, 'saved_model')
73 | driver = infer_lib.KerasDriver(
74 | self.lite_tmp_path, False, 'efficientdet-lite0', only_network=False)
75 | driver.export(saved_model_path, tflite='FP32')
76 | self.assertTrue(
77 | tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite')))
78 | tf.io.gfile.rmtree(saved_model_path)
79 | tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir())
80 | driver.export(
81 | saved_model_path,
82 | tflite='INT8',
83 | file_pattern=[tfrecord_path],
84 | num_calibration_steps=1)
85 | self.assertTrue(
86 | tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite')))
87 |
88 | def test_infer_lib(self):
89 | driver = infer_lib.KerasDriver(self.tmp_path, False, 'efficientdet-d0')
90 | images = tf.ones((1, 512, 512, 3))
91 | boxes, scores, classes, valid_lens = driver.serve(images)
92 | self.assertEqual(tf.reduce_mean(boxes), 163.09)
93 | self.assertEqual(tf.reduce_mean(scores), 0.01)
94 | self.assertEqual(tf.reduce_mean(classes), 1)
95 | self.assertEqual(tf.reduce_mean(valid_lens), 100)
96 | self.assertEqual(boxes.shape, (1, 100, 4))
97 | self.assertEqual(scores.shape, (1, 100))
98 | self.assertEqual(classes.shape, (1, 100))
99 | self.assertEqual(valid_lens.shape, (1,))
100 |
101 | def test_infer_lib_without_ema(self):
102 | driver = infer_lib.KerasDriver(
103 | self.tmp_path,
104 | False,
105 | 'efficientdet-d0',
106 | model_params={'moving_average_decay': 0})
107 | images = tf.ones((1, 512, 512, 3))
108 | boxes, scores, classes, valid_lens = driver.serve(images)
109 | self.assertEqual(tf.reduce_mean(boxes), 163.09)
110 | self.assertEqual(tf.reduce_mean(scores), 0.01)
111 | self.assertEqual(tf.reduce_mean(classes), 1)
112 | self.assertEqual(tf.reduce_mean(valid_lens), 100)
113 | self.assertEqual(boxes.shape, (1, 100, 4))
114 | self.assertEqual(scores.shape, (1, 100))
115 | self.assertEqual(classes.shape, (1, 100))
116 | self.assertEqual(valid_lens.shape, (1,))
117 |
118 | def test_network_infer_lib(self):
119 | driver = infer_lib.KerasDriver(
120 | self.tmp_path, False, 'efficientdet-d0', only_network=True)
121 | images = tf.ones((1, 512, 512, 3))
122 | class_outputs, box_outputs = driver.predict(images)
123 | self.assertLen(class_outputs, 5)
124 | self.assertLen(box_outputs, 5)
125 |
126 | def test_infer_lib_mixed_precision(self):
127 | driver = infer_lib.KerasDriver(
128 | self.tmp_path,
129 | False,
130 | 'efficientdet-d0',
131 | model_params={'mixed_precision': True})
132 | images = tf.ones((1, 512, 512, 3))
133 | boxes, scores, classes, valid_lens = driver.serve(images)
134 | policy = tf.keras.mixed_precision.global_policy()
135 | if policy.name == 'float32':
136 | self.assertEqual(tf.reduce_mean(boxes), 163.09)
137 | self.assertEqual(tf.reduce_mean(scores), 0.01)
138 | self.assertEqual(tf.reduce_mean(classes), 1)
139 | self.assertEqual(tf.reduce_mean(valid_lens), 100)
140 | elif policy.name == 'float16':
141 | pass
142 | elif policy.name == 'bfloat16':
143 | pass
144 | self.assertEqual(boxes.shape, (1, 100, 4))
145 | self.assertEqual(scores.shape, (1, 100))
146 | self.assertEqual(classes.shape, (1, 100))
147 | self.assertEqual(valid_lens.shape, (1,))
148 |
149 |
150 | if __name__ == '__main__':
151 | logging.set_verbosity(logging.WARNING)
152 | tf.test.main()
153 |
--------------------------------------------------------------------------------
/efficientdet/tf2/inspector_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Tests for model inspect tool."""
16 | import os
17 | import shutil
18 | import tempfile
19 | from absl import flags
20 | from absl import logging
21 | from absl.testing import flagsaver
22 | import numpy as np
23 | from PIL import Image
24 | import tensorflow as tf
25 |
26 | from tf2 import inspector
27 | FLAGS = flags.FLAGS
28 |
29 |
30 | class InspectorTest(tf.test.TestCase):
31 | """Model inspect tests."""
32 |
33 | def setUp(self):
34 | super().setUp()
35 | self.tempdir = tempfile.mkdtemp()
36 | FLAGS.model_dir = '_'
37 |
38 | def tearDown(self):
39 | super().tearDown()
40 | shutil.rmtree(self.tempdir)
41 |
42 | @flagsaver.flagsaver(mode='dry')
43 | def test_dry(self):
44 | FLAGS.export_ckpt = os.path.join(self.tempdir, 'model')
45 | inspector.main(None)
46 | self.assertIsNot(tf.train.get_checkpoint_state(self.tempdir), None)
47 |
48 | @flagsaver.flagsaver(mode='infer', saved_model_dir=None)
49 | def test_infer(self):
50 | test_image = np.random.randint(0, 244, (640, 720, 3)).astype(np.uint8)
51 | FLAGS.input_image = os.path.join(self.tempdir, 'img.jpg')
52 | Image.fromarray(test_image).save(FLAGS.input_image)
53 | FLAGS.output_image_dir = self.tempdir
54 | inspector.main(None)
55 | self.assertTrue(tf.io.gfile.exists(os.path.join(self.tempdir, '0.jpg')))
56 |
57 | @flagsaver.flagsaver(mode='benchmark', saved_model_dir=None)
58 | def test_benchmark(self):
59 | inspector.main(None)
60 | self.assertFalse(tf.io.gfile.exists(os.path.join(self.tempdir, '0.jpg')))
61 |
62 | @flagsaver.flagsaver(mode='export', tflite='FP32')
63 | def test_export(self):
64 | FLAGS.saved_model_dir = os.path.join(self.tempdir, 'savedmodel')
65 | tflite_path = os.path.join(FLAGS.saved_model_dir, 'fp32.tflite')
66 | inspector.main(None)
67 | self.assertTrue(tf.saved_model.contains_saved_model(FLAGS.saved_model_dir))
68 | self.assertTrue(tf.io.gfile.exists(tflite_path))
69 |
70 |
71 | if __name__ == '__main__':
72 | logging.set_verbosity(logging.WARNING)
73 | tf.test.main()
74 |
--------------------------------------------------------------------------------
/efficientdet/tf2/label_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A few predefined label id mapping."""
16 | import tensorflow as tf
17 | import yaml
18 | import hparams_config
19 |
20 | coco = {
21 | # 0: 'background',
22 | 1: 'person',
23 | 2: 'bicycle',
24 | 3: 'car',
25 | 4: 'motorcycle',
26 | 5: 'airplane',
27 | 6: 'bus',
28 | 7: 'train',
29 | 8: 'truck',
30 | 9: 'boat',
31 | 10: 'traffic light',
32 | 11: 'fire hydrant',
33 | 13: 'stop sign',
34 | 14: 'parking meter',
35 | 15: 'bench',
36 | 16: 'bird',
37 | 17: 'cat',
38 | 18: 'dog',
39 | 19: 'horse',
40 | 20: 'sheep',
41 | 21: 'cow',
42 | 22: 'elephant',
43 | 23: 'bear',
44 | 24: 'zebra',
45 | 25: 'giraffe',
46 | 27: 'backpack',
47 | 28: 'umbrella',
48 | 31: 'handbag',
49 | 32: 'tie',
50 | 33: 'suitcase',
51 | 34: 'frisbee',
52 | 35: 'skis',
53 | 36: 'snowboard',
54 | 37: 'sports ball',
55 | 38: 'kite',
56 | 39: 'baseball bat',
57 | 40: 'baseball glove',
58 | 41: 'skateboard',
59 | 42: 'surfboard',
60 | 43: 'tennis racket',
61 | 44: 'bottle',
62 | 46: 'wine glass',
63 | 47: 'cup',
64 | 48: 'fork',
65 | 49: 'knife',
66 | 50: 'spoon',
67 | 51: 'bowl',
68 | 52: 'banana',
69 | 53: 'apple',
70 | 54: 'sandwich',
71 | 55: 'orange',
72 | 56: 'broccoli',
73 | 57: 'carrot',
74 | 58: 'hot dog',
75 | 59: 'pizza',
76 | 60: 'donut',
77 | 61: 'cake',
78 | 62: 'chair',
79 | 63: 'couch',
80 | 64: 'potted plant',
81 | 65: 'bed',
82 | 67: 'dining table',
83 | 70: 'toilet',
84 | 72: 'tv',
85 | 73: 'laptop',
86 | 74: 'mouse',
87 | 75: 'remote',
88 | 76: 'keyboard',
89 | 77: 'cell phone',
90 | 78: 'microwave',
91 | 79: 'oven',
92 | 80: 'toaster',
93 | 81: 'sink',
94 | 82: 'refrigerator',
95 | 84: 'book',
96 | 85: 'clock',
97 | 86: 'vase',
98 | 87: 'scissors',
99 | 88: 'teddy bear',
100 | 89: 'hair drier',
101 | 90: 'toothbrush',
102 | }
103 |
104 | voc = {
105 | # 0: 'background',
106 | 1: 'aeroplane',
107 | 2: 'bicycle',
108 | 3: 'bird',
109 | 4: 'boat',
110 | 5: 'bottle',
111 | 6: 'bus',
112 | 7: 'car',
113 | 8: 'cat',
114 | 9: 'chair',
115 | 10: 'cow',
116 | 11: 'diningtable',
117 | 12: 'dog',
118 | 13: 'horse',
119 | 14: 'motorbike',
120 | 15: 'person',
121 | 16: 'pottedplant',
122 | 17: 'sheep',
123 | 18: 'sofa',
124 | 19: 'train',
125 | 20: 'tvmonitor',
126 | }
127 |
128 | waymo = {
129 | # 0: 'background',
130 | 1: 'vehicle',
131 | 2: 'pedestrian',
132 | 3: 'cyclist',
133 | }
134 |
135 |
136 | def get_label_map(mapping):
137 | """Get label id map based on the name, filename, or dict."""
138 | # case 1: if it is None or dict, just return it.
139 | if not mapping or isinstance(mapping, dict):
140 | return mapping
141 |
142 | if isinstance(mapping, hparams_config.Config):
143 | return mapping.as_dict()
144 |
145 | # case 2: if it is a yaml file, load it to a dict and return the dict.
146 | assert isinstance(mapping, str), 'mapping must be dict or str.'
147 | if mapping.endswith('.yaml'):
148 | with tf.io.gfile.GFile(mapping) as f:
149 | return yaml.load(f, Loader=yaml.FullLoader)
150 |
151 | # case 3: it is a name of a predefined dataset.
152 | return {'coco': coco, 'voc': voc, 'waymo': waymo}[mapping]
153 |
--------------------------------------------------------------------------------
/efficientdet/tf2/segmentation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A demo script to show to train a segmentation model."""
16 | from absl import app
17 | from absl import logging
18 | import tensorflow as tf
19 | import tensorflow_datasets as tfds
20 |
21 | import hparams_config
22 | from tf2 import efficientdet_keras
23 |
24 |
25 | def create_mask(pred_mask):
26 | pred_mask = tf.argmax(pred_mask, axis=-1)
27 | pred_mask = pred_mask[..., tf.newaxis]
28 | return pred_mask[0]
29 |
30 |
31 | def normalize(input_image, input_mask):
32 | input_image = tf.cast(input_image, tf.float32) / 255.0
33 | input_mask -= 1
34 | return input_image, input_mask
35 |
36 |
37 | def load_image_train(datapoint):
38 | """Load images for training."""
39 | input_image = tf.image.resize(datapoint['image'], (512, 512))
40 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
41 |
42 | if tf.random.uniform(()) > 0.5:
43 | input_image = tf.image.flip_left_right(input_image)
44 | input_mask = tf.image.flip_left_right(input_mask)
45 |
46 | input_image, input_mask = normalize(input_image, input_mask)
47 |
48 | return input_image, input_mask
49 |
50 |
51 | def load_image_test(datapoint):
52 | input_image = tf.image.resize(datapoint['image'], (512, 512))
53 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
54 |
55 | input_image, input_mask = normalize(input_image, input_mask)
56 |
57 | return input_image, input_mask
58 |
59 |
60 | def main(_):
61 | dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
62 | train_examples = info.splits['train'].num_examples
63 | batch_size = 8
64 | steps_per_epoch = train_examples // batch_size
65 |
66 | train = dataset['train'].map(
67 | load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
68 | test = dataset['test'].map(load_image_test)
69 |
70 | train_dataset = train.cache().shuffle(1000).batch(batch_size).repeat()
71 | train_dataset = train_dataset.prefetch(
72 | buffer_size=tf.data.experimental.AUTOTUNE)
73 | test_dataset = test.batch(batch_size)
74 | config = hparams_config.get_efficientdet_config('efficientdet-d0')
75 | config.heads = ['segmentation']
76 | model = efficientdet_keras.EfficientDetNet(config=config)
77 | model.build((1, 512, 512, 3))
78 | model.compile(
79 | optimizer='adam',
80 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
81 | metrics=['accuracy'])
82 |
83 | val_subsplits = 5
84 | val_steps = info.splits['test'].num_examples // batch_size // val_subsplits
85 | model.fit(
86 | train_dataset,
87 | epochs=20,
88 | steps_per_epoch=steps_per_epoch,
89 | validation_steps=val_steps,
90 | validation_data=test_dataset,
91 | callbacks=[])
92 |
93 | model.save_weights(
94 | './testdata/segmentation')
95 |
96 | print(create_mask(model(tf.ones((1, 512, 512, 3)), False)))
97 |
98 |
99 | if __name__ == '__main__':
100 | logging.set_verbosity(logging.WARNING)
101 | app.run(main)
102 |
--------------------------------------------------------------------------------
/efficientdet/tf2/tfmot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A tool for model optimization."""
16 | import functools
17 |
18 | import tensorflow_model_optimization as tfmot
19 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper
20 | from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs
21 |
22 |
23 | def quantize(layer, quantize_config=None):
24 | if quantize_config is None:
25 | quantize_config = default_8bit_quantize_configs.Default8BitOutputQuantizeConfig(
26 | )
27 | return quantize_wrapper.QuantizeWrapper(
28 | layer, quantize_config=quantize_config)
29 |
30 |
31 | optimzation_methods = {
32 | 'prune': tfmot.sparsity.keras.prune_low_magnitude,
33 | 'quantize': quantize
34 | }
35 |
36 |
37 | def set_config(configs):
38 | for key in configs:
39 | if key == 'prune':
40 | optimzation_methods[key] = functools.partial(
41 | tfmot.sparsity.keras.prune_low_magnitude, **configs[key])
42 | if key == 'quantize':
43 | optimzation_methods[key] = functools.partial(quantize, **configs[key])
44 |
45 |
46 | def get_method(method):
47 | if method not in optimzation_methods:
48 | raise KeyError(f'only support {optimzation_methods.keys()}')
49 | return optimzation_methods[method]
50 |
--------------------------------------------------------------------------------
/efficientdet/tf2/util_keras_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | from absl import logging
16 | from absl.testing import parameterized
17 | import tensorflow as tf
18 |
19 | import utils
20 | from tf2 import util_keras
21 |
22 |
23 | class KerasUtilTest(tf.test.TestCase, parameterized.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | ('train_local', True, ''), ('eval_local', False, ''),
27 | ('train_tpu', True, 'tpu'), ('eval_tpu', False, 'tpu'))
28 | def test_batch_norm(self, is_training, strategy):
29 | inputs = tf.random.uniform([8, 40, 40, 3])
30 | expect_results = utils.batch_norm_act(
31 | inputs, is_training, None, strategy=strategy)
32 |
33 | # Call batch norm layer with is_training parameter.
34 | bn_layer = util_keras.build_batch_norm(is_training, strategy=strategy)
35 | self.assertAllClose(expect_results, bn_layer(inputs, is_training))
36 |
37 |
38 | if __name__ == '__main__':
39 | logging.set_verbosity(logging.WARNING)
40 | tf.test.main()
41 |
--------------------------------------------------------------------------------
/efficientdet/tf2/wbf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """WBF for test-time augmentation."""
16 | import tensorflow as tf
17 |
18 |
19 | def vectorized_iou(clusters, detection):
20 | """Calculates the ious for box with each element of clusters."""
21 | x11, y11, x12, y12 = tf.split(clusters[:, 1:5], 4, axis=1)
22 | x21, y21, x22, y22 = tf.split(detection[1:5], 4)
23 |
24 | xa = tf.maximum(x11, x21)
25 | ya = tf.maximum(y11, y21)
26 | xb = tf.minimum(x12, x22)
27 | yb = tf.minimum(y12, y22)
28 |
29 | inter_area = tf.maximum((xb - xa), 0) * tf.maximum((yb - ya), 0)
30 |
31 | boxa_area = (x12 - x11) * (y12 - y11)
32 | boxb_area = (x22 - x21) * (y22 - y21)
33 |
34 | iou = inter_area / (boxa_area + boxb_area - inter_area)
35 |
36 | return iou
37 |
38 |
39 | def find_matching_cluster(clusters, detection):
40 | """Returns the index of the highest iou matching cluster for detection."""
41 | if not clusters:
42 | return -1
43 | ious = vectorized_iou(tf.stack(clusters), detection)
44 | ious = tf.reshape(ious, [len(clusters)])
45 | if tf.math.reduce_max(ious) < 0.55:
46 | # returns -1 if no iou is higher than 0.55.
47 | return -1
48 | return tf.argmax(ious)
49 |
50 |
51 | def weighted_average(samples, weights):
52 | return tf.math.reduce_sum(samples * weights) / tf.math.reduce_sum(weights)
53 |
54 |
55 | def average_detections(detections, num_models):
56 | """Takes a list of detections and returns the average, both in box co-ordinates and confidence."""
57 | num_detections = len(detections)
58 | detections = tf.stack(detections)
59 | return [
60 | detections[0][0],
61 | weighted_average(detections[:, 1], detections[:, 5]),
62 | weighted_average(detections[:, 2], detections[:, 5]),
63 | weighted_average(detections[:, 3], detections[:, 5]),
64 | weighted_average(detections[:, 4], detections[:, 5]),
65 | tf.math.reduce_mean(detections[:, 5]) * min(1, num_detections/num_models),
66 | detections[0][6],
67 | ]
68 |
69 |
70 | def ensemble_detections(params, detections, num_models):
71 | """Ensembles a group of detections by clustering the detections and returning the average of the clusters."""
72 | all_clusters = []
73 |
74 | for cid in range(params['num_classes']):
75 | indices = tf.where(tf.equal(detections[:, 6], cid))
76 | if indices.shape[0] == 0:
77 | continue
78 | class_detections = tf.gather_nd(detections, indices)
79 |
80 | clusters = []
81 | cluster_averages = []
82 | for d in class_detections:
83 | cluster_index = find_matching_cluster(cluster_averages, d)
84 | if cluster_index == -1:
85 | clusters.append([d])
86 | cluster_averages.append(average_detections([d], num_models))
87 | else:
88 | clusters[cluster_index].append(d)
89 | cluster_averages[cluster_index] = average_detections(
90 | clusters[cluster_index], num_models)
91 |
92 | all_clusters.extend(cluster_averages)
93 |
94 | all_clusters.sort(reverse=True, key=lambda d: d[5])
95 | return tf.stack(all_clusters)
96 |
--------------------------------------------------------------------------------
/efficientdet/tf2/wbf_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Test for wbf."""
16 | from absl import logging
17 | import tensorflow as tf
18 |
19 | from tf2 import wbf
20 |
21 |
22 | class WbfTest(tf.test.TestCase):
23 |
24 | def test_detection_iou_same(self):
25 | d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32)
26 | d2 = tf.constant([1, 1, 1, 3, 3, 1, 1], dtype=tf.float32)
27 |
28 | iou = wbf.vectorized_iou(d1, d2)
29 |
30 | self.assertAllClose(iou[0][0], 1.0)
31 |
32 | def test_detection_iou_corners(self):
33 | d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32)
34 | d2 = tf.constant([1, 2, 2, 4, 4, 1, 1], dtype=tf.float32)
35 |
36 | iou = wbf.vectorized_iou(d1, d2)
37 |
38 | self.assertAllClose(iou[0][0], 1.0 / 7.0)
39 |
40 | def test_detection_iou_ends(self):
41 | d1 = tf.constant([[1, 1, 1, 3, 2, 1, 1]], dtype=tf.float32)
42 | d2 = tf.constant([1, 2, 1, 4, 2, 1, 1], dtype=tf.float32)
43 |
44 | iou = wbf.vectorized_iou(d1, d2)
45 |
46 | self.assertAllClose(iou[0][0], 1.0 / 3.0)
47 |
48 | def test_detection_iou_none(self):
49 | d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32)
50 | d2 = tf.constant([1, 3, 3, 5, 5, 1, 1], dtype=tf.float32)
51 |
52 | iou = wbf.vectorized_iou(d1, d2)
53 |
54 | self.assertAllClose(iou[0][0], 0)
55 |
56 | def test_detection_iou_vector(self):
57 | vector_to_match = tf.constant(
58 | [
59 | [1, 1, 1, 3, 3, 1, 1],
60 | [1, 2, 2, 4, 4, 1, 1],
61 | [1, 3, 3, 5, 5, 1, 1],
62 | ],
63 | dtype=tf.float32,
64 | )
65 |
66 | detection = tf.constant([1, 1, 1, 3, 3, 1, 1], dtype=tf.float32)
67 |
68 | ious = wbf.vectorized_iou(vector_to_match, detection)
69 | self.assertAllClose(tf.reshape(ious, [3]), [1, 1.0 / 7.0, 0])
70 |
71 | def test_find_matching_cluster_matches(self):
72 | matching_cluster = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32)
73 | non_matching_cluster = tf.constant([1, 3, 3, 2, 2, 1, 1], dtype=tf.float32)
74 |
75 | box = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32)
76 |
77 | cluster_index = wbf.find_matching_cluster(
78 | (matching_cluster, non_matching_cluster), box)
79 |
80 | self.assertAllClose(cluster_index, 0)
81 |
82 | cluster_index = wbf.find_matching_cluster(
83 | (non_matching_cluster, matching_cluster), box)
84 |
85 | self.assertAllClose(cluster_index, 1)
86 |
87 | def test_find_matching_cluster_best_overlap(self):
88 | overlaps = tf.constant([1, 1, 1, 11, 2, 1, 1], dtype=tf.float32)
89 | overlaps_better = tf.constant([1, 2, 1, 12, 2, 1, 1], dtype=tf.float32)
90 |
91 | box = tf.constant([1, 3, 1, 13, 2, 1, 1], dtype=tf.float32)
92 |
93 | cluster_index = wbf.find_matching_cluster((overlaps,), box)
94 |
95 | self.assertAllClose(cluster_index, 0)
96 |
97 | cluster_index = wbf.find_matching_cluster((overlaps, overlaps_better), box)
98 |
99 | self.assertAllClose(cluster_index, 1)
100 |
101 | def test_weighted_average(self):
102 | samples = tf.constant([1, 3], dtype=tf.float32)
103 |
104 | weights1 = tf.constant([0.5, 0.5], dtype=tf.float32)
105 | weighted_average1 = wbf.weighted_average(samples, weights1)
106 |
107 | self.assertAllClose(weighted_average1, 2)
108 |
109 | weights2 = tf.constant([1, 0], dtype=tf.float32)
110 | weighted_average2 = wbf.weighted_average(samples, weights2)
111 |
112 | self.assertAllClose(weighted_average2, 1)
113 |
114 | weights3 = tf.constant([1, 2], dtype=tf.float32)
115 | weighted_average3 = wbf.weighted_average(samples, weights3)
116 |
117 | self.assertAllClose(weighted_average3, 7.0 / 3.0)
118 |
119 | def test_average_detections(self):
120 | d1 = tf.constant([1, 1, 1, 2, 2, 0.3, 1], dtype=tf.float32)
121 | d2 = tf.constant([1, 3, 3, 4, 4, 0.7, 1], dtype=tf.float32)
122 |
123 | averaged_single_model = wbf.average_detections((d1, d2), 1)
124 | self.assertAllClose(averaged_single_model, [1, 2.4, 2.4, 3.4, 3.4, 0.5, 1])
125 |
126 | averaged_multi_model = wbf.average_detections((d1, d2), 3)
127 | self.assertAllClose(averaged_multi_model,
128 | [1, 2.4, 2.4, 3.4, 3.4, 0.333333, 1])
129 |
130 | averaged_single_detection = wbf.average_detections((d2,), 2)
131 | self.assertAllClose(averaged_single_detection, [1, 3, 3, 4, 4, 0.35, 1])
132 |
133 | def test_ensemble_boxes(self):
134 | d1 = tf.constant([1, 2, 1, 10, 1, 0.75, 1], dtype=tf.float32)
135 | d2 = tf.constant([1, 3, 1, 10, 1, 0.75, 1], dtype=tf.float32)
136 | d3 = tf.constant([1, 3, 1, 10, 1, 1, 2], dtype=tf.float32)
137 |
138 | ensembled = wbf.ensemble_detections({'num_classes': 3},
139 | tf.stack([d1, d2, d3]), 2)
140 |
141 | self.assertAllClose(ensembled,
142 | [[1, 2.5, 1, 10, 1, 0.75, 1], [1, 3, 1, 10, 1, 0.5, 2]])
143 |
144 |
145 | if __name__ == '__main__':
146 | logging.set_verbosity(logging.WARNING)
147 | tf.test.main()
148 |
--------------------------------------------------------------------------------
/efficientdet/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for utils."""
16 | import os
17 | from absl import logging
18 | import tensorflow.compat.v1 as tf
19 |
20 | import utils
21 |
22 |
23 | class UtilsTest(tf.test.TestCase):
24 |
25 | def setUp(self):
26 | super(UtilsTest, self).setUp()
27 | self.model_dir = os.path.join(tf.test.get_temp_dir(), 'model_dir')
28 |
29 | def build_model(self):
30 | x = tf.Variable(1.0)
31 | y = tf.Variable(2.0)
32 | z = x + y
33 | return z
34 |
35 | def test_archive_ckpt(self):
36 | model_dir = os.path.join(tf.test.get_temp_dir(), 'model_dir')
37 | ckpt_path = os.path.join(model_dir, 'ckpt')
38 | self.build_model()
39 | saver = tf.train.Saver()
40 | with self.session() as sess:
41 | sess.run(tf.global_variables_initializer())
42 | saver.save(sess, ckpt_path)
43 |
44 | # Save checkpoint if the new objective is better.
45 | self.assertTrue(utils.archive_ckpt('eval1', 0.1, ckpt_path))
46 | logging.info(os.listdir(model_dir))
47 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive')))
48 | self.assertFalse(tf.io.gfile.exists(os.path.join(model_dir, 'backup')))
49 |
50 | # Save checkpoint if the new objective is better.
51 | self.assertTrue(utils.archive_ckpt('eval2', 0.2, ckpt_path))
52 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive')))
53 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'backup')))
54 |
55 | # Skip checkpoint if the new objective is worse.
56 | self.assertFalse(utils.archive_ckpt('eval3', 0.1, ckpt_path))
57 |
58 | # Save checkpoint if the new objective is better.
59 | self.assertTrue(utils.archive_ckpt('eval4', 0.3, ckpt_path))
60 |
61 | # Save checkpoint if the new objective is equal.
62 | self.assertTrue(utils.archive_ckpt('eval5', 0.3, ckpt_path))
63 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive')))
64 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'backup')))
65 |
66 | def test_image_size(self):
67 | self.assertEqual(utils.parse_image_size('1280x640'), (640, 1280))
68 | self.assertEqual(utils.parse_image_size(1280), (1280, 1280))
69 | self.assertEqual(utils.parse_image_size((1280, 640)), (1280, 640))
70 |
71 | def test_get_feat_sizes(self):
72 | feats = utils.get_feat_sizes(640, 2)
73 | self.assertEqual(feats, [{
74 | 'height': 640,
75 | 'width': 640
76 | }, {
77 | 'height': 320,
78 | 'width': 320
79 | }, {
80 | 'height': 160,
81 | 'width': 160
82 | }])
83 |
84 | feats = utils.get_feat_sizes((640, 300), 2)
85 | self.assertEqual(feats, [{
86 | 'height': 640,
87 | 'width': 300,
88 | }, {
89 | 'height': 320,
90 | 'width': 150,
91 | }, {
92 | 'height': 160,
93 | 'width': 75,
94 | }])
95 |
96 | def test_precision_float16(self):
97 | def _model(inputs):
98 | x = tf.ones((4, 4, 4, 4), dtype='float32')
99 | conv = tf.keras.layers.Conv2D(filters=4, kernel_size=2, use_bias=False)
100 | a = tf.Variable(1.0)
101 | return tf.cast(a, inputs.dtype) * conv(x) * inputs
102 |
103 | x = tf.constant(2.0, dtype=tf.float32) # input can be any type.
104 | out = utils.build_model_with_precision('mixed_float16', _model, x)
105 | # Variables should be float32.
106 | for v in tf.global_variables():
107 | self.assertIn(v.dtype, (tf.float32, tf.dtypes.as_dtype('float32_ref')))
108 | self.assertIs(out.dtype, tf.float16) # output should be float16.
109 |
110 |
111 | class ActivationTest(tf.test.TestCase):
112 |
113 | def test_swish(self):
114 | features = tf.constant([.5, 10])
115 |
116 | result = utils.activation_fn(features, 'swish')
117 | expected = features * tf.sigmoid(features)
118 | self.assertAllClose(result, expected)
119 |
120 | result = utils.activation_fn(features, 'swish_native')
121 | self.assertAllClose(result, expected)
122 |
123 | def test_hswish(self):
124 | features = tf.constant([.5, 10])
125 | result = utils.activation_fn(features, 'hswish')
126 | self.assertAllClose(result, [0.29166667, 10.0])
127 |
128 | def test_relu(self):
129 | features = tf.constant([.5, 10])
130 | result = utils.activation_fn(features, 'relu')
131 | self.assertAllClose(result, [0.5, 10])
132 |
133 | def test_relu6(self):
134 | features = tf.constant([.5, 10])
135 | result = utils.activation_fn(features, 'relu6')
136 | self.assertAllClose(result, [0.5, 6])
137 |
138 | def test_mish(self):
139 | features = tf.constant([.5, 10])
140 | result = utils.activation_fn(features, 'mish')
141 | self.assertAllClose(result, [0.37524524, 10.0])
142 |
143 |
144 | if __name__ == '__main__':
145 | logging.set_verbosity(logging.WARNING)
146 | tf.disable_eager_execution()
147 | tf.test.main()
148 |
--------------------------------------------------------------------------------
/efficientdet/visualize/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | # Visualization library is mostly based on TensorFlow object detection API:
16 | # https://github.com/tensorflow/models/tree/master/research/object_detection
17 |
--------------------------------------------------------------------------------
/efficientdet/visualize/static_shape.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Helper functions to access TensorShape values.
16 |
17 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth].
18 | """
19 |
20 |
21 | def get_dim_as_int(dim):
22 | """Utility to get v1 or v2 TensorShape dim as an int.
23 |
24 | Args:
25 | dim: The TensorShape dimension to get as an int
26 |
27 | Returns:
28 | None or an int.
29 | """
30 | try:
31 | return dim.value
32 | except AttributeError:
33 | return dim
34 |
35 |
36 | def get_batch_size(tensor_shape):
37 | """Returns batch size from the tensor shape.
38 |
39 | Args:
40 | tensor_shape: A rank 4 TensorShape.
41 |
42 | Returns:
43 | An integer representing the batch size of the tensor.
44 | """
45 | tensor_shape.assert_has_rank(rank=4)
46 | return get_dim_as_int(tensor_shape[0])
47 |
48 |
49 | def get_height(tensor_shape):
50 | """Returns height from the tensor shape.
51 |
52 | Args:
53 | tensor_shape: A rank 4 TensorShape.
54 |
55 | Returns:
56 | An integer representing the height of the tensor.
57 | """
58 | tensor_shape.assert_has_rank(rank=4)
59 | return get_dim_as_int(tensor_shape[1])
60 |
61 |
62 | def get_width(tensor_shape):
63 | """Returns width from the tensor shape.
64 |
65 | Args:
66 | tensor_shape: A rank 4 TensorShape.
67 |
68 | Returns:
69 | An integer representing the width of the tensor.
70 | """
71 | tensor_shape.assert_has_rank(rank=4)
72 | return get_dim_as_int(tensor_shape[2])
73 |
74 |
75 | def get_depth(tensor_shape):
76 | """Returns depth from the tensor shape.
77 |
78 | Args:
79 | tensor_shape: A rank 4 TensorShape.
80 |
81 | Returns:
82 | An integer representing the depth of the tensor.
83 | """
84 | tensor_shape.assert_has_rank(rank=4)
85 | return get_dim_as_int(tensor_shape[3])
86 |
--------------------------------------------------------------------------------
/efficientnetv2/autoaugment_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for autoaugment."""
16 | import tensorflow.compat.v1 as tf
17 | import autoaugment
18 |
19 |
20 | class AutoaugmentTest(tf.test.TestCase):
21 |
22 | def test_autoaugment(self):
23 | """Smoke test to be sure no syntax errors."""
24 | image = tf.zeros((224, 224, 3), dtype=tf.uint8)
25 | aug_image = autoaugment.distort_image_with_autoaugment(image, 'v0')
26 | self.assertEqual((224, 224, 3), aug_image.shape)
27 |
28 | def test_randaug(self):
29 | """Smoke test to be sure no syntax errors."""
30 | num_layers = 2
31 | magnitude = 15
32 | image = tf.zeros((224, 224, 3), dtype=tf.uint8)
33 | aug_image = autoaugment.distort_image_with_randaugment(
34 | image, num_layers, magnitude)
35 | self.assertEqual((224, 224, 3), aug_image.shape)
36 |
37 |
38 | if __name__ == '__main__':
39 | tf.test.main()
40 |
--------------------------------------------------------------------------------
/efficientnetv2/cflags.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Command line flags, shared by all executable binaries."""
16 | from absl import flags
17 | FLAGS = flags.FLAGS
18 |
19 |
20 | def define_flags():
21 | """A shared function to define flags."""
22 | flags.DEFINE_string('model_name', 'efficientnetv2-b0', 'model name.')
23 | flags.DEFINE_string('dataset_cfg', 'Imagenet', 'dataset config name.')
24 | flags.DEFINE_string('hparam_str', '', 'Comma separated k=v pairs of hparams.')
25 | flags.DEFINE_string('sweeps', '', 'Comma separated k=v pairs for sweeping.')
26 | flags.DEFINE_bool('use_tpu', True, 'If true, use TPU; otherwise use CPU/GPU.')
27 | flags.DEFINE_string('tpu_job_name', None, 'job name, default to tpu_worker.')
28 | # Cloud TPU Cluster Resolvers
29 | flags.DEFINE_string('tpu', None, 'address e.g. grpc://ip.address.of.tpu:8470')
30 | flags.DEFINE_string('gcp_project', None, 'Project name.')
31 | flags.DEFINE_string('tpu_zone', None, 'GCE zone')
32 | # Model specific flags
33 | flags.DEFINE_string('data_dir', None, 'The directory for training images.')
34 | flags.DEFINE_string('eval_name', None, 'Evaluation name.')
35 | flags.DEFINE_bool('archive_ckpt', True, 'If true, archive the best ckpt.')
36 | flags.DEFINE_string('model_dir', None, 'Dir for checkpoint and summaries.')
37 | flags.DEFINE_string('mode', 'train', 'One of {"train", "eval"}.')
38 | flags.DEFINE_bool('export_to_tpu', False, 'Export metagraph.')
39 |
--------------------------------------------------------------------------------
/efficientnetv2/datasets_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests datasets."""
16 | import tensorflow as tf
17 | import datasets
18 |
19 |
20 | class ImagenetInputTest(tf.test.TestCase):
21 |
22 | def test_imagenet(self):
23 | ds_class = datasets.get_dataset_class('imagenet')
24 | ds = ds_class(
25 | is_training=False,
26 | data_dir='null',
27 | cache=False,
28 | image_size=224,
29 | image_dtype=None,
30 | augname=None,
31 | mixup_alpha=0,
32 | ra_num_layers=2,
33 | ra_magnitude=20)
34 | params = {'batch_size': 2}
35 | for _, labels in ds.input_fn(params):
36 | label = labels['label']
37 | self.assertAllClose(label[:, 0:4], [[0, 0, 0, 0], [0, 0, 0, 0]])
38 | break
39 |
40 | def test_imagenet21k(self):
41 | ds_class = datasets.get_dataset_class('imagenet21k')
42 | ds = ds_class(
43 | is_training=False,
44 | data_dir='null',
45 | cache=False,
46 | image_size=224,
47 | image_dtype=None,
48 | augname=None,
49 | mixup_alpha=0,
50 | ra_num_layers=2,
51 | ra_magnitude=20)
52 | params = {'batch_size': 2}
53 | for _, labels in ds.input_fn(params):
54 | label = labels['label']
55 | self.assertAllClose(label[:, 0:4], [[0, 0, 1, 1], [0, 0, 1, 1]])
56 | break
57 |
58 |
59 | class DatasetConfigTest(tf.test.TestCase):
60 |
61 | def test_dataset_config(self):
62 | cfg = datasets.get_dataset_config('cifar10ft')
63 | self.assertEqual(cfg.data.ds_name, 'cifar10')
64 |
65 |
66 | if __name__ == '__main__':
67 | tf.test.main()
68 |
--------------------------------------------------------------------------------
/efficientnetv2/effnetv2_configs_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests fo effnetv2_configs."""
16 | from absl import logging
17 | import tensorflow as tf
18 | import effnetv2_configs
19 |
20 |
21 | class EffnetV2ConfigsTest(tf.test.TestCase):
22 |
23 | def test_model_config(self):
24 | cfg = effnetv2_configs.get_model_config('efficientnet-b0')
25 | self.assertEqual(cfg.model.model_name, 'efficientnet-b0')
26 |
27 | cfg = effnetv2_configs.get_model_config('efficientnetv2-s')
28 | self.assertEqual(cfg.model.model_name, 'efficientnetv2-s')
29 |
30 |
31 | if __name__ == '__main__':
32 | logging.set_verbosity(logging.WARNING)
33 | tf.test.main()
34 |
35 |
--------------------------------------------------------------------------------
/efficientnetv2/effnetv2_model_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for effnetv2_model."""
16 | from absl import logging
17 | from absl.testing import parameterized
18 | import tensorflow.compat.v1 as tf
19 | import effnetv2_model
20 |
21 |
22 | class EffNetV2ModelTest(tf.test.TestCase, parameterized.TestCase):
23 |
24 | @parameterized.named_parameters(('v1_b0', 'efficientnet-b0', 5330564),
25 | ('v1_b1', 'efficientnet-b1', 7856232),
26 | ('v1_b2', 'efficientnet-b2', 9177562),
27 | ('v1_b3', 'efficientnet-b3', 12314268),
28 | ('v1_b4', 'efficientnet-b4', 19466816),
29 | ('v1_b5', 'efficientnet-b5', 30562520),
30 | ('v1_b6', 'efficientnet-b6', 43265136))
31 | def test_effnetv1(self, model_name, expected_params):
32 | images = tf.zeros((1, 224, 224, 3), dtype=tf.float32)
33 | model = effnetv2_model.EffNetV2Model(model_name)
34 | _ = model(images)
35 | self.assertEqual(model.count_params(), expected_params)
36 |
37 | @parameterized.named_parameters(('v1-b0', 'efficientnetv2-b0', 7200312),
38 | ('v1-b1', 'efficientnetv2-b1', 8212124),
39 | ('v1-b2', 'efficientnetv2-b2', 10178374),
40 | ('v1-b3', 'efficientnetv2-b3', 14467622),
41 | ('s', 'efficientnetv2-s', 21612360),
42 | ('m', 'efficientnetv2-m', 54431388),
43 | ('l', 'efficientnetv2-l', 119027848),
44 | ('xl', 'efficientnetv2-xl', 208896832))
45 | def test_effnetv2(self, model_name, expected_params):
46 | images = tf.zeros((10, 224, 224, 3), dtype=tf.float32)
47 | model = effnetv2_model.EffNetV2Model(model_name)
48 | _ = model(images)
49 | self.assertEqual(model.count_params(), expected_params)
50 |
51 |
52 | if __name__ == '__main__':
53 | logging.set_verbosity(logging.WARNING)
54 | tf.test.main()
55 |
--------------------------------------------------------------------------------
/efficientnetv2/g3doc/effnetv2-l-gpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/effnetv2-l-gpu.png
--------------------------------------------------------------------------------
/efficientnetv2/g3doc/effnetv2-m-gpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/effnetv2-m-gpu.png
--------------------------------------------------------------------------------
/efficientnetv2/g3doc/effnetv2-s-gpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/effnetv2-s-gpu.png
--------------------------------------------------------------------------------
/efficientnetv2/g3doc/effnetv2-s-relu6-gpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/effnetv2-s-relu6-gpu.png
--------------------------------------------------------------------------------
/efficientnetv2/g3doc/param_flops.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/param_flops.png
--------------------------------------------------------------------------------
/efficientnetv2/g3doc/train_params.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/train_params.png
--------------------------------------------------------------------------------
/efficientnetv2/mlir.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A simple example on how to export MLIR."""
16 | import os
17 | from absl import app
18 | from absl import flags
19 | from absl import logging
20 | import tensorflow as tf
21 |
22 | import effnetv2_model
23 | import utils
24 | # pylint: disable=g-direct-tensorflow-import
25 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
26 | from tensorflow.lite.python.util import get_grappler_config
27 | from tensorflow.lite.python.util import run_graph_optimizations
28 |
29 | FLAGS = flags.FLAGS
30 |
31 |
32 | def define_flags():
33 | """Define all flags for binary run."""
34 | flags.DEFINE_string('model_dir', None, 'Location of the checkpoint to run.')
35 | flags.DEFINE_string('model_name', 'efficientnetv2-b0', 'Model name to use.')
36 | flags.DEFINE_string('dataset_cfg', 'Imagenet', 'dataset config name.')
37 | flags.DEFINE_string('hparam_str', '', 'k=v,x=y pairs or yaml file.')
38 | flags.DEFINE_string('export_dir', None, 'Export or saved model directory')
39 |
40 |
41 | def main(_):
42 | """Export model to MLIR."""
43 | model = effnetv2_model.get_model(
44 | FLAGS.model_name,
45 | FLAGS.hparam_str,
46 | include_top=True,
47 | pretrained=FLAGS.model_dir or True)
48 | # Use call (not build) to match the namescope: tensorflow issues/29576
49 | model(tf.ones([1, 224, 224, 3]), False)
50 | if FLAGS.model_dir:
51 | ckpt = FLAGS.model_dir
52 | if tf.io.gfile.isdir(ckpt):
53 | ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
54 | utils.restore_tf2_ckpt(model, ckpt, exclude_layers=('_head', 'optimizer'))
55 | model.summary()
56 |
57 | fff = tf.function(model).get_concrete_function(
58 | tf.TensorSpec([1, 224, 224, 3], tf.float32))
59 |
60 | frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(fff)
61 |
62 | input_tensors = [
63 | tensor for tensor in frozen_func.inputs if tensor.dtype != tf.resource
64 | ]
65 | output_tensors = frozen_func.outputs
66 |
67 | graph_def = run_graph_optimizations(
68 | graph_def,
69 | input_tensors,
70 | output_tensors,
71 | config=get_grappler_config([
72 | 'pruning', 'function', 'constfold', 'shape', 'remap', 'memory',
73 | 'common_subgraph_elimination', 'arithmetic', 'loop', 'dependency',
74 | 'debug_stripper'
75 | ]),
76 | graph=frozen_func.graph)
77 |
78 | tf_mlir_graph = tf.mlir.experimental.convert_graph_def(graph_def)
79 |
80 | print('export model to {}.mlir'.format(FLAGS.model_name))
81 | export_dir = FLAGS.export_dir
82 | if export_dir is None:
83 | export_dir = '.'
84 | os.makedirs(export_dir, exist_ok=True)
85 | outfile = open('{}/{}.mlir'.format(export_dir, FLAGS.model_name), 'wb')
86 | outfile.write(tf_mlir_graph.encode())
87 | outfile.close()
88 |
89 |
90 | if __name__ == '__main__':
91 | logging.set_verbosity(logging.ERROR)
92 | define_flags()
93 | app.run(main)
94 |
--------------------------------------------------------------------------------
/efficientnetv2/preprocessing_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for preprocessing."""
16 | from absl import logging
17 | from absl.testing import parameterized
18 | import tensorflow as tf
19 | import preprocessing
20 |
21 |
22 | class PreprocessingTest(tf.test.TestCase, parameterized.TestCase):
23 |
24 | @parameterized.parameters('effnetv1_autoaug', 'effnetv1_randaug', None)
25 | def test_preprocessing_legacy(self, augname):
26 | image = tf.zeros((300, 300, 3), dtype=tf.float32)
27 | try:
28 | preprocessing.preprocess_image(image, 224, False, None, augname)
29 | except tf.errors.InvalidArgumentError as e:
30 | if 'ExtractJpegShape' not in str(e):
31 | raise e
32 |
33 | @parameterized.parameters('autoaug', 'randaug', 'ft', 'ft_autoaug', None)
34 | def test_preprocessing(self, augname):
35 | image = tf.zeros((300, 300, 3), dtype=tf.float32)
36 | preprocessing.preprocess_image(image, 224, True, None, augname)
37 |
38 |
39 | if __name__ == '__main__':
40 | logging.set_verbosity(logging.WARNING)
41 | tf.test.main()
42 |
--------------------------------------------------------------------------------
/efficientnetv2/smoke_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for EfficientNetV2 train smoke tests."""
16 | import sys
17 | import tempfile
18 | from absl import flags
19 | from absl.testing import flagsaver
20 | import tensorflow as tf
21 | import main as main_lib
22 |
23 | FLAGS = flags.FLAGS
24 | GPU_TEST = 'gpu_test' in sys.argv[0]
25 | TPU_TEST = 'test_tpu' in sys.argv[0]
26 |
27 |
28 | class EfficientNetV2Test(tf.test.TestCase):
29 |
30 | @classmethod
31 | def setUpClass(cls):
32 | super().setUpClass()
33 | FLAGS.tpu = ''
34 | FLAGS.model_dir = tempfile.mkdtemp()
35 | FLAGS.data_dir = 'null'
36 | cls.hparam_str = (
37 | 'train.batch_size=2,eval.batch_size=2,train.epochs=0,train.min_steps=1,'
38 | 'train.stages=0,train.lr_base=0,data.splits.eval.num_images=6,')
39 |
40 | def _run_single_step_train_and_eval(self, hparam_str=''):
41 | """Single step run with TPUEstimator."""
42 | FLAGS.hparam_str = self.hparam_str + hparam_str
43 | FLAGS.mode = 'train'
44 | main_lib.main(None)
45 |
46 | tf.compat.v1.reset_default_graph()
47 | FLAGS.mode = 'eval'
48 | main_lib.main(None)
49 |
50 | @flagsaver.flagsaver(
51 | use_tpu=False, model_name='efficientnetv2-s', dataset_cfg='ImageNet')
52 | def test_cpu_b0_model_single_step(self):
53 | self._run_single_step_train_and_eval()
54 |
55 | @flagsaver.flagsaver(use_tpu=True)
56 | def test_tpu_b0_model_bfloat_single_step(self):
57 | if TPU_TEST:
58 | self._run_single_step_train_and_eval('')
59 | else:
60 | self.skipTest('Skip because no TPU is available.')
61 |
62 | @flagsaver.flagsaver(use_tpu=False)
63 | def test_tpu_b0_model_single_step_gpu(self):
64 | if GPU_TEST:
65 | # Disables export as tflite does not support NCHW layout.
66 | self._run_single_step_train_and_eval('model.data_format=channels_first')
67 | else:
68 | self.skipTest('Skip because no GPU is available.')
69 |
70 |
71 | if __name__ == '__main__':
72 | tf.compat.v1.disable_eager_execution()
73 | tf.test.main()
74 |
--------------------------------------------------------------------------------
/efficientnetv2/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for utils."""
16 |
17 | import tensorflow as tf
18 | import utils
19 |
20 |
21 | class UtilsTest(tf.test.TestCase):
22 |
23 | def test_constant_lr(self):
24 | constant_schedule = utils.WarmupLearningRateSchedule(
25 | 1.0, lr_decay_type='constant', warmup_epochs=None)
26 |
27 | lr = constant_schedule(10)
28 | self.assertAllClose(lr, 1.0)
29 |
30 | def test_linear_lr(self):
31 | linear_schedule = utils.WarmupLearningRateSchedule(
32 | 1.0, total_steps=10, lr_decay_type='linear', warmup_epochs=None)
33 |
34 | lr = linear_schedule(0)
35 | self.assertAllClose(lr, 1.0)
36 |
37 | lr = linear_schedule(5)
38 | self.assertAllClose(lr, 0.5)
39 |
40 | lr = linear_schedule(10)
41 | self.assertAllClose(lr, 0.0)
42 |
43 | def test_cosine_lr(self):
44 | cosine_schedule = utils.WarmupLearningRateSchedule(
45 | 1.0, total_steps=10, lr_decay_type='cosine', warmup_epochs=None)
46 |
47 | lr = cosine_schedule(4)
48 | self.assertAllClose(lr, 0.654508)
49 |
50 | lr = cosine_schedule(5)
51 | self.assertAllClose(lr, 0.5)
52 |
53 | lr = cosine_schedule(6)
54 | self.assertAllClose(lr, 0.345491)
55 |
56 | def test_exponential_lr(self):
57 | exponential_schedule = utils.WarmupLearningRateSchedule(
58 | 1.0,
59 | total_steps=100,
60 | steps_per_epoch=10,
61 | decay_epochs=2,
62 | decay_factor=0.5,
63 | lr_decay_type='exponential',
64 | warmup_epochs=None)
65 |
66 | lr = exponential_schedule(5)
67 | self.assertAllClose(lr, 1.0)
68 |
69 | lr = exponential_schedule(25)
70 | self.assertAllClose(lr, 0.5)
71 |
72 | lr = exponential_schedule(70)
73 | self.assertAllClose(lr, 0.125)
74 |
75 | def test_warmup(self):
76 | warmup_schedule = utils.WarmupLearningRateSchedule(
77 | 1.0,
78 | total_steps=100,
79 | steps_per_epoch=10,
80 | warmup_epochs=2,
81 | lr_decay_type='constant')
82 |
83 | lr = warmup_schedule(5)
84 | self.assertAllClose(lr, 0.25)
85 |
86 | lr = warmup_schedule(35)
87 | self.assertAllClose(lr, 1.0)
88 |
89 |
90 | if __name__ == '__main__':
91 | tf.test.main()
92 |
--------------------------------------------------------------------------------
/hero/example_commands.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Example command for installing dependencies.
4 | # JAX installation is environment-specific (CPU, GPU, TPU). Check the official JAX installation guide at https://docs.jax.dev/en/latest/installation.html.
5 | # Examples:
6 | # CPU: `pip install -U jax`
7 | # GPU: `pip install -U "jax[cuda12]"` (Replace `cuda12`; ensure CUDA/cuDNN/drivers are correct)
8 | # TPU: `pip install -U "jax[tpu]"`
9 | pip install -r requirements.txt
10 |
11 | # Example command for local run.
12 | # Add "export JAX_DISABLE_JIT=True;" to disable `jit` for easier debugging.
13 | # Change "TransformerLMTest" to other experiment config names in `config_lib.py` to run other experiments.
14 | EXP=local_test_1; rm -rf /tmp/${EXP}; python main.py --experiment_config TransformerLMTest --experiment_dir /tmp/${EXP} --verbosity=-1
15 |
16 | # Example command for checking learning curves with tensorboard.
17 | tensorboard --logdir /tmp/${EXP}
18 |
--------------------------------------------------------------------------------
/hero/fn_lib.py:
--------------------------------------------------------------------------------
1 | """The functions used in the program search.
2 |
3 | This library implements the functions that will be used in the program search.
4 | Having a shared library of functions will make it easier to share the discovered
5 | algorithms between different tasks. For example, if we want to evaluate one
6 | algorithm on both language and vision tasks.
7 | """
8 |
9 | from typing import Sequence, Optional
10 |
11 | import jax
12 | import jax.nn
13 |
14 | import jax.numpy as jnp
15 |
16 | import core
17 |
18 |
19 | @jax.jit
20 | def interpolate(x, y, weight):
21 | return core.tree_add(
22 | core.tree_mult(1.0 - weight, x), core.tree_mult(weight, y))
23 |
24 |
25 | @jax.jit
26 | def global_norm(tree):
27 | leaves = jax.tree_leaves(tree)
28 | norm = jnp.sqrt(sum([jnp.vdot(x, x) for x in leaves]))
29 | return norm
30 |
31 |
32 | @jax.jit
33 | def tree_dot(tree1, tree2):
34 | tree_result = jax.tree_map(jnp.vdot, tree1, tree2)
35 | return sum(jax.tree_leaves(tree_result))
36 |
37 |
38 | @jax.jit
39 | def tree_cosine_sim(tree1, tree2):
40 | tree_result = jax.tree_map(jnp.vdot, tree1, tree2)
41 | dot_result = sum(jax.tree_leaves(tree_result))
42 | norm1 = global_norm(tree1)
43 | norm2 = global_norm(tree2)
44 | return dot_result / (norm1 * norm2)
45 |
46 |
47 | @jax.jit
48 | def clip_by_global_norm(tree, clip_norm):
49 | l2_g = global_norm(tree)
50 | g_factor = jnp.minimum(1.0, clip_norm / l2_g)
51 | return core.tree_mult(g_factor, tree)
52 |
53 |
54 | def get_math_fns(allowed_fns: Optional[Sequence[str]] = None):
55 | """Get the dictionary containing the math functions."""
56 |
57 | fn_dict = {}
58 |
59 | noarg_fn_dict = dict(
60 | get_pi=lambda: jnp.pi, get_e=lambda: jnp.e, get_eps=lambda: 1e-8)
61 |
62 | for k, v in noarg_fn_dict.items():
63 | fn_dict[k] = core.Function(v, 0, [])
64 |
65 | def nonneg(f):
66 | def g(x):
67 | return f(jnp.fabs(x))
68 | return g
69 |
70 | def map_to_tree(f):
71 | def g(x):
72 | return jax.tree_map(f, x)
73 | return g
74 |
75 | unary_fn_dict = dict(
76 | abs=jnp.abs,
77 | cos=jnp.cos,
78 | sin=jnp.sin,
79 | tan=jnp.tan,
80 | arcsin=jnp.arcsin,
81 | arccos=jnp.arccos,
82 | arctan=jnp.arctan,
83 | sinh=jnp.sinh,
84 | cosh=jnp.cosh,
85 | tanh=jnp.tanh,
86 | arcsinh=jnp.arcsinh,
87 | arccosh=jnp.arccosh,
88 | arctanh=jnp.arctanh,
89 | exp=jnp.exp,
90 | exp2=jnp.exp2,
91 | exp10=lambda x: jnp.power(10, x),
92 | expm1=jnp.expm1,
93 | log=nonneg(jnp.log),
94 | log10=nonneg(jnp.log10),
95 | log2=nonneg(jnp.log2),
96 | log1p=lambda x: jnp.log(jnp.fabs(1 + x)),
97 | square=jnp.square,
98 | sqrt=nonneg(jnp.sqrt),
99 | cube=lambda x: jnp.power(x, 3),
100 | cbrt=lambda x: jnp.cbrt,
101 | sign=jnp.sign,
102 | reciprocal=jnp.reciprocal,
103 | norm=jnp.linalg.norm,
104 | invert=jnp.invert,
105 | negative=jnp.negative)
106 |
107 | for k, v in unary_fn_dict.items():
108 | fn_dict[k] = core.Function(map_to_tree(v), 1, [core.is_numeric])
109 |
110 | fn_dict['global_norm'] = core.Function(global_norm, 1, [core.is_numeric])
111 |
112 | fn_dict['interpolate'] = core.Function(
113 | interpolate, 3,
114 | [core.ExampleAnnotation(1.0).check, core.is_numeric, core.is_numeric])
115 |
116 | fn_dict['dot'] = core.Function(
117 | tree_dot, 2, [core.is_numeric, core.is_numeric])
118 |
119 | fn_dict['cosine_sim'] = core.Function(
120 | tree_cosine_sim, 2, [core.is_numeric, core.is_numeric])
121 |
122 | fn_dict['clip_by_global_norm'] = core.Function(
123 | clip_by_global_norm, 2,
124 | [core.is_numeric, core.ExampleAnnotation(1.0).check])
125 |
126 | fn_dict['power'] = core.Function(
127 | jnp.power, 2, [core.is_numeric, core.is_numeric])
128 |
129 | if allowed_fns is not None:
130 | new_fn_dict = {}
131 | allowed_fns_set = set(allowed_fns)
132 | for key in fn_dict:
133 | if key in allowed_fns_set:
134 | new_fn_dict[key] = fn_dict[key]
135 | else:
136 | new_fn_dict = fn_dict
137 |
138 | return new_fn_dict
139 |
140 |
--------------------------------------------------------------------------------
/hero/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Chen Liang
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Simply a language model."""
15 |
16 | import os
17 | import re
18 | from typing import Sequence
19 |
20 | from absl import app
21 | from absl import flags
22 | from absl import logging
23 | import config_lib
24 | import data_lib
25 | import model_lib
26 |
27 | _EXPERIMENT_CONFIG = flags.DEFINE_string(
28 | 'experiment_config', 'TransformerLMTest', 'Name of the experiment config.')
29 |
30 | _SHARDING_CONFIG = flags.DEFINE_string(
31 | 'sharding_config', 'GSPMDSharding', 'Name of the sharding config.')
32 |
33 | _EXPERIMENT_DIR = flags.DEFINE_string(
34 | 'experiment_dir', '/tmp/simply_lm/', 'Path to save the experiment data.')
35 |
36 | _MESH_SHAPE = flags.DEFINE_list(
37 | 'mesh_shape',
38 | None,
39 | 'Shape for the mesh, comma separated integers, e.g. 1,265,1',
40 | )
41 |
42 | _DCN_MESH_SHAPE = flags.DEFINE_list(
43 | 'dcn_mesh_shape',
44 | None,
45 | 'Shape for the dcn mesh, comma separated integers, e.g. 2,1,1',
46 | )
47 |
48 |
49 | def main(argv: Sequence[str]) -> None:
50 | del argv
51 | if mesh_shape := _MESH_SHAPE.value:
52 | mesh_shape = [int(i) for i in mesh_shape]
53 | if dcn_mesh_shape := _DCN_MESH_SHAPE.value:
54 | dcn_mesh_shape = [int(i) for i in dcn_mesh_shape]
55 | config = config_lib.ExperimentConfigRegistry.get_config(
56 | _EXPERIMENT_CONFIG.value)
57 | sharding_config = config_lib.ShardingConfigRegistry.get_config(
58 | _SHARDING_CONFIG.value)
59 | logging.info('config: %s', config)
60 | logging.info('sharding_config: %s', sharding_config)
61 | logging.info('mesh_shape: %s', mesh_shape)
62 | logging.info('dcn_mesh_shape: %s', dcn_mesh_shape)
63 | model_lib.run_experiment(
64 | config=config, sharding_config=sharding_config,
65 | mesh_shape=mesh_shape,
66 | dcn_mesh_shape=dcn_mesh_shape,
67 | create_dataset=data_lib.create_dataset,
68 | experiment_dir=_EXPERIMENT_DIR.value)
69 |
70 |
71 | _TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+')
72 |
73 | if __name__ == '__main__':
74 | app.run(main)
75 |
--------------------------------------------------------------------------------
/hero/requirements.txt:
--------------------------------------------------------------------------------
1 | jax
2 | orbax
3 | seqio
4 | einops
5 | absl-py
6 | clu
7 | t5
8 |
--------------------------------------------------------------------------------
/hero/vb100864_openmix_v1.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/hero/vb100864_openmix_v1.model
--------------------------------------------------------------------------------
/hero/vb32000_t5_cc.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/hero/vb32000_t5_cc.model
--------------------------------------------------------------------------------
/lion/fig/ablation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/ablation.png
--------------------------------------------------------------------------------
/lion/fig/alg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/alg.png
--------------------------------------------------------------------------------
/lion/fig/basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/basic.png
--------------------------------------------------------------------------------
/lion/fig/diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/diffusion.png
--------------------------------------------------------------------------------
/lion/fig/ft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/ft.png
--------------------------------------------------------------------------------
/lion/fig/i1k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/i1k.png
--------------------------------------------------------------------------------
/lion/fig/imagen.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/imagen.png
--------------------------------------------------------------------------------
/lion/fig/jft-ft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/jft-ft.png
--------------------------------------------------------------------------------
/lion/fig/jft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/jft.png
--------------------------------------------------------------------------------
/lion/fig/lit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/lit.png
--------------------------------------------------------------------------------
/lion/fig/llm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/llm.png
--------------------------------------------------------------------------------
/lion/fig/lm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/lm.png
--------------------------------------------------------------------------------
/lion/fig/retrieval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/retrieval.png
--------------------------------------------------------------------------------
/lion/lion_optax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Optax implementation of the Lion optimizer."""
16 |
17 | from typing import Any, Callable, NamedTuple, Optional, Union
18 |
19 | import chex
20 | import jax
21 | import jax.numpy as jnp
22 | import optax
23 |
24 |
25 | def _scale_by_learning_rate(
26 | learning_rate: optax.ScalarOrSchedule, flip_sign=True):
27 | m = -1 if flip_sign else 1
28 | if callable(learning_rate):
29 | return optax.scale_by_schedule(lambda count: m * learning_rate(count))
30 | return optax.scale(m * learning_rate)
31 |
32 |
33 | def lion(
34 | learning_rate: optax.ScalarOrSchedule,
35 | b1: float = 0.9,
36 | b2: float = 0.99,
37 | mu_dtype: Optional[Any] = None,
38 | weight_decay: float = 0.0,
39 | mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
40 | ) -> optax.GradientTransformation:
41 | """Lion.
42 |
43 | Args:
44 | learning_rate: A fixed global scaling factor.
45 | b1: Exponential decay rate to combine the gradient and the moment.
46 | b2: Exponential decay rate to track the moment of past gradients.
47 | mu_dtype: Optional `dtype` to be used for the first order accumulator; if
48 | `None` then the `dtype` is inferred from `params` and `updates`.
49 | weight_decay: Strength of the weight decay regularization. Note that this
50 | weight decay is multiplied with the learning rate. This is consistent
51 | with other frameworks such as PyTorch, but different from
52 | (Loshchilov et al, 2019) where the weight decay is only multiplied with
53 | the "schedule multiplier", but not the base learning rate.
54 | mask: A tree with same structure as (or a prefix of) the params PyTree,
55 | or a Callable that returns such a pytree given the params/updates.
56 | The leaves should be booleans, `True` for leaves/subtrees you want to
57 | apply the weight decay to, and `False` for those you want to skip. Note
58 | that the Adam gradient transformations are applied to all parameters.
59 |
60 | Returns:
61 | The corresponding `GradientTransformation`.
62 | """
63 | return optax.chain(
64 | scale_by_lion(
65 | b1=b1, b2=b2, mu_dtype=mu_dtype),
66 | optax.add_decayed_weights(weight_decay, mask),
67 | _scale_by_learning_rate(learning_rate),
68 | )
69 |
70 |
71 | def update_moment(updates, moments, decay, order):
72 | """Compute the exponential moving average of the `order`-th moment."""
73 | return jax.tree_util.tree_map(
74 | lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
75 |
76 |
77 | class ScaleByLionState(NamedTuple):
78 | """State for the Lion algorithm."""
79 | count: chex.Array # shape=(), dtype=jnp.int32.
80 | mu: optax.Updates
81 |
82 |
83 | def scale_by_lion(
84 | b1: float = 0.9,
85 | b2: float = 0.99,
86 | mu_dtype: Optional[Any] = None,
87 | ) -> optax.GradientTransformation:
88 | """Rescale updates according to the Lion algorithm.
89 |
90 | Args:
91 | b1: rate for combining moment and the current grad.
92 | b2: decay rate for the exponentially weighted average of grads.
93 | mu_dtype: optional `dtype` to be used for the first order accumulator; if
94 | `None` then the `dtype is inferred from `params` and `updates`.
95 |
96 | Returns:
97 | A `GradientTransformation` object.
98 | """
99 |
100 | def init_fn(params):
101 | mu = jax.tree_util.tree_map( # moment
102 | lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
103 | return ScaleByLionState(count=jnp.zeros([], jnp.int32), mu=mu)
104 |
105 | def update_fn(updates, state, params=None):
106 | del params
107 | mu = update_moment(updates, state.mu, b2, 1)
108 | mu = jax.tree_map(lambda x: x.astype(mu_dtype), mu)
109 | count_inc = optax.safe_int32_increment(state.count)
110 | updates = jax.tree_util.tree_map(
111 | lambda g, m: jnp.sign((1. - b1) * g + b1 * m), updates, state.mu)
112 | return updates, ScaleByLionState(count=count_inc, mu=mu)
113 |
114 | return optax.GradientTransformation(init_fn, update_fn)
115 |
--------------------------------------------------------------------------------
/lion/lion_pytorch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """PyTorch implementation of the Lion optimizer."""
16 | import torch
17 | from torch.optim.optimizer import Optimizer
18 |
19 |
20 | class Lion(Optimizer):
21 | r"""Implements Lion algorithm."""
22 |
23 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
24 | """Initialize the hyperparameters.
25 |
26 | Args:
27 | params (iterable): iterable of parameters to optimize or dicts defining
28 | parameter groups
29 | lr (float, optional): learning rate (default: 1e-4)
30 | betas (Tuple[float, float], optional): coefficients used for computing
31 | running averages of gradient and its square (default: (0.9, 0.99))
32 | weight_decay (float, optional): weight decay coefficient (default: 0)
33 | """
34 |
35 | if not 0.0 <= lr:
36 | raise ValueError('Invalid learning rate: {}'.format(lr))
37 | if not 0.0 <= betas[0] < 1.0:
38 | raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
39 | if not 0.0 <= betas[1] < 1.0:
40 | raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
41 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
42 | super().__init__(params, defaults)
43 |
44 | @torch.no_grad()
45 | def step(self, closure=None):
46 | """Performs a single optimization step.
47 |
48 | Args:
49 | closure (callable, optional): A closure that reevaluates the model
50 | and returns the loss.
51 |
52 | Returns:
53 | the loss.
54 | """
55 | loss = None
56 | if closure is not None:
57 | with torch.enable_grad():
58 | loss = closure()
59 |
60 | for group in self.param_groups:
61 | for p in group['params']:
62 | if p.grad is None:
63 | continue
64 |
65 | # Perform stepweight decay
66 | p.data.mul_(1 - group['lr'] * group['weight_decay'])
67 |
68 | grad = p.grad
69 | state = self.state[p]
70 | # State initialization
71 | if len(state) == 0:
72 | # Exponential moving average of gradient values
73 | state['exp_avg'] = torch.zeros_like(p)
74 |
75 | exp_avg = state['exp_avg']
76 | beta1, beta2 = group['betas']
77 |
78 | # Weight update
79 | update = exp_avg * beta1 + grad * (1 - beta1)
80 |
81 | p.add_(update.sign_(), alpha=-group['lr'])
82 |
83 | # Decay the momentum running average coefficient
84 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
85 |
86 | return loss
87 |
--------------------------------------------------------------------------------
/lion/lion_tf1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF1 implementation of the Lion optimizer."""
16 | from typing import Optional, Union, Callable
17 |
18 | import tensorflow.compat.v1 as tf
19 | from tensorflow.python.ops import resource_variable_ops
20 |
21 | VType = Union[Callable, float, tf.Tensor]
22 |
23 |
24 | class Lion(tf.compat.v1.train.Optimizer):
25 | """Optimizer that implements the discovered algorithm in automl-hero."""
26 |
27 | def __init__(self,
28 | learning_rate: VType = 0.0001,
29 | beta1: VType = 0.9,
30 | beta2: VType = 0.99,
31 | wd: Optional[VType] = 0.0,
32 | use_locking=False,
33 | name="Lion"):
34 | r"""Construct a new Lion optimizer.
35 |
36 | Args:
37 | learning_rate: A Tensor or a floating point value. The learning rate.
38 | beta1: A float value or a constant float tensor. The rate to combine
39 | the gradient and the moment estimate.
40 | beta2: A float value or a constant float tensor. The exponential decay
41 | rate for the moment estimate.
42 | wd: Optional[A float value or a constant float tensor].
43 | The decoupled weight decay.
44 | use_locking: If True use locks for update operations.
45 | name: Optional name for the operations created when applying gradients.
46 | """
47 | super(Lion, self).__init__(use_locking, name)
48 | self._lr = learning_rate
49 | self._beta1 = beta1
50 | self._beta2 = beta2
51 | self._wd = None if isinstance(wd, float) and wd < 0 else wd
52 |
53 | # Tensor versions of the constructor arguments, created in _prepare().
54 | self._lr_t = None
55 | self._beta1_t = None
56 | self._beta2_t = None
57 | self._wd_t = None
58 |
59 | def _create_slots(self, var_list):
60 | # Create slots for the moment.
61 | for v in var_list:
62 | self._zeros_slot(v, "m", self._name)
63 |
64 | def _prepare(self):
65 | lr = self._call_if_callable(self._lr)
66 | beta1 = self._call_if_callable(self._beta1)
67 | beta2 = self._call_if_callable(self._beta2)
68 | wd = self._call_if_callable(self._wd)
69 |
70 | self._lr_t = tf.convert_to_tensor(lr, name="learning_rate")
71 | self._beta1_t = tf.convert_to_tensor(beta1, name="beta1")
72 | self._beta2_t = tf.convert_to_tensor(beta2, name="beta2")
73 | if wd is not None:
74 | self._wd_t = tf.convert_to_tensor(wd, name="weight_decay")
75 |
76 | def _apply_dense_shared(self, grad, var):
77 | m = self.get_slot(var, "m")
78 |
79 | lr_t = tf.cast(self._lr_t, dtype=var.dtype)
80 | beta1_t = tf.cast(self._beta1_t, dtype=var.dtype)
81 | beta2_t = tf.cast(self._beta2_t, dtype=var.dtype)
82 | if self._wd_t is None:
83 | weight_decay_t = None
84 | else:
85 | weight_decay_t = tf.cast(self._wd_t, dtype=var.dtype)
86 |
87 | updates_grad = tf.sign(m * beta1_t + grad * (1. - beta1_t))
88 | if weight_decay_t is not None:
89 | updates_grad = updates_grad + var * weight_decay_t
90 |
91 | var_update = tf.assign_sub(
92 | var, lr_t * updates_grad, use_locking=self._use_locking)
93 | with tf.control_dependencies([var_update]):
94 | m_update = tf.assign(m, m * beta2_t + grad * (1. - beta2_t))
95 | return tf.group(*[var_update, m_update])
96 |
97 | def _apply_dense(self, grad, var):
98 | return self._apply_dense_shared(grad, var)
99 |
100 | def _resource_apply_dense(self, grad, var):
101 | return self._apply_dense_shared(grad, var)
102 |
103 | def _apply_sparse_shared(self, grad, var, indices, scatter_add):
104 | m = self.get_slot(var, "m")
105 | lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
106 | beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
107 | beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
108 | wd_t = tf.cast(self._wd_t, var.dtype.base_dtype)
109 |
110 | m_update = tf.assign(m, m * beta1_t, use_locking=self._use_locking)
111 | with tf.control_dependencies([m_update]):
112 | m_update = scatter_add(m, indices, grad * (1. - beta1_t))
113 | with tf.control_dependencies([m_update]):
114 | var_update = tf.assign_sub(
115 | var,
116 | lr_t * (tf.sign(m) + var * wd_t),
117 | use_locking=self._use_locking)
118 | with tf.control_dependencies([var_update]):
119 | m_update = scatter_add(m, indices, grad * (beta1_t - 1.))
120 | with tf.control_dependencies([m_update]):
121 | m_update = tf.assign(
122 | m, m * beta2_t / beta1_t, use_locking=self._use_locking)
123 | with tf.control_dependencies([m_update]):
124 | m_update = scatter_add(m, indices, grad * (1. - beta2_t))
125 | return tf.group(*[var_update, m_update])
126 |
127 | def _apply_sparse(self, grad, var):
128 | return self._apply_sparse_shared(
129 | grad.values,
130 | var,
131 | grad.indices,
132 | lambda x, i, v: tf.scatter_add(
133 | x,
134 | i,
135 | v,
136 | use_locking=self._use_locking))
137 |
138 | def _resource_scatter_add(self, x, i, v):
139 | with tf.control_dependencies(
140 | [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
141 | return x.value()
142 |
143 | def _resource_apply_sparse(self, grad, var, indices):
144 | return self._apply_sparse_shared(grad, var, indices,
145 | self._resource_scatter_add)
146 |
--------------------------------------------------------------------------------
/lion/lion_tf2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google Research. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF2 implementation of the Lion optimizer."""
16 |
17 | import tensorflow.compat.v2 as tf
18 |
19 |
20 | class Lion(tf.keras.optimizers.legacy.Optimizer):
21 | r"""Optimizer that implements the Lion algorithm."""
22 |
23 | def __init__(self,
24 | learning_rate=0.0001,
25 | beta_1=0.9,
26 | beta_2=0.99,
27 | wd=0,
28 | name='lion',
29 | **kwargs):
30 | """Construct a new Lion optimizer."""
31 |
32 | super(Lion, self).__init__(name, **kwargs)
33 | self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
34 | self._set_hyper('beta_1', beta_1)
35 | self._set_hyper('beta_2', beta_2)
36 | self._set_hyper('wd', wd)
37 |
38 | def _create_slots(self, var_list):
39 | # Create slots for the first and second moments.
40 | # Separate for-loops to respect the ordering of slot variables from v1.
41 | for var in var_list:
42 | self.add_slot(var, 'm')
43 |
44 | def _prepare_local(self, var_device, var_dtype, apply_state):
45 | super(Lion, self)._prepare_local(var_device, var_dtype, apply_state)
46 |
47 | beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype))
48 | beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype))
49 | wd_t = tf.identity(self._get_hyper('wd', var_dtype))
50 | lr = apply_state[(var_device, var_dtype)]['lr_t']
51 | apply_state[(var_device, var_dtype)].update(
52 | dict(
53 | lr=lr,
54 | beta_1_t=beta_1_t,
55 | one_minus_beta_1_t=1 - beta_1_t,
56 | beta_2_t=beta_2_t,
57 | one_minus_beta_2_t=1 - beta_2_t,
58 | wd_t=wd_t))
59 |
60 | @tf.function(jit_compile=True)
61 | def _resource_apply_dense(self, grad, var, apply_state=None):
62 | var_device, var_dtype = var.device, var.dtype.base_dtype
63 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
64 | self._fallback_apply_state(var_device, var_dtype))
65 |
66 | m = self.get_slot(var, 'm')
67 | var_t = var.assign_sub(
68 | coefficients['lr_t'] *
69 | (tf.math.sign(m * coefficients['beta_1_t'] +
70 | grad * coefficients['one_minus_beta_1_t']) +
71 | var * coefficients['wd_t']))
72 | with tf.control_dependencies([var_t]):
73 | m.assign(m * coefficients['beta_2_t'] +
74 | grad * coefficients['one_minus_beta_2_t'])
75 |
76 | @tf.function(jit_compile=True)
77 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
78 | var_device, var_dtype = var.device, var.dtype.base_dtype
79 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
80 | self._fallback_apply_state(var_device, var_dtype))
81 |
82 | m = self.get_slot(var, 'm')
83 | m_t = m.assign(m * coefficients['beta_1_t'])
84 | m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
85 | m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices))
86 | var_t = var.assign_sub(coefficients['lr'] *
87 | (tf.math.sign(m_t) + var * coefficients['wd_t']))
88 |
89 | with tf.control_dependencies([var_t]):
90 | m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices))
91 | m_t = m_t.assign(m_t * coefficients['beta_2_t'] /
92 | coefficients['beta_1_t'])
93 | m_scaled_g_values = grad * coefficients['one_minus_beta_2_t']
94 | m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices))
95 |
96 | def get_config(self):
97 | config = super(Lion, self).get_config()
98 | config.update({
99 | 'learning_rate': self._serialize_hyperparameter('learning_rate'),
100 | 'beta_1': self._serialize_hyperparameter('beta_1'),
101 | 'beta_2': self._serialize_hyperparameter('beta_2'),
102 | 'wd': self._serialize_hyperparameter('wd'),
103 | })
104 | return config
105 |
--------------------------------------------------------------------------------