├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── .pylintrc ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── efficientdet ├── Det-AdvProp.md ├── README.md ├── __init__.py ├── aug │ ├── __init__.py │ ├── autoaugment.py │ ├── autoaugment_test.py │ ├── gridmask.py │ ├── gridmask_test.py │ ├── mosaic.py │ └── mosaic_test.py ├── backbone │ ├── __init__.py │ ├── backbone_factory.py │ ├── efficientnet_builder.py │ ├── efficientnet_builder_test.py │ ├── efficientnet_lite_builder.py │ ├── efficientnet_lite_builder_test.py │ ├── efficientnet_model.py │ └── efficientnet_model_test.py ├── coco_metric.py ├── coco_metric_test.py ├── dataloader.py ├── dataloader_test.py ├── dataset │ ├── README.md │ ├── __init__.py │ ├── create_coco_tfrecord.py │ ├── create_coco_tfrecord_test.py │ ├── create_pascal_tfrecord.py │ ├── create_pascal_tfrecord_test.py │ ├── inspect_tfrecords.py │ ├── label_map_util.py │ └── tfrecord_util.py ├── det_advprop_tutorial.ipynb ├── det_model_fn.py ├── det_model_fn_test.py ├── efficientdet_arch.py ├── efficientdet_arch_test.py ├── g3doc │ ├── Det-AdvProp.png │ ├── coco_ids.yaml │ ├── faq.md │ ├── flops.png │ ├── network.png │ ├── params.png │ └── street.jpg ├── hparams_config.py ├── hparams_config_test.py ├── inference.py ├── install_deps.sh ├── iou_utils.py ├── iou_utils_test.py ├── main.py ├── model_inspect.py ├── model_inspect_test.py ├── nms_np.py ├── object_detection │ ├── __init__.py │ ├── argmax_matcher.py │ ├── box_coder.py │ ├── box_list.py │ ├── faster_rcnn_box_coder.py │ ├── matcher.py │ ├── preprocessor.py │ ├── region_similarity_calculator.py │ ├── shape_utils.py │ ├── target_assigner.py │ └── tf_example_decoder.py ├── requirements.txt ├── run_tflite.py ├── tensorrt.py ├── test.sh ├── test_util.py ├── testdata │ ├── img1-d1.jpg │ └── img1.jpg ├── tf2 │ ├── README.md │ ├── __init__.py │ ├── anchors.py │ ├── efficientdet_keras.py │ ├── efficientdet_keras_test.py │ ├── eval.py │ ├── eval_tflite.py │ ├── fpn_configs.py │ ├── fpn_configs_test.py │ ├── infer.py │ ├── infer_lib.py │ ├── infer_lib_test.py │ ├── inspector.py │ ├── inspector_test.py │ ├── label_util.py │ ├── postprocess.py │ ├── postprocess_test.py │ ├── segmentation.py │ ├── tfmot.py │ ├── train.py │ ├── train_lib.py │ ├── train_lib_test.py │ ├── tutorial.ipynb │ ├── util_keras.py │ ├── util_keras_test.py │ ├── wbf.py │ └── wbf_test.py ├── tutorial.ipynb ├── utils.py ├── utils_test.py └── visualize │ ├── __init__.py │ ├── shape_utils.py │ ├── standard_fields.py │ ├── static_shape.py │ ├── vis_utils.py │ └── vis_utils_test.py ├── efficientnetv2 ├── README.md ├── autoaugment.py ├── autoaugment_test.py ├── cflags.py ├── datasets.py ├── datasets_test.py ├── effnetv2_configs.py ├── effnetv2_configs_test.py ├── effnetv2_model.py ├── effnetv2_model_test.py ├── g3doc │ ├── effnetv2-l-gpu.png │ ├── effnetv2-m-gpu.png │ ├── effnetv2-s-gpu.png │ ├── effnetv2-s-relu6-gpu.png │ ├── imagenet1k_labels.txt │ ├── imagenet21k_labels.txt │ ├── param_flops.png │ └── train_params.png ├── hparams.py ├── infer.py ├── main.py ├── main_tf2.py ├── mlir.py ├── preprocess_legacy.py ├── preprocessing.py ├── preprocessing_test.py ├── smoke_test.py ├── tfhub.ipynb ├── tutorial.ipynb ├── utils.py └── utils_test.py ├── hero ├── LICENSE ├── config_lib.py ├── core.py ├── core_test.py ├── data_lib.py ├── example_commands.sh ├── fn_lib.py ├── main.py ├── model_lib.py ├── requirements.txt ├── vb100864_openmix_v1.model └── vb32000_t5_cc.model └── lion ├── README.md ├── fig ├── ablation.png ├── alg.png ├── basic.png ├── diffusion.png ├── ft.png ├── i1k.png ├── imagen.png ├── jft-ft.png ├── jft.png ├── lit.png ├── llm.png ├── lm.png └── retrieval.png ├── lion_optax.py ├── lion_pytorch.py ├── lion_tf1.py └── lion_tf2.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Efficientdet 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-18.04 16 | strategy: 17 | matrix: 18 | python-version: [3.7, 3.8, 3.9] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install pylint 29 | bash efficientdet/install_deps.sh 30 | - name: Test with pytest 31 | run: | 32 | bash efficientdet/test.sh 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | .vscode 4 | tmp 5 | .ropeproject 6 | .pyc 7 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Brain AutoML 2 | 3 | This repository contains a list of AutoML related models and libraries. 4 | -------------------------------------------------------------------------------- /efficientdet/Det-AdvProp.md: -------------------------------------------------------------------------------- 1 | # Det-AdvProp 2 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google/automl/blob/master/efficientdet/det_advprop_tutorial.ipynb) 3 | 4 | [1] Xiangning Chen, Cihang Xie, Mingxing Tan, Li Zhang, Cho-Jui Hsieh, Boqing 5 | Gong. CVPR 2021. Arxiv link: https://arxiv.org/abs/2103.13886 6 | 7 | Det-AdvProp is a data augmentation technique specifically designed for the 8 | fine-tuning process of object detectors. It can consistently and substantially 9 | outperform the vanilla training and AutoAugment under various settings. The 10 | obtained detector is not only more accurate on clean images, but also more 11 | robust to image distortions and domain shift. 12 | 13 |

14 | 15 |

16 | 17 | ## 1. Accurate on Clean Images 18 | 19 | The following table includes a list of models trained with Det-AdvProp + 20 | AutoAugment (AA): 21 | 22 | Model | APtest | AP50 | AP75 | APS | APM | APL | APval | | #params | #FLOPs 23 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------- | --------------- | --------------- | -------------- | -------------- | -------------- | ---------------- | --- | ------- | :----: 24 | EfficientDet-D0 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d0.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d0.txt)) | 35.3 | 54.1 | 37.8 | 12.7 | 39.9 | 53.2 | 35.1 | | 3.9M | 2.54B 25 | EfficientDet-D1 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d1.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d1.txt)) | 40.9 | 60.0 | 44.1 | 19.1 | 45.6 | 57.2 | 40.8 | | 6.6M | 6.10B 26 | EfficientDet-D2 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d2.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d2.txt)) | 44.3 | 63.5 | 47.9 | 23.5 | 48.5 | 59.9 | 44.3 | | 8.1M | 11.0B 27 | EfficientDet-D3 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d3.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d3.txt)) | 48.0 | 67.1 | 52.2 | 28.1 | 51.8 | 62.8 | 47.7 | | 12.0M | 24.9B 28 | EfficientDet-D4 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d4.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d4.txt)) | 50.4 | 69.5 | 54.9 | 30.9 | 54.3 | 64.4 | 50.4 | | 20.7M | 55.2B 29 | EfficientDet-D5 + Det-AdvProp + AA ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/efficientdet-d5.tar.gz), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/d5.txt)) | 52.5 | 71.8 | 57.2 | 34.6 | 55.9 | 65.2 | 52.2 | | 33.7M | 130B 30 | 31 | Unlike the vanilla EfficientDet that scales the image with mean and std, 32 | here we scale the input to the range of [-1, 1] to make it easier for performing 33 | adversarial attack. Please see [this Colab](https://github.com/google/automl/blob/master/efficientdet/det_advprop_tutorial.ipynb) for reproducing the 34 | results. 35 | 36 | ## 2. Robust Against Common Corruptions 37 | 38 | We test the detectors' robustness against common image corruptions (e.g., 39 | Gaussian Noise, Snow, etc.) based on the COCO-C dataset in 40 | [this paper](https://arxiv.org/abs/1907.07484). The table below shows the 41 | comparison between vanilla training and Det-AdvProp + AutoAugment (AA): 42 | 43 | Model | mAP 44 | ---------------------- | --------------- 45 | EfficientDet-D0 | 21.4 46 | **+ Det-AdvProp + AA** | **22.7 (+1.3)** 47 | EfficientDet-D1 | 24.4 48 | **+ Det-AdvProp + AA** | **26.7 (+2.3)** 49 | EfficientDet-D2 | 26.7 50 | **+ Det-AdvProp + AA** | **28.9 (+2.2)** 51 | EfficientDet-D3 | 28.8 52 | **+ Det-AdvProp + AA** | **32.0 (+3.2)** 53 | EfficientDet-D4 | 30.1 54 | **+ Det-AdvProp + AA** | **33.9 (+3.8)** 55 | EfficientDet-D5 | 31.4 56 | **+ Det-AdvProp + AA** | **35.0 (+3.6)** 57 | 58 | ## 3. Robust Against Domain Shift 59 | 60 | PASCAL VOC 2012 only contains 20 classes, which are much smaller than the 80 61 | labeled classes in COCO. The underlying distributions of the two datasets are 62 | also different in the image content or the bounding box sizes and locations. We 63 | use the trained detectors to run inference directly on the VOC dataset to test 64 | their transferibility. We maintain the COCO evaluation metrics in this 65 | experiment: 66 | 67 | Model | mAP | AP50 | AP75 68 | ---------------------- | --------------- | --------------- | --------------- 69 | EfficientDet-D0 | 55.6 | 77.6 | 61.4 70 | **+ Det-AdvProp + AA** | **56.2 (+0.6)** | **78.3 (+0.7)** | **62.3 (+0.9)** 71 | EfficientDet-D1 | 60.8 | 82.0 | 66.7 72 | **+ Det-AdvProp + AA** | **61.3 (+0.5)** | **82.5 (+0.5)** | **67.6 (+0.9)** 73 | EfficientDet-D2 | 63.3 | 83.6 | 69.3 74 | **+ Det-AdvProp + AA** | **63.6 (+0.3)** | **84.0 (+0.4)** | **70.0 (+0.7)** 75 | EfficientDet-D3 | 65.7 | 85.3 | 71.8 76 | **+ Det-AdvProp + AA** | **66.4 (+0.7)** | **85.9 (+0.6)** | **72.8 (+1.0)** 77 | EfficientDet-D4 | 67.0 | 86.0 | 73.0 78 | **+ Det-AdvProp + AA** | **67.8 (+0.8)** | **87.0 (+1.0)** | **74.3 (+1.3)** 79 | EfficientDet-D5 | 67.4 | 86.9 | 73.8 80 | **+ Det-AdvProp + AA** | **68.7 (+1.3)** | **88.0 (+1.1)** | **75.4 (+1.6)** 81 | -------------------------------------------------------------------------------- /efficientdet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /efficientdet/aug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /efficientdet/aug/autoaugment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Autoaugment.""" 16 | from absl import logging 17 | import tensorflow.compat.v1 as tf 18 | 19 | from aug import autoaugment 20 | 21 | 22 | class AutoaugmentTest(tf.test.TestCase): 23 | 24 | def test_autoaugment_policy(self): 25 | # A very simple test to verify no syntax error. 26 | image = tf.placeholder(tf.uint8, shape=[640, 640, 3]) 27 | bboxes = tf.placeholder(tf.float32, shape=[4, 4]) 28 | autoaugment.distort_image_with_autoaugment(image, bboxes, 'test') 29 | 30 | def test_randaugment_policy(self): 31 | image = tf.placeholder(tf.uint8, shape=[320, 320, 3]) 32 | bboxes = tf.placeholder(tf.float32, shape=[4, 4]) 33 | autoaugment.distort_image_with_randaugment(image, bboxes, 1, 15) 34 | 35 | 36 | if __name__ == '__main__': 37 | logging.set_verbosity(logging.WARNING) 38 | tf.disable_eager_execution() 39 | tf.test.main() 40 | 41 | -------------------------------------------------------------------------------- /efficientdet/aug/gridmask.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Grid Masking Augmentation Reference: https://arxiv.org/abs/2001.04086.""" 16 | import math 17 | 18 | import tensorflow as tf 19 | from tensorflow_addons import image as image_ops 20 | 21 | 22 | class GridMask(object): 23 | """GridMask class for grid masking augmentation.""" 24 | 25 | def __init__(self, 26 | prob=0.6, 27 | ratio=0.6, 28 | rotate=10, 29 | gridmask_size_ratio=0.5, 30 | fill=1, 31 | interpolation="BILINEAR"): 32 | """initialization. 33 | 34 | Args: 35 | prob: probablity of occurance. 36 | ratio: grid mask ratio i.e if 0.5 grid and spacing will be equal. 37 | rotate: Rotation of grid mesh. 38 | gridmask_size_ratio: Grid mask size, grid to image size ratio. 39 | fill: Fill value for grids. 40 | interpolation: Interpolation method for rotation. 41 | """ 42 | self.prob = prob 43 | self.ratio = ratio 44 | self.rotate = rotate 45 | self.gridmask_size_ratio = gridmask_size_ratio 46 | self.fill = fill 47 | self.interpolation = interpolation 48 | 49 | @tf.function 50 | def random_rotate(self, mask): 51 | """Randomly rotates mask on given range.""" 52 | 53 | angle = self.rotate * tf.random.normal([], -1, 1) 54 | angle = math.pi * angle / 180 55 | return image_ops.rotate(mask, angle, interpolation=self.interpolation) 56 | 57 | @staticmethod 58 | def crop(mask, h, w): 59 | """crops in middle of mask and image corners.""" 60 | ww = hh = tf.shape(mask)[0] 61 | mask = mask[(hh - h) // 2:(hh - h) // 2 + h, 62 | (ww - w) // 2:(ww - w) // 2 + w,] 63 | return mask 64 | 65 | @tf.function 66 | def mask(self, h, w): 67 | """mask helper function for initializing grid mask of required size.""" 68 | h = tf.cast(h, tf.float32) 69 | w = tf.cast(w, tf.float32) 70 | mask_w = mask_h = tf.cast( 71 | tf.cast( 72 | (self.gridmask_size_ratio + 1), tf.float32) * tf.math.maximum(h, w), 73 | tf.int32) 74 | self.mask_w = mask_w 75 | mask = tf.zeros(shape=[mask_h, mask_w], dtype=tf.int32) 76 | gridblock = tf.random.uniform( 77 | shape=[], 78 | minval=int(tf.math.minimum(h * 0.5, w * 0.3)), 79 | maxval=int(tf.math.maximum(h * 0.5, w * 0.3)) + 1, 80 | dtype=tf.int32) 81 | 82 | if self.ratio == 1: 83 | length = tf.random.uniform( 84 | shape=[], minval=1, maxval=gridblock + 1, dtype=tf.int32) 85 | else: 86 | length = tf.cast( 87 | tf.math.minimum( 88 | tf.math.maximum( 89 | int(tf.cast(gridblock, tf.float32) * self.ratio + 0.5), 1), 90 | gridblock - 1), tf.int32) 91 | 92 | for _ in range(2): 93 | start_w = tf.random.uniform( 94 | shape=[], minval=0, maxval=gridblock + 1, dtype=tf.int32) 95 | for i in range(mask_w // gridblock): 96 | start = gridblock * i + start_w 97 | end = tf.math.minimum(start + length, mask_w) 98 | indices = tf.reshape(tf.range(start, end), [end - start, 1]) 99 | updates = ( 100 | tf.ones(shape=[end - start, mask_w], dtype=tf.int32) * self.fill) 101 | mask = tf.tensor_scatter_nd_update(mask, indices, updates) 102 | mask = tf.transpose(mask) 103 | 104 | return mask 105 | 106 | def __call__(self, image, label): 107 | """Masks input image tensor with random grid mask.""" 108 | h = tf.shape(image)[0] 109 | w = tf.shape(image)[1] 110 | grid = self.mask(h, w) 111 | grid = self.random_rotate(grid) 112 | mask = self.crop(grid, h, w) 113 | mask = tf.cast(mask, image.dtype) 114 | mask = tf.reshape(mask, (h, w)) 115 | mask = (tf.expand_dims(mask, -1) if image._rank() != mask._rank() else mask) 116 | occur = tf.random.normal([], 0, 1) < self.prob 117 | image = tf.cond(occur, lambda: image * mask, lambda: image) 118 | return image, label 119 | 120 | 121 | def gridmask(image, 122 | boxes, 123 | prob=0.5, 124 | ratio=0.6, 125 | rotate=10, 126 | gridmask_size_ratio=0.5, 127 | fill=1): 128 | """Callable instance of GridMask and transforms input image.""" 129 | gridmask_obj = GridMask( 130 | prob=prob, 131 | ratio=ratio, 132 | rotate=rotate, 133 | gridmask_size_ratio=gridmask_size_ratio, 134 | fill=fill) 135 | image, boxes = gridmask_obj(image, boxes) 136 | return image, boxes 137 | -------------------------------------------------------------------------------- /efficientdet/aug/gridmask_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """GridMask Augmentation simple test.""" 16 | from absl import logging 17 | import tensorflow.compat.v1 as tf 18 | 19 | from aug import gridmask 20 | 21 | 22 | class GridMaskTest(tf.test.TestCase): 23 | 24 | def setUp(self): 25 | super().setUp() 26 | tf.random.set_random_seed(111111) 27 | 28 | def test_gridmask_images(self): 29 | """Verify transformed image shape is valid and syntax check.""" 30 | images = tf.random.uniform( 31 | shape=(512, 512, 3), minval=0, maxval=255, dtype=tf.float32) 32 | bboxes = tf.random.uniform( 33 | shape=(2, 4), minval=1, maxval=511, dtype=tf.int32) 34 | transform_images, _ = gridmask.gridmask(images, bboxes) 35 | self.assertEqual(images.shape[1], transform_images.shape[1]) 36 | 37 | def test_gridmask_tiny_images(self): 38 | """Verify transform image shape on very tiny image.""" 39 | images = tf.zeros(shape=(4, 4, 3)) 40 | bboxes = tf.random.uniform( 41 | shape=(2, 4), minval=1, maxval=511, dtype=tf.int32) 42 | transform_images, _ = gridmask.gridmask(images, bboxes) 43 | self.assertEqual(images.shape[1], transform_images.shape[1]) 44 | 45 | def test_rectangle_image_shape(self): 46 | """Verify transform image shape on rectangle image.""" 47 | images = tf.zeros(shape=(1028, 512, 3)) 48 | bboxes = tf.random.uniform( 49 | shape=(2, 4), minval=1, maxval=511, dtype=tf.int32) 50 | transform_images, _ = gridmask.gridmask(images, bboxes) 51 | self.assertEqual(images.shape[1], transform_images.shape[1]) 52 | 53 | 54 | if __name__ == "__main__": 55 | logging.set_verbosity(logging.WARNING) 56 | tf.disable_eager_execution() 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /efficientdet/aug/mosaic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Mosaic Augmentation simple test.""" 16 | from absl import logging 17 | import tensorflow.compat.v1 as tf 18 | 19 | from aug import mosaic 20 | 21 | 22 | class MosaicTest(tf.test.TestCase): 23 | 24 | def __init__(self, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | self.output_size = (512, 512) 27 | self.mosaic = mosaic.Mosaic(out_size=self.output_size) 28 | tf.random.set_random_seed(111111) 29 | 30 | def test_mosaic_boxes(self): 31 | """Verify num of boxes are valid and syntax check random four images.""" 32 | images = tf.random.uniform( 33 | shape=(4, 512, 512, 3), minval=0, maxval=255, dtype=tf.float32) 34 | bboxes = tf.random.uniform( 35 | shape=(4, 2, 4), minval=1, maxval=511, dtype=tf.int32) 36 | _, mosaic_boxes = self.mosaic(images, bboxes) 37 | self.assertEqual(bboxes.shape[0], len(mosaic_boxes)) 38 | 39 | def test_mosaic_tiny_images(self): 40 | images = tf.zeros(shape=(4, 4, 4, 3)) 41 | bboxes = tf.random.uniform( 42 | shape=(4, 2, 4), minval=1, maxval=511, dtype=tf.int32) 43 | _, mosaic_boxes = self.mosaic(images, bboxes) 44 | self.assertEqual(bboxes.shape[0], len(mosaic_boxes)) 45 | 46 | 47 | if __name__ == "__main__": 48 | logging.set_verbosity(logging.WARNING) 49 | tf.disable_eager_execution() 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /efficientdet/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /efficientdet/backbone/backbone_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Backbone network factory.""" 16 | import os 17 | from absl import logging 18 | import tensorflow as tf 19 | 20 | from backbone import efficientnet_builder 21 | from backbone import efficientnet_lite_builder 22 | from backbone import efficientnet_model 23 | 24 | 25 | def get_model_builder(model_name): 26 | """Get the model_builder module for a given model name.""" 27 | if model_name.startswith('efficientnet-lite'): 28 | return efficientnet_lite_builder 29 | elif model_name.startswith('efficientnet-'): 30 | return efficientnet_builder 31 | else: 32 | raise ValueError('Unknown model name {}'.format(model_name)) 33 | 34 | 35 | def get_model(model_name, override_params=None, model_dir=None): 36 | """A helper function to create and return model. 37 | 38 | Args: 39 | model_name: string, the predefined model name. 40 | override_params: A dictionary of params for overriding. Fields must exist in 41 | efficientnet_model.GlobalParams. 42 | model_dir: string, optional model dir for saving configs. 43 | 44 | Returns: 45 | created model 46 | 47 | Raises: 48 | When model_name specified an undefined model, raises NotImplementedError. 49 | When override_params has invalid fields, raises ValueError. 50 | """ 51 | 52 | # For backward compatibility. 53 | if override_params and override_params.get('drop_connect_rate', None): 54 | override_params['survival_prob'] = 1 - override_params['drop_connect_rate'] 55 | 56 | if not override_params: 57 | override_params = {} 58 | 59 | if model_name.startswith('efficientnet-lite'): 60 | builder = efficientnet_lite_builder 61 | elif model_name.startswith('efficientnet-'): 62 | builder = efficientnet_builder 63 | else: 64 | raise ValueError('Unknown model name {}'.format(model_name)) 65 | 66 | blocks_args, global_params = builder.get_model_params(model_name, 67 | override_params) 68 | 69 | if model_dir: 70 | param_file = os.path.join(model_dir, 'model_params.txt') 71 | if not tf.io.gfile.exists(param_file): 72 | if not tf.io.gfile.exists(model_dir): 73 | tf.io.gfile.mkdir(model_dir) 74 | with tf.io.gfile.GFile(param_file, 'w') as f: 75 | logging.info('writing to %s', param_file) 76 | f.write('model_name= %s\n\n' % model_name) 77 | f.write('global_params= %s\n\n' % str(global_params)) 78 | f.write('blocks_args= %s\n\n' % str(blocks_args)) 79 | 80 | return efficientnet_model.Model(blocks_args, global_params, model_name) 81 | -------------------------------------------------------------------------------- /efficientdet/backbone/efficientnet_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for efficientnet_builder.""" 16 | from absl import logging 17 | import numpy as np 18 | import tensorflow.compat.v1 as tf 19 | 20 | from backbone import efficientnet_builder 21 | 22 | 23 | class EfficientnetBuilderTest(tf.test.TestCase): 24 | 25 | def _test_model_params(self, 26 | model_name, 27 | input_size, 28 | expected_params, 29 | override_params=None, 30 | features_only=False, 31 | pooled_features_only=False): 32 | images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32) 33 | efficientnet_builder.build_model( 34 | images, 35 | model_name=model_name, 36 | override_params=override_params, 37 | training=False, 38 | features_only=features_only, 39 | pooled_features_only=pooled_features_only) 40 | num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 41 | self.assertEqual(num_params, expected_params) 42 | 43 | def test_efficientnet_b0(self): 44 | self._test_model_params('efficientnet-b0', 224, expected_params=5288548) 45 | 46 | def test_efficientnet_b1(self): 47 | self._test_model_params('efficientnet-b1', 240, expected_params=7794184) 48 | 49 | def test_efficientnet_b2(self): 50 | self._test_model_params('efficientnet-b2', 260, expected_params=9109994) 51 | 52 | def test_efficientnet_b3(self): 53 | self._test_model_params('efficientnet-b3', 300, expected_params=12233232) 54 | 55 | def test_efficientnet_b4(self): 56 | self._test_model_params('efficientnet-b4', 380, expected_params=19341616) 57 | 58 | def test_efficientnet_b5(self): 59 | self._test_model_params('efficientnet-b5', 456, expected_params=30389784) 60 | 61 | def test_efficientnet_b6(self): 62 | self._test_model_params('efficientnet-b6', 528, expected_params=43040704) 63 | 64 | def test_efficientnet_b7(self): 65 | self._test_model_params('efficientnet-b7', 600, expected_params=66347960) 66 | 67 | def test_efficientnet_b0_with_customized_num_classes(self): 68 | self._test_model_params( 69 | 'efficientnet-b0', 70 | 224, 71 | expected_params=4135648, 72 | override_params={'num_classes': 100}) 73 | 74 | def test_efficientnet_b0_with_features_only(self): 75 | self._test_model_params( 76 | 'efficientnet-b0', 224, features_only=True, expected_params=3595388) 77 | 78 | def test_efficientnet_b0_with_pooled_features_only(self): 79 | self._test_model_params( 80 | 'efficientnet-b0', 81 | 224, 82 | pooled_features_only=True, 83 | expected_params=4007548) 84 | 85 | def test_efficientnet_b0_fails_if_both_features_requested(self): 86 | with self.assertRaises(AssertionError): 87 | efficientnet_builder.build_model( 88 | None, 89 | model_name='efficientnet-b0', 90 | training=False, 91 | features_only=True, 92 | pooled_features_only=True) 93 | 94 | def test_efficientnet_b0_base(self): 95 | # Creates a base model using the model configuration. 96 | images = tf.zeros((1, 224, 224, 3), dtype=tf.float32) 97 | _, endpoints = efficientnet_builder.build_model_base( 98 | images, model_name='efficientnet-b0', training=False) 99 | 100 | # reduction_1 to reduction_5 should be in endpoints 101 | self.assertEqual(len(endpoints), 5) 102 | 103 | 104 | if __name__ == '__main__': 105 | logging.set_verbosity(logging.WARNING) 106 | # Disable eager to allow tf.profile works for #params/#flops. 107 | tf.disable_eager_execution() 108 | tf.test.main() 109 | -------------------------------------------------------------------------------- /efficientdet/backbone/efficientnet_lite_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for efficientnet_lite_builder.""" 16 | from absl import logging 17 | import numpy as np 18 | import tensorflow.compat.v1 as tf 19 | 20 | from backbone import efficientnet_lite_builder 21 | 22 | 23 | class EfficientnetBuilderTest(tf.test.TestCase): 24 | 25 | def _test_model_params(self, 26 | model_name, 27 | input_size, 28 | expected_params, 29 | override_params=None, 30 | features_only=False, 31 | pooled_features_only=False): 32 | images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32) 33 | efficientnet_lite_builder.build_model( 34 | images, 35 | model_name=model_name, 36 | override_params=override_params, 37 | training=False, 38 | features_only=features_only, 39 | pooled_features_only=pooled_features_only) 40 | num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 41 | 42 | self.assertEqual(num_params, expected_params) 43 | 44 | def test_efficientnet_b0(self): 45 | self._test_model_params( 46 | 'efficientnet-lite0', 224, expected_params=4652008) 47 | 48 | def test_efficientnet_b1(self): 49 | self._test_model_params( 50 | 'efficientnet-lite1', 240, expected_params=5416680) 51 | 52 | def test_efficientnet_b2(self): 53 | self._test_model_params( 54 | 'efficientnet-lite2', 260, expected_params=6092072) 55 | 56 | def test_efficientnet_b3(self): 57 | self._test_model_params( 58 | 'efficientnet-lite3', 280, expected_params=8197096) 59 | 60 | def test_efficientnet_b4(self): 61 | self._test_model_params( 62 | 'efficientnet-lite4', 300, expected_params=13006568) 63 | 64 | 65 | if __name__ == '__main__': 66 | logging.set_verbosity(logging.WARNING) 67 | # Disable eager to allow tf.profile works for #params/#flops. 68 | tf.disable_eager_execution() 69 | tf.test.main() 70 | -------------------------------------------------------------------------------- /efficientdet/coco_metric_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for coco_metric.""" 16 | 17 | from absl import logging 18 | import tensorflow.compat.v1 as tf 19 | import coco_metric 20 | 21 | 22 | class CocoMetricTest(tf.test.TestCase): 23 | 24 | def setUp(self): 25 | super(CocoMetricTest, self).setUp() 26 | # [y1, x1, y2, x2, is_crowd, area, class], in image coords. 27 | self.groundtruth_data = tf.constant([[ 28 | [10.0, 10.0, 20.0, 20.0, 0.0, 100.0, 1], 29 | [10.0, 10.0, 30.0, 15.0, 0.0, 100.0, 2], 30 | [30.0, 30.0, 40.0, 50.0, 0.0, 100.0, 3] 31 | ]], dtype=tf.float32) 32 | # [image_id, x, y, width, height, score, class] 33 | self.detections = tf.constant([[ 34 | [1.0, 10.0, 10.0, 10.0, 10.0, 0.6, 1], 35 | [1.0, 10.0, 10.0, 5.0, 20.0, 0.5, 2] 36 | ]], dtype=tf.float32) 37 | self.class_labels = {1: 'car', 2: 'truck', 3: 'bicycle'} 38 | 39 | def test_mAP(self): 40 | 41 | eval_metric = coco_metric.EvaluationMetric(label_map=self.class_labels) 42 | coco_metrics = eval_metric.estimator_metric_fn(self.detections, 43 | self.groundtruth_data) 44 | self.assertEqual(len(coco_metrics.keys()), 15) 45 | self.assertAllClose(coco_metrics['AP'][0], 2.0/3.0) 46 | self.assertAllClose(coco_metrics['AP_/car'][0], 1.0) 47 | self.assertAllClose(coco_metrics['AP_/truck'][0], 1.0) 48 | self.assertAllClose(coco_metrics['AP_/bicycle'][0], 0.0) 49 | 50 | 51 | if __name__ == '__main__': 52 | logging.set_verbosity(logging.WARNING) 53 | tf.test.main() 54 | -------------------------------------------------------------------------------- /efficientdet/dataloader_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Data loader and processing test cases.""" 16 | 17 | import tensorflow as tf 18 | 19 | import dataloader 20 | import hparams_config 21 | import test_util 22 | 23 | from tf2 import anchors 24 | from object_detection import tf_example_decoder 25 | 26 | 27 | class DataloaderTest(tf.test.TestCase): 28 | 29 | def test_parser(self): 30 | tf.random.set_seed(111111) 31 | params = hparams_config.get_detection_config('efficientdet-d0').as_dict() 32 | input_anchors = anchors.Anchors(params['min_level'], params['max_level'], 33 | params['num_scales'], 34 | params['aspect_ratios'], 35 | params['anchor_scale'], 36 | params['image_size']) 37 | anchor_labeler = anchors.AnchorLabeler(input_anchors, params['num_classes']) 38 | example_decoder = tf_example_decoder.TfExampleDecoder( 39 | regenerate_source_id=params['regenerate_source_id']) 40 | tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir()) 41 | dataset = tf.data.TFRecordDataset([tfrecord_path]) 42 | value = next(iter(dataset)) 43 | reader = dataloader.InputReader(tfrecord_path, True) 44 | result = reader.dataset_parser(value, example_decoder, anchor_labeler, 45 | params) 46 | self.assertEqual(len(result), 11) 47 | 48 | 49 | if __name__ == '__main__': 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /efficientdet/dataset/README.md: -------------------------------------------------------------------------------- 1 | This folder provides tools for converting raw coco/pascal data to tfrecord. 2 | 3 | ### 1. Convert COCO validation set to tfrecord: 4 | 5 | # Download coco data. 6 | !wget http://images.cocodataset.org/zips/val2017.zip 7 | !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 8 | !unzip val2017.zip 9 | !unzip annotations_trainval2017.zip 10 | 11 | # convert coco data to tfrecord. 12 | !mkdir tfrecord 13 | !PYTHONPATH=".:$PYTHONPATH" python dataset/create_coco_tfrecord.py \ 14 | --image_dir=val2017 \ 15 | --object_annotations_file=annotations/instances_val2017.json \ 16 | --output_file_prefix=tfrecord/val \ 17 | --num_shards=32 18 | 19 | ### 2. Convert Pascal VOC 2012 to tfrecord: 20 | 21 | # Download and convert pascal data. 22 | !wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 23 | !tar xf VOCtrainval_11-May-2012.tar 24 | !mkdir tfrecord 25 | !PYTHONPATH=".:$PYTHONPATH" python dataset/create_pascal_tfrecord.py \ 26 | --data_dir=VOCdevkit --year=VOC2012 --output_path=tfrecord/pascal 27 | 28 | Attention: soure_id (or image_id) needs to be an integer due to the official COCO library requreiments. 29 | -------------------------------------------------------------------------------- /efficientdet/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # This library is mostly based on tensorflow object detection API 16 | # https://github.com/tensorflow/models/blob/master/research/object_detection/dataset_tools/create_coco_tf_record.py 17 | -------------------------------------------------------------------------------- /efficientdet/dataset/create_pascal_tfrecord_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test for create_pascal_tfrecord.py.""" 16 | 17 | import os 18 | 19 | from absl import logging 20 | import numpy as np 21 | import PIL.Image 22 | import six 23 | import tensorflow as tf 24 | 25 | from dataset import create_pascal_tfrecord 26 | 27 | 28 | class CreatePascalTFRecordTest(tf.test.TestCase): 29 | 30 | def _assertProtoEqual(self, proto_field, expectation): 31 | """Helper function to assert if a proto field equals some value. 32 | 33 | Args: 34 | proto_field: The protobuf field to compare. 35 | expectation: The expected value of the protobuf field. 36 | """ 37 | proto_list = [p for p in proto_field] 38 | self.assertListEqual(proto_list, expectation) 39 | 40 | def test_dict_to_tf_example(self): 41 | image_file_name = '2012_12.jpg' 42 | image_data = np.random.rand(256, 256, 3) 43 | save_path = os.path.join(self.get_temp_dir(), image_file_name) 44 | image = PIL.Image.fromarray(image_data, 'RGB') 45 | image.save(save_path) 46 | 47 | data = { 48 | 'folder': 49 | '', 50 | 'filename': 51 | image_file_name, 52 | 'size': { 53 | 'height': 256, 54 | 'width': 256, 55 | }, 56 | 'object': [ 57 | { 58 | 'difficult': 1, 59 | 'bndbox': { 60 | 'xmin': 64, 61 | 'ymin': 64, 62 | 'xmax': 192, 63 | 'ymax': 192, 64 | }, 65 | 'name': 'person', 66 | 'truncated': 0, 67 | 'pose': '', 68 | }, 69 | { 70 | 'difficult': 0, 71 | 'bndbox': { 72 | 'xmin': 128, 73 | 'ymin': 128, 74 | 'xmax': 256, 75 | 'ymax': 256, 76 | }, 77 | 'name': 'notperson', 78 | 'truncated': 0, 79 | 'pose': '', 80 | }, 81 | ], 82 | } 83 | 84 | label_map_dict = { 85 | 'background': 0, 86 | 'person': 1, 87 | 'notperson': 2, 88 | } 89 | 90 | ann_json_dict = {'images': [], 'annotations': [], 'categories': []} 91 | unique_id = create_pascal_tfrecord.UniqueId() 92 | example = create_pascal_tfrecord.dict_to_tf_example( 93 | data, 94 | self.get_temp_dir(), 95 | label_map_dict, 96 | unique_id, 97 | ann_json_dict=ann_json_dict) 98 | self.assertEqual(unique_id.image_id, 1) 99 | self.assertEqual(unique_id.ann_id, 2) 100 | 101 | self._assertProtoEqual( 102 | example.features.feature['image/height'].int64_list.value, [256]) 103 | self._assertProtoEqual( 104 | example.features.feature['image/width'].int64_list.value, [256]) 105 | self._assertProtoEqual( 106 | example.features.feature['image/filename'].bytes_list.value, 107 | [six.b(image_file_name)]) 108 | self._assertProtoEqual( 109 | example.features.feature['image/source_id'].bytes_list.value, 110 | [six.b(str(1))]) 111 | self._assertProtoEqual( 112 | example.features.feature['image/format'].bytes_list.value, 113 | [six.b('jpeg')]) 114 | self._assertProtoEqual( 115 | example.features.feature['image/object/bbox/xmin'].float_list.value, 116 | [0.25, 0.5]) 117 | self._assertProtoEqual( 118 | example.features.feature['image/object/bbox/ymin'].float_list.value, 119 | [0.25, 0.5]) 120 | self._assertProtoEqual( 121 | example.features.feature['image/object/bbox/xmax'].float_list.value, 122 | [0.75, 1.0]) 123 | self._assertProtoEqual( 124 | example.features.feature['image/object/bbox/ymax'].float_list.value, 125 | [0.75, 1.0]) 126 | self._assertProtoEqual( 127 | example.features.feature['image/object/class/text'].bytes_list.value, 128 | [six.b('person'), six.b('notperson')]) 129 | self._assertProtoEqual( 130 | example.features.feature['image/object/class/label'].int64_list.value, 131 | [1, 2]) 132 | self._assertProtoEqual( 133 | example.features.feature['image/object/difficult'].int64_list.value, 134 | [1, 0]) 135 | self._assertProtoEqual( 136 | example.features.feature['image/object/truncated'].int64_list.value, 137 | [0, 0]) 138 | self._assertProtoEqual( 139 | example.features.feature['image/object/view'].bytes_list.value, 140 | [six.b(''), six.b('')]) 141 | 142 | expected_ann_json_dict = { 143 | 'annotations': [{ 144 | 'area': 16384, 145 | 'iscrowd': 0, 146 | 'image_id': 1, 147 | 'bbox': [64, 64, 128, 128], 148 | 'category_id': 1, 149 | 'id': 1, 150 | 'ignore': 0, 151 | 'segmentation': [] 152 | }, { 153 | 'area': 16384, 154 | 'iscrowd': 0, 155 | 'image_id': 1, 156 | 'bbox': [128, 128, 128, 128], 157 | 'category_id': 2, 158 | 'id': 2, 159 | 'ignore': 0, 160 | 'segmentation': [] 161 | }], 162 | 'categories': [], 163 | 'images': [{ 164 | 'file_name': '2012_12.jpg', 165 | 'height': 256, 166 | 'width': 256, 167 | 'id': 1 168 | }] 169 | } 170 | self.assertEqual(ann_json_dict, expected_ann_json_dict) 171 | 172 | 173 | if __name__ == '__main__': 174 | logging.set_verbosity(logging.WARNING) 175 | tf.test.main() 176 | -------------------------------------------------------------------------------- /efficientdet/dataset/inspect_tfrecords.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Inspect tfrecord dataset.""" 16 | import os 17 | from absl import app 18 | from absl import flags 19 | from absl import logging 20 | import numpy as np 21 | from PIL import Image 22 | import tensorflow as tf 23 | 24 | import dataloader 25 | import hparams_config 26 | import utils 27 | from visualize import vis_utils 28 | 29 | flags.DEFINE_string('save_samples_dir', 'tfrecord_samples', 30 | 'Location of samples to save') 31 | flags.DEFINE_string('model_name', 'efficientdet-d0', 32 | 'model name for config and image_size') 33 | flags.DEFINE_string( 34 | 'hparams', '', 'Comma separated k=v pairs of hyperparameters or a module' 35 | ' containing attributes to use as hyperparameters.') 36 | flags.DEFINE_integer('samples', 10, 37 | 'Number of random samples for visualization.') 38 | flags.DEFINE_string('file_pattern', None, 39 | 'Glob for data files (e.g., COCO train - minival set)') 40 | flags.DEFINE_bool('eval', False, 'flag for file pattern mode i.e eval') 41 | FLAGS = flags.FLAGS 42 | 43 | 44 | class RecordInspect: 45 | """Inspection Class.""" 46 | 47 | def __init__(self, config): 48 | """Initializes RecordInspect with passed config. 49 | 50 | Args: 51 | config: config file to initialize input_fn. 52 | """ 53 | self.input_fn = dataloader.InputReader( 54 | FLAGS.file_pattern, 55 | is_training=not FLAGS.eval, 56 | use_fake_data=False, 57 | max_instances_per_image=config.max_instances_per_image) 58 | 59 | self.params = dict( 60 | config.as_dict(), batch_size=FLAGS.samples, model_name=FLAGS.model_name) 61 | logging.info(self.params) 62 | self.cls_to_label = config.label_map 63 | os.makedirs(FLAGS.save_samples_dir, exist_ok=True) 64 | 65 | def visualize(self): 66 | """save tfrecords images with bounding boxes.""" 67 | vis_ds = self.input_fn(params=self.params) 68 | data = next(iter(vis_ds)) # iterable. 69 | images = data[0] 70 | gt_data = data[1]['groundtruth_data'] 71 | 72 | # scales 73 | scale_to_org = data[1]['image_scales'] 74 | scales = 1.0 / scale_to_org 75 | offset = tf.constant([0.485, 0.456, 0.406]) 76 | offset = tf.reshape(offset, (1, 1, -1)) 77 | scale_image = tf.constant([0.229, 0.224, 0.225]) 78 | scale_image = tf.reshape(scale_image, (1, 1, -1)) 79 | 80 | logging.info('Visualizing TfRecords %s', FLAGS.file_pattern) 81 | for i, zip_data in enumerate(zip(gt_data, images, scales)): 82 | gt, image, scale = zip_data 83 | boxes = gt[:, :4] 84 | boxes = boxes[np.any(boxes > 0, axis=1)].numpy() 85 | if boxes.shape[0] > 0: 86 | classes = gt[:boxes.shape[0], -1].numpy() 87 | try: 88 | category_index = {idx: {'id': idx, 'name': self.cls_to_label[idx]} 89 | for idx in np.asarray(classes, dtype=np.int)} 90 | except Exception: # pylint: disable=broad-except 91 | category_index = {} 92 | 93 | # unnormalize image. 94 | image *= scale_image 95 | image += offset 96 | 97 | # 0-255. range 98 | image = np.asarray(image.numpy() * 255., dtype=np.uint8) 99 | 100 | # scale to image_size 101 | boxes *= scale.numpy() 102 | 103 | image = vis_utils.visualize_boxes_and_labels_on_image_array( 104 | image, 105 | boxes=boxes, 106 | classes=classes.astype(int), 107 | scores=np.ones(boxes.shape[0]), 108 | category_index=category_index, 109 | line_thickness=2, 110 | skip_scores=True) 111 | image = Image.fromarray(image) 112 | image.save(os.path.join(FLAGS.save_samples_dir, f'sample{i}.jpg')) 113 | 114 | 115 | def main(_): 116 | # Parse and override hparams 117 | config = hparams_config.get_detection_config(FLAGS.model_name) 118 | config.override(FLAGS.hparams) 119 | 120 | # Parse image size in case it is in string format. 121 | config.image_size = utils.parse_image_size(config.image_size) 122 | try: 123 | recordinspect = RecordInspect(config) 124 | recordinspect.visualize() 125 | except Exception as e: # pylint: disable=broad-except 126 | logging.error(e) 127 | else: 128 | logging.info('Done, please find samples at %s', FLAGS.save_samples_dir) 129 | 130 | 131 | if __name__ == '__main__': 132 | app.run(main) 133 | -------------------------------------------------------------------------------- /efficientdet/dataset/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Label map utility functions.""" 16 | from absl import logging 17 | 18 | 19 | def _validate_label_map(label_map): 20 | """Checks if a label map is valid. 21 | 22 | Args: 23 | label_map: StringIntLabelMap to validate. 24 | 25 | Raises: 26 | ValueError: if label map is invalid. 27 | """ 28 | for item in label_map.item: 29 | if item.id < 0: 30 | raise ValueError('Label map ids should be >= 0.') 31 | if (item.id == 0 and item.name != 'background' and 32 | item.display_name != 'background'): 33 | raise ValueError('Label map id 0 is reserved for the background label') 34 | 35 | 36 | def create_category_index(categories): 37 | """Creates dictionary of COCO compatible categories keyed by category id. 38 | 39 | Args: 40 | categories: a list of dicts, each of which has the following keys: 41 | 'id': (required) an integer id uniquely identifying this category. 42 | 'name': (required) string representing category name 43 | e.g., 'cat', 'dog', 'pizza'. 44 | 45 | Returns: 46 | category_index: a dict containing the same entries as categories, but keyed 47 | by the 'id' field of each category. 48 | """ 49 | category_index = {} 50 | for cat in categories: 51 | category_index[cat['id']] = cat 52 | return category_index 53 | 54 | 55 | def get_max_label_map_index(label_map): 56 | """Get maximum index in label map. 57 | 58 | Args: 59 | label_map: a StringIntLabelMapProto 60 | 61 | Returns: 62 | an integer 63 | """ 64 | return max([item.id for item in label_map.item]) 65 | 66 | 67 | def convert_label_map_to_categories(label_map, 68 | max_num_classes, 69 | use_display_name=True): 70 | """Given label map proto returns categories list compatible with eval. 71 | 72 | This function converts label map proto and returns a list of dicts, each of 73 | which has the following keys: 74 | 'id': (required) an integer id uniquely identifying this category. 75 | 'name': (required) string representing category name 76 | e.g., 'cat', 'dog', 'pizza'. 77 | 'keypoints': (optional) a dictionary of keypoint string 'label' to integer 78 | 'id'. 79 | We only allow class into the list if its id-label_id_offset is 80 | between 0 (inclusive) and max_num_classes (exclusive). 81 | If there are several items mapping to the same id in the label map, 82 | we will only keep the first one in the categories list. 83 | 84 | Args: 85 | label_map: a StringIntLabelMapProto or None. If None, a default categories 86 | list is created with max_num_classes categories. 87 | max_num_classes: maximum number of (consecutive) label indices to include. 88 | use_display_name: (boolean) choose whether to load 'display_name' field as 89 | category name. If False or if the display_name field does not exist, uses 90 | 'name' field as category names instead. 91 | 92 | Returns: 93 | categories: a list of dictionaries representing all possible categories. 94 | """ 95 | categories = [] 96 | list_of_ids_already_added = [] 97 | if not label_map: 98 | label_id_offset = 1 99 | for class_id in range(max_num_classes): 100 | categories.append({ 101 | 'id': class_id + label_id_offset, 102 | 'name': 'category_{}'.format(class_id + label_id_offset) 103 | }) 104 | return categories 105 | for item in label_map.item: 106 | if not 0 < item.id <= max_num_classes: 107 | logging.info( 108 | 'Ignore item %d since it falls outside of requested ' 109 | 'label range.', item.id) 110 | continue 111 | if use_display_name and item.HasField('display_name'): 112 | name = item.display_name 113 | else: 114 | name = item.name 115 | if item.id not in list_of_ids_already_added: 116 | list_of_ids_already_added.append(item.id) 117 | category = {'id': item.id, 'name': name} 118 | if item.keypoints: 119 | keypoints = {} 120 | list_of_keypoint_ids = [] 121 | for kv in item.keypoints: 122 | if kv.id in list_of_keypoint_ids: 123 | raise ValueError('Duplicate keypoint ids are not allowed. ' 124 | 'Found {} more than once'.format(kv.id)) 125 | keypoints[kv.label] = kv.id 126 | list_of_keypoint_ids.append(kv.id) 127 | category['keypoints'] = keypoints 128 | categories.append(category) 129 | return categories 130 | 131 | 132 | def create_class_agnostic_category_index(): 133 | """Creates a category index with a single `object` class.""" 134 | return {1: {'id': 1, 'name': 'object'}} 135 | -------------------------------------------------------------------------------- /efficientdet/dataset/tfrecord_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""TFRecord related utilities.""" 16 | import tensorflow as tf 17 | 18 | 19 | def int64_feature(value): 20 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 21 | 22 | 23 | def int64_list_feature(value): 24 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 25 | 26 | 27 | def bytes_feature(value): 28 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 29 | 30 | 31 | def bytes_list_feature(value): 32 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 33 | 34 | 35 | def float_list_feature(value): 36 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 37 | 38 | 39 | def read_examples_list(path): 40 | """Read list of training or validation examples. 41 | 42 | The file is assumed to contain a single example per line where the first 43 | token in the line is an identifier that allows us to find the image and 44 | annotation xml for that example. 45 | 46 | For example, the line: 47 | xyz 3 48 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 49 | 50 | Args: 51 | path: absolute path to examples list file. 52 | 53 | Returns: 54 | list of example identifiers (strings). 55 | """ 56 | with tf.io.gfile.GFile(path) as fid: 57 | lines = fid.readlines() 58 | return [line.strip().split(' ')[0] for line in lines] 59 | 60 | 61 | def recursive_parse_xml_to_dict(xml): 62 | """Recursively parses XML contents to python dict. 63 | 64 | We assume that `object` tags are the only ones that can appear 65 | multiple times at the same level of a tree. 66 | 67 | Args: 68 | xml: xml tree obtained by parsing XML file contents using lxml.etree 69 | 70 | Returns: 71 | Python dictionary holding XML contents. 72 | """ 73 | if not len(xml): # pylint: disable=g-explicit-length-test 74 | return {xml.tag: xml.text if xml.text else ''} 75 | result = {} 76 | for child in xml: 77 | child_result = recursive_parse_xml_to_dict(child) 78 | if child.tag != 'object': 79 | result[child.tag] = child_result[child.tag] 80 | else: 81 | if child.tag not in result: 82 | result[child.tag] = [] 83 | result[child.tag].append(child_result[child.tag]) 84 | return {xml.tag: result} 85 | 86 | 87 | def open_sharded_output_tfrecords(exit_stack, base_path, num_shards): 88 | """Opens all TFRecord shards for writing and adds them to an exit stack. 89 | 90 | Args: 91 | exit_stack: A context2.ExitStack used to automatically closed the TFRecords 92 | opened in this function. 93 | base_path: The base path for all shards 94 | num_shards: The number of shards 95 | 96 | Returns: 97 | The list of opened TFRecords. Position k in the list corresponds to shard k. 98 | """ 99 | tf_record_output_filenames = [ 100 | '{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards) 101 | for idx in range(num_shards) 102 | ] 103 | 104 | tfrecords = [ 105 | exit_stack.enter_context(tf.io.TFRecordWriter(file_name)) 106 | for file_name in tf_record_output_filenames 107 | ] 108 | 109 | return tfrecords 110 | -------------------------------------------------------------------------------- /efficientdet/det_model_fn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for det_model_fn.""" 16 | import tensorflow as tf 17 | import det_model_fn 18 | 19 | 20 | def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _=0): 21 | """A legacy focal loss that does not support label smoothing.""" 22 | with tf.name_scope('focal_loss'): 23 | positive_label_mask = tf.equal(targets, 1.0) 24 | cross_entropy = ( 25 | tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits)) 26 | 27 | neg_logits = -1.0 * logits 28 | modulator = tf.exp(gamma * targets * neg_logits - 29 | gamma * tf.math.log1p(tf.exp(neg_logits))) 30 | loss = modulator * cross_entropy 31 | weighted_loss = tf.where(positive_label_mask, alpha * loss, 32 | (1.0 - alpha) * loss) 33 | weighted_loss /= normalizer 34 | return weighted_loss 35 | 36 | 37 | class FocalLossTest(tf.test.TestCase): 38 | 39 | def test_focal_loss(self): 40 | tf.random.set_seed(1111) 41 | y_pred = tf.random.uniform([4, 32, 32, 90]) 42 | y_true = tf.ones([4, 32, 32, 90]) 43 | alpha, gamma, n = 0.25, 1.5, 100 44 | legacy_output = legacy_focal_loss(y_pred, y_true, alpha, gamma, n) 45 | new_output = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n) 46 | self.assertAllClose(legacy_output, new_output) 47 | 48 | def test_focal_loss_with_label_smoothing(self): 49 | tf.random.set_seed(1111) 50 | shape = [2, 2, 2, 2] 51 | y_pred = tf.random.uniform(shape) 52 | 53 | # A binary classification target [0.0, 1.0] becomes [.1, .9] 54 | # with smoothing .2 55 | y_true = tf.ones(shape) * [0.0, 1.0] 56 | y_true_presmoothed = tf.ones(shape) * [0.1, 0.9] 57 | 58 | alpha, gamma, n = 1, 0, 100 59 | presmoothed = det_model_fn.focal_loss(y_pred, y_true_presmoothed, alpha, 60 | gamma, n, 0) 61 | alpha, gamma, n = 0.9, 0, 100 62 | unsmoothed = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n, 0.2) 63 | 64 | self.assertAllClose(presmoothed, unsmoothed) 65 | 66 | 67 | if __name__ == '__main__': 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /efficientdet/g3doc/Det-AdvProp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/Det-AdvProp.png -------------------------------------------------------------------------------- /efficientdet/g3doc/coco_ids.yaml: -------------------------------------------------------------------------------- 1 | 0: background 2 | 1: person 3 | 2: bicycle 4 | 3: car 5 | 4: motorcycle 6 | 5: airplane 7 | 6: bus 8 | 7: train 9 | 8: truck 10 | 9: boat 11 | 10: traffic light 12 | 11: fire hydrant 13 | 13: stop sign 14 | 14: parking meter 15 | 15: bench 16 | 16: bird 17 | 17: cat 18 | 18: dog 19 | 19: horse 20 | 20: sheep 21 | 21: cow 22 | 22: elephant 23 | 23: bear 24 | 24: zebra 25 | 25: giraffe 26 | 27: backpack 27 | 28: umbrella 28 | 31: handbag 29 | 32: tie 30 | 33: suitcase 31 | 34: frisbee 32 | 35: skis 33 | 36: snowboard 34 | 37: sports ball 35 | 38: kite 36 | 39: baseball bat 37 | 40: baseball glove 38 | 41: skateboard 39 | 42: surfboard 40 | 43: tennis racket 41 | 44: bottle 42 | 46: wine glass 43 | 47: cup 44 | 48: fork 45 | 49: knife 46 | 50: spoon 47 | 51: bowl 48 | 52: banana 49 | 53: apple 50 | 54: sandwich 51 | 55: orange 52 | 56: broccoli 53 | 57: carrot 54 | 58: hot dog 55 | 59: pizza 56 | 60: donut 57 | 61: cake 58 | 62: chair 59 | 63: couch 60 | 64: potted plant 61 | 65: bed 62 | 67: dining table 63 | 70: toilet 64 | 72: tv 65 | 73: laptop 66 | 74: mouse 67 | 75: remote 68 | 76: keyboard 69 | 77: cell phone 70 | 78: microwave 71 | 79: oven 72 | 80: toaster 73 | 81: sink 74 | 82: refrigerator 75 | 84: book 76 | 85: clock 77 | 86: vase 78 | 87: scissors 79 | 88: teddy bear 80 | 89: hair drier 81 | 90: toothbrush 82 | -------------------------------------------------------------------------------- /efficientdet/g3doc/faq.md: -------------------------------------------------------------------------------- 1 | # EfficientDet FQA 2 | 3 | 7 | 8 | [TOC] 9 | 10 | ## 1. For Users 11 | 12 | ### 1.1 How can I convert the saved model to tflite? 13 | 14 | Unfortunately, there is no way to do that with the current public tensorflow 15 | release due to some issues in tf converter. We have some internal fixes, which 16 | could potentially be available with the next TensorFlow release. 17 | 18 | ### 1.2 Why I see NaN during my training and how to debug it? 19 | 20 | Because we use batch norm, which needs reasonable batch size. If your batch size 21 | is too small, it may causes NaN. (We may add group norm to deal with this in 22 | futurre) 23 | 24 | If you see NaN, you can check the followings: 25 | 26 | - Is my batch size too small? It usually needs to be >=8. 27 | - Should I clip my gradient? How about h.clip_gradients_norm=5.0? 28 | - Should I use smaller jitter? How about jitter_min=0.8 and jitter_max=1.2? 29 | 30 | If you want to debug it, you can use these tools: 31 | 32 | ``` 33 | tf.compat.v1.add_check_numerics_ops() # for Tensorflow 1.x 34 | tf.debugging.disable_check_numerics() # for TensorFlow 2.x 35 | ``` 36 | 37 | ### 1.3 Why my last class eval AP is always zero? 38 | 39 | The current code assume class 0 is always reserved for background, so you if you K classes, then you should set num_classes=K+1. 40 | 41 | See [#391](https://github.com/google/automl/issues/391) and [#398](https://github.com/google/automl/issues/398) for more discussion. 42 | 43 | ### 1.4 Why my input pipeline has assert failure? 44 | 45 | This is most likely that your dataset has some images with many objects (more 46 | than the 100 limit for COCO), you should set --hparams="max_instances_per_image=200" or larger. 47 | 48 | See [#93](https://github.com/google/automl/issues/93) for more discussion. 49 | 50 | 51 | ## 2. For Developers 52 | 53 | ### 2.1 How can I format my code for PRs? 54 | 55 | Please use [yapf](https://github.com/google/yapf) with option 56 | --style='{based_on_style: yapf}'. You can also save the 57 | following file to ~/.config/yapf/style: 58 | 59 | [style] 60 | based_on_style = yapf 61 | 62 | If you want to check the format with lint, please run: 63 | 64 | !pylint --rcfile=../.pylintrc your_file.py 65 | 66 | ### 2.2 How can I run all tests? 67 | 68 | !export PYTHONPATH="`pwd`:$PYTHONPATH" 69 | !find . -name "*_test.py" | parallel python &> /tmp/test.log \ 70 | && echo "All passed" || echo "Failed! Search keyword FAILED in /tmp/test.log" 71 | -------------------------------------------------------------------------------- /efficientdet/g3doc/flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/flops.png -------------------------------------------------------------------------------- /efficientdet/g3doc/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/network.png -------------------------------------------------------------------------------- /efficientdet/g3doc/params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/params.png -------------------------------------------------------------------------------- /efficientdet/g3doc/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/g3doc/street.jpg -------------------------------------------------------------------------------- /efficientdet/hparams_config_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ====================================== 15 | """Tests for hparams_config.""" 16 | import os 17 | import tempfile 18 | from absl import logging 19 | import tensorflow.compat.v1 as tf 20 | import yaml 21 | 22 | import hparams_config 23 | 24 | 25 | class HparamsConfigTest(tf.test.TestCase): 26 | 27 | def test_config_override(self): 28 | c = hparams_config.Config({'a': 1, 'b': 2}) 29 | self.assertEqual(c.as_dict(), {'a': 1, 'b': 2}) 30 | 31 | c.update({'a': 10}) 32 | self.assertEqual(c.as_dict(), {'a': 10, 'b': 2}) 33 | 34 | c.b = 20 35 | self.assertEqual(c.as_dict(), {'a': 10, 'b': 20}) 36 | 37 | c.override('a=true,b=ss') 38 | self.assertEqual(c.as_dict(), {'a': True, 'b': 'ss'}) 39 | 40 | c.override('a=100,,,b=2.3,') # extra ',' is fine. 41 | self.assertEqual(c.as_dict(), {'a': 100, 'b': 2.3}) 42 | 43 | c.override('a=2x3,b=50') # a is a special format for image size. 44 | self.assertEqual(c.as_dict(), {'a': '2x3', 'b': 50}) 45 | 46 | # overrride string must be in the format of xx=yy. 47 | with self.assertRaises(ValueError): 48 | c.override('a=true,invalid_string') 49 | 50 | def test_config_yaml(self): 51 | tmpdir = tempfile.gettempdir() 52 | yaml_file_path = os.path.join(tmpdir, 'x.yaml') 53 | with open(yaml_file_path, 'w') as f: 54 | f.write(""" 55 | x: 2 56 | y: 57 | z: 'test' 58 | """) 59 | c = hparams_config.Config(dict(x=234, y=2342)) 60 | c.override(yaml_file_path) 61 | self.assertEqual(c.as_dict(), {'x': 2, 'y': {'z': 'test'}}) 62 | 63 | yaml_file_path2 = os.path.join(tmpdir, 'y.yaml') 64 | c.save_to_yaml(yaml_file_path2) 65 | with open(yaml_file_path2, 'r') as f: 66 | config_dict = yaml.load(f, Loader=yaml.FullLoader) 67 | self.assertEqual(config_dict, {'x': 2, 'y': {'z': 'test'}}) 68 | 69 | def test_config_override_recursive(self): 70 | c = hparams_config.Config({'x': 1}) 71 | self.assertEqual(c.as_dict(), {'x': 1}) 72 | c.override('y.y0=2,y.y1=3', allow_new_keys=True) 73 | self.assertEqual(c.as_dict(), {'x': 1, 'y': {'y0': 2, 'y1': 3}}) 74 | c.update({'y': {'y0': 5, 'y1': {'y11': 100}}}) 75 | self.assertEqual(c.as_dict(), {'x': 1, 'y': {'y0': 5, 'y1': {'y11': 100}}}) 76 | self.assertEqual(c.y.y1.y11, 100) 77 | 78 | def test_config_override_list(self): 79 | c = hparams_config.Config({'x': [1.0, 2.0]}) 80 | self.assertEqual(c.as_dict(), {'x': [1.0, 2.0]}) 81 | c.override('x=3.0*4.0*5.0') 82 | self.assertEqual(c.as_dict(), {'x': [3.0, 4.0, 5.0]}) 83 | 84 | 85 | if __name__ == '__main__': 86 | logging.set_verbosity(logging.WARNING) 87 | tf.test.main() 88 | -------------------------------------------------------------------------------- /efficientdet/install_deps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | for line in $(cat efficientdet/requirements.txt) 18 | do 19 | pip install $line 20 | done 21 | -------------------------------------------------------------------------------- /efficientdet/iou_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ====================================== 15 | """Tests for iou_utils.""" 16 | from absl import logging 17 | import tensorflow as tf 18 | import iou_utils 19 | 20 | 21 | class IouUtilsTest(tf.test.TestCase): 22 | """IoU test class.""" 23 | 24 | def setUp(self): 25 | super(IouUtilsTest, self).setUp() 26 | self.pb = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], 27 | dtype=tf.float32) 28 | self.tb = tf.constant( 29 | [[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], dtype=tf.float32) 30 | self.zeros = tf.constant([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=tf.float32) 31 | 32 | def test_iou(self): 33 | self.assertAllClose( 34 | iou_utils.iou_loss(self.pb, self.tb, 'iou'), [0.875, 1.]) 35 | 36 | def test_ciou(self): 37 | self.assertAllClose( 38 | iou_utils.iou_loss(self.pb, self.tb, 'ciou'), [0.99931306, 1.6415315]) 39 | 40 | def test_diou(self): 41 | self.assertAllClose( 42 | iou_utils.iou_loss(self.pb, self.tb, 'diou'), [0.9969512, 1.6243094]) 43 | 44 | def test_giou(self): 45 | self.assertAllClose( 46 | iou_utils.iou_loss(self.pb, self.tb, 'giou'), [1.075000, 1.933333]) 47 | 48 | def test_iou_zero_target(self): 49 | self.assertAllClose( 50 | iou_utils.iou_loss(self.pb, self.zeros, 'iou'), [0.0, 0.0]) 51 | self.assertAllClose( 52 | iou_utils.iou_loss(self.pb, self.zeros, 'ciou'), [0.0, 0.0]) 53 | self.assertAllClose( 54 | iou_utils.iou_loss(self.pb, self.zeros, 'diou'), [0.0, 0.0]) 55 | self.assertAllClose( 56 | iou_utils.iou_loss(self.pb, self.zeros, 'giou'), [0.0, 0.0]) 57 | 58 | def test_iou_multiple_anchors(self): 59 | pb = tf.tile(self.pb, [1, 2]) 60 | tb = tf.tile(self.tb, [1, 2]) 61 | self.assertAllClose(iou_utils.iou_loss(pb, tb, 'iou'), [1.75, 2.0]) 62 | 63 | def test_iou_multiple_anchors_mixed(self): 64 | pb = tf.concat([self.pb, self.zeros], axis=-1) 65 | tb = tf.concat([self.tb, self.zeros], axis=-1) 66 | self.assertAllClose(iou_utils.iou_loss(pb, tb, 'iou'), [0.875, 1.0]) 67 | 68 | def test_ciou_grad(self): 69 | pb = tf.concat([self.pb, self.zeros], axis=-1) 70 | tb = tf.concat([self.tb, self.zeros], axis=-1) 71 | with tf.GradientTape() as tape: 72 | tape.watch([pb, tb]) 73 | loss = iou_utils.iou_loss(pb, tb, 'ciou') 74 | grad = tape.gradient(loss, [tb, pb]) 75 | self.assertAlmostEqual(tf.reduce_sum(grad[0]).numpy(), 0.1476393) 76 | self.assertAlmostEqual(tf.reduce_sum(grad[1]).numpy(), -0.14763935) 77 | 78 | 79 | if __name__ == '__main__': 80 | logging.set_verbosity(logging.WARNING) 81 | tf.test.main() 82 | -------------------------------------------------------------------------------- /efficientdet/object_detection/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /efficientdet/object_detection/box_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base box coder. 16 | 17 | Box coders convert between coordinate frames, namely image-centric 18 | (with (0,0) on the top left of image) and anchor-centric (with (0,0) being 19 | defined by a specific anchor). 20 | 21 | Users of a BoxCoder can call two methods: 22 | encode: which encodes a box with respect to a given anchor 23 | (or rather, a tensor of boxes wrt a corresponding tensor of anchors) and 24 | decode: which inverts this encoding with a decode operation. 25 | In both cases, the arguments are assumed to be in 1-1 correspondence already; 26 | it is not the job of a BoxCoder to perform matching. 27 | """ 28 | from abc import ABCMeta 29 | from abc import abstractmethod 30 | from abc import abstractproperty 31 | 32 | import tensorflow.compat.v1 as tf 33 | 34 | 35 | # Box coder types. 36 | FASTER_RCNN = 'faster_rcnn' 37 | KEYPOINT = 'keypoint' 38 | MEAN_STDDEV = 'mean_stddev' 39 | SQUARE = 'square' 40 | 41 | 42 | class BoxCoder(object): 43 | """Abstract base class for box coder.""" 44 | __metaclass__ = ABCMeta 45 | 46 | @abstractproperty 47 | def code_size(self): 48 | """Return the size of each code. 49 | 50 | This number is a constant and should agree with the output of the `encode` 51 | op (e.g. if rel_codes is the output of self.encode(...), then it should have 52 | shape [N, code_size()]). This abstractproperty should be overridden by 53 | implementations. 54 | 55 | Returns: 56 | an integer constant 57 | """ 58 | pass 59 | 60 | def encode(self, boxes, anchors): 61 | """Encode a box list relative to an anchor collection. 62 | 63 | Args: 64 | boxes: BoxList holding N boxes to be encoded 65 | anchors: BoxList of N anchors 66 | 67 | Returns: 68 | a tensor representing N relative-encoded boxes 69 | """ 70 | with tf.name_scope('Encode'): 71 | return self._encode(boxes, anchors) 72 | 73 | def decode(self, rel_codes, anchors): 74 | """Decode boxes that are encoded relative to an anchor collection. 75 | 76 | Args: 77 | rel_codes: a tensor representing N relative-encoded boxes 78 | anchors: BoxList of anchors 79 | 80 | Returns: 81 | boxlist: BoxList holding N boxes encoded in the ordinary way (i.e., 82 | with corners y_min, x_min, y_max, x_max) 83 | """ 84 | with tf.name_scope('Decode'): 85 | return self._decode(rel_codes, anchors) 86 | 87 | @abstractmethod 88 | def _encode(self, boxes, anchors): 89 | """Method to be overridden by implementations. 90 | 91 | Args: 92 | boxes: BoxList holding N boxes to be encoded 93 | anchors: BoxList of N anchors 94 | 95 | Returns: 96 | a tensor representing N relative-encoded boxes 97 | """ 98 | pass 99 | 100 | @abstractmethod 101 | def _decode(self, rel_codes, anchors): 102 | """Method to be overridden by implementations. 103 | 104 | Args: 105 | rel_codes: a tensor representing N relative-encoded boxes 106 | anchors: BoxList of anchors 107 | 108 | Returns: 109 | boxlist: BoxList holding N boxes encoded in the ordinary way (i.e., 110 | with corners y_min, x_min, y_max, x_max) 111 | """ 112 | pass 113 | 114 | 115 | def batch_decode(encoded_boxes, box_coder, anchors): 116 | """Decode a batch of encoded boxes. 117 | 118 | This op takes a batch of encoded bounding boxes and transforms 119 | them to a batch of bounding boxes specified by their corners in 120 | the order of [y_min, x_min, y_max, x_max]. 121 | 122 | Args: 123 | encoded_boxes: a float32 tensor of shape [batch_size, num_anchors, 124 | code_size] representing the location of the objects. 125 | box_coder: a BoxCoder object. 126 | anchors: a BoxList of anchors used to encode `encoded_boxes`. 127 | 128 | Returns: 129 | decoded_boxes: a float32 tensor of shape [batch_size, num_anchors, 130 | coder_size] representing the corners of the objects in the order 131 | of [y_min, x_min, y_max, x_max]. 132 | 133 | Raises: 134 | ValueError: if batch sizes of the inputs are inconsistent, or if 135 | the number of anchors inferred from encoded_boxes and anchors are 136 | inconsistent. 137 | """ 138 | encoded_boxes.get_shape().assert_has_rank(3) 139 | if encoded_boxes.get_shape()[1].value != anchors.num_boxes_static(): 140 | raise ValueError('The number of anchors inferred from encoded_boxes' 141 | ' and anchors are inconsistent: shape[1] of encoded_boxes' 142 | ' %s should be equal to the number of anchors: %s.' % 143 | (encoded_boxes.get_shape()[1].value, 144 | anchors.num_boxes_static())) 145 | 146 | decoded_boxes = tf.stack([ 147 | box_coder.decode(boxes, anchors).get() 148 | for boxes in tf.unstack(encoded_boxes) 149 | ]) 150 | return decoded_boxes 151 | -------------------------------------------------------------------------------- /efficientdet/object_detection/faster_rcnn_box_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Faster RCNN box coder. 16 | 17 | Faster RCNN box coder follows the coding schema described below: 18 | ty = (y - ya) / ha 19 | tx = (x - xa) / wa 20 | th = log(h / ha) 21 | tw = log(w / wa) 22 | where x, y, w, h denote the box's center coordinates, width and height 23 | respectively. Similarly, xa, ya, wa, ha denote the anchor's center 24 | coordinates, width and height. tx, ty, tw and th denote the anchor-encoded 25 | center, width and height respectively. 26 | 27 | See http://arxiv.org/abs/1506.01497 for details. 28 | """ 29 | 30 | import tensorflow.compat.v1 as tf 31 | 32 | from object_detection import box_coder 33 | from object_detection import box_list 34 | 35 | EPSILON = 1e-8 36 | 37 | 38 | class FasterRcnnBoxCoder(box_coder.BoxCoder): 39 | """Faster RCNN box coder.""" 40 | 41 | def __init__(self, scale_factors=None): 42 | """Constructor for FasterRcnnBoxCoder. 43 | 44 | Args: 45 | scale_factors: List of 4 positive scalars to scale ty, tx, th and tw. 46 | If set to None, does not perform scaling. For Faster RCNN, 47 | the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0]. 48 | """ 49 | if scale_factors: 50 | assert len(scale_factors) == 4 51 | for scalar in scale_factors: 52 | assert scalar > 0 53 | self._scale_factors = scale_factors 54 | 55 | @property 56 | def code_size(self): 57 | return 4 58 | 59 | def _encode(self, boxes, anchors): 60 | """Encode a box collection with respect to anchor collection. 61 | 62 | Args: 63 | boxes: BoxList holding N boxes to be encoded. 64 | anchors: BoxList of anchors. 65 | 66 | Returns: 67 | a tensor representing N anchor-encoded boxes of the format 68 | [ty, tx, th, tw]. 69 | """ 70 | # Convert anchors to the center coordinate representation. 71 | ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes() 72 | ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes() 73 | # Avoid NaN in division and log below. 74 | ha = tf.maximum(EPSILON, ha) 75 | wa = tf.maximum(EPSILON, wa) 76 | h = tf.maximum(EPSILON, h) 77 | w = tf.maximum(EPSILON, w) 78 | 79 | tx = (xcenter - xcenter_a) / wa 80 | ty = (ycenter - ycenter_a) / ha 81 | tw = tf.log(w / wa) 82 | th = tf.log(h / ha) 83 | # Scales location targets as used in paper for joint training. 84 | if self._scale_factors: 85 | ty *= self._scale_factors[0] 86 | tx *= self._scale_factors[1] 87 | th *= self._scale_factors[2] 88 | tw *= self._scale_factors[3] 89 | return tf.transpose(tf.stack([ty, tx, th, tw])) 90 | 91 | def _decode(self, rel_codes, anchors): 92 | """Decode relative codes to boxes. 93 | 94 | Args: 95 | rel_codes: a tensor representing N anchor-encoded boxes. 96 | anchors: BoxList of anchors. 97 | 98 | Returns: 99 | boxes: BoxList holding N bounding boxes. 100 | """ 101 | ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes() 102 | 103 | ty, tx, th, tw = tf.unstack(tf.transpose(rel_codes)) 104 | if self._scale_factors: 105 | ty /= self._scale_factors[0] 106 | tx /= self._scale_factors[1] 107 | th /= self._scale_factors[2] 108 | tw /= self._scale_factors[3] 109 | w = tf.exp(tw) * wa 110 | h = tf.exp(th) * ha 111 | ycenter = ty * ha + ycenter_a 112 | xcenter = tx * wa + xcenter_a 113 | ymin = ycenter - h / 2. 114 | xmin = xcenter - w / 2. 115 | ymax = ycenter + h / 2. 116 | xmax = xcenter + w / 2. 117 | return box_list.BoxList(tf.transpose(tf.stack([ymin, xmin, ymax, xmax]))) 118 | -------------------------------------------------------------------------------- /efficientdet/object_detection/region_similarity_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Region Similarity Calculators for BoxLists. 16 | 17 | Region Similarity Calculators compare a pairwise measure of similarity 18 | between the boxes in two BoxLists. 19 | """ 20 | from abc import ABCMeta 21 | from abc import abstractmethod 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | def area(boxlist, scope=None): 27 | """Computes area of boxes. 28 | 29 | Args: 30 | boxlist: BoxList holding N boxes 31 | scope: name scope. 32 | 33 | Returns: 34 | a tensor with shape [N] representing box areas. 35 | """ 36 | with tf.name_scope(scope, 'Area'): 37 | y_min, x_min, y_max, x_max = tf.split( 38 | value=boxlist.get(), num_or_size_splits=4, axis=1) 39 | return tf.squeeze((y_max - y_min) * (x_max - x_min), [1]) 40 | 41 | 42 | def intersection(boxlist1, boxlist2, scope=None): 43 | """Compute pairwise intersection areas between boxes. 44 | 45 | Args: 46 | boxlist1: BoxList holding N boxes 47 | boxlist2: BoxList holding M boxes 48 | scope: name scope. 49 | 50 | Returns: 51 | a tensor with shape [N, M] representing pairwise intersections 52 | """ 53 | with tf.name_scope(scope, 'Intersection'): 54 | y_min1, x_min1, y_max1, x_max1 = tf.split( 55 | value=boxlist1.get(), num_or_size_splits=4, axis=1) 56 | y_min2, x_min2, y_max2, x_max2 = tf.split( 57 | value=boxlist2.get(), num_or_size_splits=4, axis=1) 58 | all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2)) 59 | all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2)) 60 | intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin) 61 | all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2)) 62 | all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2)) 63 | intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin) 64 | return intersect_heights * intersect_widths 65 | 66 | 67 | def iou(boxlist1, boxlist2, scope=None): 68 | """Computes pairwise intersection-over-union between box collections. 69 | 70 | Args: 71 | boxlist1: BoxList holding N boxes 72 | boxlist2: BoxList holding M boxes 73 | scope: name scope. 74 | 75 | Returns: 76 | a tensor with shape [N, M] representing pairwise iou scores. 77 | """ 78 | with tf.name_scope(scope, 'IOU'): 79 | intersections = intersection(boxlist1, boxlist2) 80 | areas1 = area(boxlist1) 81 | areas2 = area(boxlist2) 82 | unions = ( 83 | tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections) 84 | return tf.where( 85 | tf.equal(intersections, 0.0), 86 | tf.zeros_like(intersections), tf.truediv(intersections, unions)) 87 | 88 | 89 | class RegionSimilarityCalculator(object): 90 | """Abstract base class for region similarity calculator.""" 91 | __metaclass__ = ABCMeta 92 | 93 | def compare(self, boxlist1, boxlist2, scope=None): 94 | """Computes matrix of pairwise similarity between BoxLists. 95 | 96 | This op (to be overridden) computes a measure of pairwise similarity between 97 | the boxes in the given BoxLists. Higher values indicate more similarity. 98 | 99 | Note that this method simply measures similarity and does not explicitly 100 | perform a matching. 101 | 102 | Args: 103 | boxlist1: BoxList holding N boxes. 104 | boxlist2: BoxList holding M boxes. 105 | scope: Op scope name. Defaults to 'Compare' if None. 106 | 107 | Returns: 108 | a (float32) tensor of shape [N, M] with pairwise similarity score. 109 | """ 110 | with tf.name_scope(scope, 'Compare', [boxlist1, boxlist2]) as scope: 111 | return self._compare(boxlist1, boxlist2) 112 | 113 | @abstractmethod 114 | def _compare(self, boxlist1, boxlist2): 115 | pass 116 | 117 | 118 | class IouSimilarity(RegionSimilarityCalculator): 119 | """Class to compute similarity based on Intersection over Union (IOU) metric. 120 | 121 | This class computes pairwise similarity between two BoxLists based on IOU. 122 | """ 123 | 124 | def _compare(self, boxlist1, boxlist2): 125 | """Compute pairwise IOU similarity between the two BoxLists. 126 | 127 | Args: 128 | boxlist1: BoxList holding N boxes. 129 | boxlist2: BoxList holding M boxes. 130 | 131 | Returns: 132 | A tensor with shape [N, M] representing pairwise iou scores. 133 | """ 134 | return iou(boxlist1, boxlist2) 135 | -------------------------------------------------------------------------------- /efficientdet/object_detection/shape_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utils used to manipulate tensor shapes.""" 16 | 17 | import tensorflow.compat.v1 as tf 18 | 19 | 20 | def assert_shape_equal(shape_a, shape_b): 21 | """Asserts that shape_a and shape_b are equal. 22 | 23 | If the shapes are static, raises a ValueError when the shapes 24 | mismatch. 25 | 26 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes 27 | mismatch. 28 | 29 | Args: 30 | shape_a: a list containing shape of the first tensor. 31 | shape_b: a list containing shape of the second tensor. 32 | 33 | Returns: 34 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op 35 | when the shapes are dynamic. 36 | 37 | Raises: 38 | ValueError: When shapes are both static and unequal. 39 | """ 40 | if (all(isinstance(dim, int) for dim in shape_a) and 41 | all(isinstance(dim, int) for dim in shape_b)): 42 | if shape_a != shape_b: 43 | raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) 44 | else: return tf.no_op() 45 | else: 46 | return tf.assert_equal(shape_a, shape_b) 47 | 48 | 49 | def combined_static_and_dynamic_shape(tensor): 50 | """Returns a list containing static and dynamic values for the dimensions. 51 | 52 | Returns a list of static and dynamic values for shape dimensions. This is 53 | useful to preserve static shapes when available in reshape operation. 54 | 55 | Args: 56 | tensor: A tensor of any type. 57 | 58 | Returns: 59 | A list of size tensor.shape.ndims containing integers or a scalar tensor. 60 | """ 61 | static_tensor_shape = tensor.shape.as_list() 62 | dynamic_tensor_shape = tf.shape(tensor) 63 | combined_shape = [] 64 | for index, dim in enumerate(static_tensor_shape): 65 | if dim is not None: 66 | combined_shape.append(dim) 67 | else: 68 | combined_shape.append(dynamic_tensor_shape[index]) 69 | return combined_shape 70 | -------------------------------------------------------------------------------- /efficientdet/requirements.txt: -------------------------------------------------------------------------------- 1 | lxml>=4.6.1 2 | absl-py>=0.10.0 3 | matplotlib>=3.0.3 4 | numpy>=1.19.4,<1.24.0 5 | Pillow>=9.5.0 6 | PyYAML>=5.1 7 | six>=1.15.0 8 | tensorflow>=2.10.0,<2.16.0 9 | tensorflow-addons>=0.18.0 10 | tensorflow-hub>=0.11 11 | neural-structured-learning>=1.3.1 12 | Cython>=0.29.13 13 | git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI 14 | -------------------------------------------------------------------------------- /efficientdet/run_tflite.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Run TF Lite model.""" 16 | from absl import app 17 | from absl import flags 18 | 19 | from PIL import Image 20 | import tensorflow as tf 21 | 22 | import inference 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | 27 | def define_flags(): 28 | """Define flags.""" 29 | flags.DEFINE_string('tflite_path', None, 'Path of tflite file.') 30 | flags.DEFINE_string('sample_image', None, 'Sample image path') 31 | flags.DEFINE_string('output_image', None, 'Output image path') 32 | flags.DEFINE_string('image_size', '512x512', 'Image size "WxH".') 33 | 34 | 35 | def load_image(image_path, image_size): 36 | """Loads an image, and returns numpy.ndarray. 37 | 38 | Args: 39 | image_path: str, path to image. 40 | image_size: list of int, representing [width, height]. 41 | 42 | Returns: 43 | image_batch: numpy.ndarray of shape [1, H, W, C]. 44 | """ 45 | input_data = tf.io.gfile.GFile(image_path, 'rb').read() 46 | image = tf.io.decode_image(input_data, channels=3, dtype=tf.uint8) 47 | image = tf.image.resize( 48 | image, image_size, method='bilinear', antialias=True) 49 | return tf.expand_dims(tf.cast(image, tf.uint8), 0).numpy() 50 | 51 | 52 | def save_visualized_image(image, prediction, output_path): 53 | """Saves the visualized image with prediction. 54 | 55 | Args: 56 | image: numpy.ndarray of shape [H, W, C]. 57 | prediction: numpy.ndarray of shape [num_predictions, 7]. 58 | output_path: str, output image path. 59 | """ 60 | output_image = inference.visualize_image_prediction( 61 | image, 62 | prediction, 63 | label_map='coco') 64 | Image.fromarray(output_image).save(output_path) 65 | 66 | 67 | class TFLiteRunner: 68 | """Wrapper to run TFLite model.""" 69 | 70 | def __init__(self, model_path): 71 | """Init. 72 | 73 | Args: 74 | model_path: str, path to tflite model. 75 | """ 76 | self.interpreter = tf.lite.Interpreter(model_path=model_path) 77 | self.interpreter.allocate_tensors() 78 | self.input_index = self.interpreter.get_input_details()[0]['index'] 79 | self.output_index = self.interpreter.get_output_details()[0]['index'] 80 | 81 | def run(self, image): 82 | """Run inference on a single images. 83 | 84 | Args: 85 | image: numpy.ndarray of shape [1, H, W, C]. 86 | 87 | Returns: 88 | prediction: numpy.ndarray of shape [1, num_detections, 7]. 89 | """ 90 | self.interpreter.set_tensor(self.input_index, image) 91 | self.interpreter.invoke() 92 | return self.interpreter.get_tensor(self.output_index) 93 | 94 | 95 | def main(_): 96 | image_size = [int(dim) for dim in FLAGS.image_size.split('x')] 97 | image = load_image(FLAGS.sample_image, image_size) 98 | 99 | runner = TFLiteRunner(FLAGS.tflite_path) 100 | prediction = runner.run(image) 101 | 102 | save_visualized_image(image[0], prediction[0], FLAGS.output_image) 103 | 104 | 105 | if __name__ == '__main__': 106 | define_flags() 107 | app.run(main) 108 | -------------------------------------------------------------------------------- /efficientdet/tensorrt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Simple tools for TensorRT. 16 | 17 | Example usage: 18 | 19 | $ export ROOT=/tmp/d4 20 | $ python model_inspect.py --runmode=freeze --model_name=efficientdet-d4 \ 21 | --logdir=$ROOT # --hparams=xyz.yaml 22 | $ python tensorrt.py --tf_savedmodel_dir=$ROOT/savedmodel \ 23 | --trt_savedmodel_dir=$ROOT/trtmodel 24 | """ 25 | import time 26 | from absl import app 27 | from absl import flags 28 | import numpy as np 29 | import tensorflow.compat.v1 as tf 30 | # pylint: disable=g-direct-tensorflow-import 31 | from tensorflow.python.compiler.tensorrt import trt_convert as trt 32 | 33 | flags.DEFINE_string('tf_savedmodel_dir', None, 'TensorFlow saved model dir.') 34 | flags.DEFINE_string('trt_savedmodel_dir', None, 'TensorRT saved model dir.') 35 | FLAGS = flags.FLAGS 36 | 37 | 38 | def convert2trt(tf_savedmodel_dir: str, trt_savedmodel_dir: str): 39 | converter = trt.TrtGraphConverter( 40 | input_saved_model_dir=tf_savedmodel_dir, 41 | max_workspace_size_bytes=(2 << 20), 42 | precision_mode='FP16', 43 | maximum_cached_engines=1) 44 | converter.convert() 45 | converter.save(trt_savedmodel_dir) 46 | 47 | 48 | def benchmark(trt_savedmodel_dir: str, warmup_runs: int = 5, bm_runs: int = 20): 49 | """Benchmark TRT latency for a given TRT saved model.""" 50 | with tf.Session() as sess: 51 | # First load the Saved Model into the session 52 | tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 53 | trt_savedmodel_dir) 54 | graph = tf.get_default_graph() 55 | input_shape = graph.get_tensor_by_name('input:0').shape 56 | x = np.ones(input_shape).astype(np.float32) 57 | ss = lambda i: '' if i == 0 else '_%d' % i 58 | outputs = ['box_net/box-predict%s/BiasAdd:0' % ss(i) for i in range(1)] 59 | outputs += ['class_net/class-predict%s/BiasAdd:0' % ss(i) for i in range(5)] 60 | # Apply reduce_sum to avoid massive data move between GPU and CPU. 61 | outputs = [tf.reduce_sum(graph.get_tensor_by_name(i)) for i in outputs] 62 | 63 | # warmup 64 | for _ in range(warmup_runs): 65 | sess.run(outputs, feed_dict={'input:0': x}) 66 | # benchmark 67 | s = time.perf_counter() 68 | for _ in range(bm_runs): 69 | sess.run(outputs, feed_dict={'input:0': x}) 70 | e = time.perf_counter() 71 | print('Benchmark latency=%.4f FPS=%.2f', (e - s) / bm_runs, 72 | bm_runs / (e - s)) 73 | 74 | 75 | def main(_): 76 | if FLAGS.tf_savedmodel_dir: 77 | convert2trt(FLAGS.tf_savedmodel_dir, FLAGS.trt_savedmodel_dir) 78 | benchmark(FLAGS.trt_savedmodel_dir, FLAGS.warmup_runs, FLAGS.bm_runs) 79 | 80 | 81 | if __name__ == '__main__': 82 | flags.mark_flag_as_required('trt_savedmodel_dir') 83 | tf.disable_v2_behavior() 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /efficientdet/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | for file in `find $PWD/efficientdet -name '*.py'` 17 | do 18 | pylint --rcfile=.pylintrc $file 19 | done 20 | 21 | cd efficientdet 22 | for file in `find $PWD -name '*_test.py'` 23 | do 24 | PYTHONPATH=$PWD TF_CPP_MIN_LOG_LEVEL=1 python $file 25 | done -------------------------------------------------------------------------------- /efficientdet/test_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Test utilities.""" 16 | import os 17 | 18 | import tensorflow as tf 19 | from dataset import tfrecord_util 20 | 21 | 22 | def make_fake_tfrecord(temp_dir): 23 | """Makes fake TFRecord to test input.""" 24 | tfrecord_path = os.path.join(temp_dir, 'test.tfrecords') 25 | writer = tf.io.TFRecordWriter(tfrecord_path) 26 | encoded_jpg = tf.io.encode_jpeg(tf.ones([512, 512, 3], dtype=tf.uint8)) 27 | example = tf.train.Example( 28 | features=tf.train.Features( 29 | feature={ 30 | 'image/height': 31 | tfrecord_util.int64_feature(512), 32 | 'image/width': 33 | tfrecord_util.int64_feature(512), 34 | 'image/filename': 35 | tfrecord_util.bytes_feature('test_file_name.jpg'.encode( 36 | 'utf8')), 37 | 'image/source_id': 38 | tfrecord_util.bytes_feature('123456'.encode('utf8')), 39 | 'image/key/sha256': 40 | tfrecord_util.bytes_feature('qwdqwfw12345'.encode('utf8')), 41 | 'image/encoded': 42 | tfrecord_util.bytes_feature(encoded_jpg.numpy()), 43 | 'image/format': 44 | tfrecord_util.bytes_feature('jpeg'.encode('utf8')), 45 | 'image/object/bbox/xmin': 46 | tfrecord_util.float_list_feature([0.1]), 47 | 'image/object/bbox/xmax': 48 | tfrecord_util.float_list_feature([0.1]), 49 | 'image/object/bbox/ymin': 50 | tfrecord_util.float_list_feature([0.2]), 51 | 'image/object/bbox/ymax': 52 | tfrecord_util.float_list_feature([0.2]), 53 | 'image/object/class/text': 54 | tfrecord_util.bytes_list_feature(['test'.encode('utf8')]), 55 | 'image/object/class/label': 56 | tfrecord_util.int64_list_feature([1]), 57 | 'image/object/difficult': 58 | tfrecord_util.int64_list_feature([]), 59 | 'image/object/truncated': 60 | tfrecord_util.int64_list_feature([]), 61 | 'image/object/view': 62 | tfrecord_util.bytes_list_feature([]), 63 | })) 64 | writer.write(example.SerializeToString()) 65 | return tfrecord_path 66 | -------------------------------------------------------------------------------- /efficientdet/testdata/img1-d1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/testdata/img1-d1.jpg -------------------------------------------------------------------------------- /efficientdet/testdata/img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/testdata/img1.jpg -------------------------------------------------------------------------------- /efficientdet/tf2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientdet/tf2/__init__.py -------------------------------------------------------------------------------- /efficientdet/tf2/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Eval libraries.""" 16 | from absl import app 17 | from absl import flags 18 | from absl import logging 19 | import tensorflow as tf 20 | 21 | import coco_metric 22 | import dataloader 23 | import hparams_config 24 | import utils 25 | from tf2 import anchors 26 | from tf2 import efficientdet_keras 27 | from tf2 import label_util 28 | from tf2 import postprocess 29 | from tf2 import util_keras 30 | 31 | # Cloud TPU Cluster Resolvers 32 | flags.DEFINE_string('tpu', None, 'The Cloud TPU name.') 33 | flags.DEFINE_string('gcp_project', None, 'Project name.') 34 | flags.DEFINE_string('tpu_zone', None, 'GCE zone name.') 35 | 36 | flags.DEFINE_integer('eval_samples', None, 'Number of eval samples.') 37 | flags.DEFINE_string('val_file_pattern', None, 38 | 'Glob for eval tfrecords, e.g. coco/val-*.tfrecord.') 39 | flags.DEFINE_string('val_json_file', None, 40 | 'Groudtruth, e.g. annotations/instances_val2017.json.') 41 | flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name to use.') 42 | flags.DEFINE_string('model_dir', None, 'Location of the checkpoint to run.') 43 | flags.DEFINE_integer('batch_size', 8, 'GLobal batch size.') 44 | flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file') 45 | FLAGS = flags.FLAGS 46 | 47 | 48 | def main(_): 49 | config = hparams_config.get_efficientdet_config(FLAGS.model_name) 50 | config.override(FLAGS.hparams) 51 | config.val_json_file = FLAGS.val_json_file 52 | config.nms_configs.max_nms_inputs = anchors.MAX_DETECTION_POINTS 53 | config.drop_remainder = False # eval all examples w/o drop. 54 | config.image_size = utils.parse_image_size(config['image_size']) 55 | 56 | if config.strategy == 'tpu': 57 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 58 | FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 59 | tf.config.experimental_connect_to_cluster(tpu_cluster_resolver) 60 | tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver) 61 | ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver) 62 | logging.info('All devices: %s', tf.config.list_logical_devices('TPU')) 63 | elif config.strategy == 'gpus': 64 | ds_strategy = tf.distribute.MirroredStrategy() 65 | logging.info('All devices: %s', tf.config.list_physical_devices('GPU')) 66 | else: 67 | if tf.config.list_physical_devices('GPU'): 68 | ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0') 69 | else: 70 | ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0') 71 | 72 | with ds_strategy.scope(): 73 | # Network 74 | model = efficientdet_keras.EfficientDetNet(config=config) 75 | model.build((None, *config.image_size, 3)) 76 | util_keras.restore_ckpt(model, 77 | tf.train.latest_checkpoint(FLAGS.model_dir), 78 | config.moving_average_decay, 79 | skip_mismatch=False) 80 | @tf.function 81 | def model_fn(images, labels): 82 | cls_outputs, box_outputs = model(images, training=False) 83 | detections = postprocess.generate_detections(config, 84 | cls_outputs, 85 | box_outputs, 86 | labels['image_scales'], 87 | labels['source_ids']) 88 | tf.numpy_function(evaluator.update_state, 89 | [labels['groundtruth_data'], 90 | postprocess.transform_detections(detections)], []) 91 | 92 | # Evaluator for AP calculation. 93 | label_map = label_util.get_label_map(config.label_map) 94 | evaluator = coco_metric.EvaluationMetric( 95 | filename=config.val_json_file, label_map=label_map) 96 | 97 | # dataset 98 | batch_size = FLAGS.batch_size # global batch size. 99 | ds = dataloader.InputReader( 100 | FLAGS.val_file_pattern, 101 | is_training=False, 102 | max_instances_per_image=config.max_instances_per_image)( 103 | config, batch_size=batch_size) 104 | if FLAGS.eval_samples: 105 | ds = ds.take((FLAGS.eval_samples + batch_size - 1) // batch_size) 106 | ds = ds_strategy.experimental_distribute_dataset(ds) 107 | 108 | # evaluate all images. 109 | eval_samples = FLAGS.eval_samples or 5000 110 | pbar = tf.keras.utils.Progbar((eval_samples + batch_size - 1) // batch_size) 111 | for i, (images, labels) in enumerate(ds): 112 | ds_strategy.run(model_fn, (images, labels)) 113 | pbar.update(i) 114 | 115 | # compute the final eval results. 116 | metrics = evaluator.result() 117 | metric_dict = {} 118 | for i, name in enumerate(evaluator.metric_names): 119 | metric_dict[name] = metrics[i] 120 | 121 | if label_map: 122 | for i, cid in enumerate(sorted(label_map.keys())): 123 | name = 'AP_/%s' % label_map[cid] 124 | metric_dict[name] = metrics[i + len(evaluator.metric_names)] 125 | print(FLAGS.model_name, metric_dict) 126 | 127 | 128 | if __name__ == '__main__': 129 | flags.mark_flag_as_required('val_file_pattern') 130 | flags.mark_flag_as_required('model_dir') 131 | logging.set_verbosity(logging.ERROR) 132 | app.run(main) 133 | -------------------------------------------------------------------------------- /efficientdet/tf2/fpn_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """BiFPN/QuFPN and other FPN configs. 16 | 17 | BiFPN is presented in the EfficientDet paper. 18 | QuFPN is proposed in https://github.com/google/automl/pull/580 19 | """ 20 | import itertools 21 | import hparams_config 22 | 23 | 24 | def bifpn_config(min_level, max_level, weight_method): 25 | """A dynamic bifpn config that can adapt to different min/max levels.""" 26 | p = hparams_config.Config() 27 | p.weight_method = weight_method or 'fastattn' 28 | 29 | # Node id starts from the input features and monotonically increase whenever 30 | # a new node is added. Here is an example for level P3 - P7: 31 | # P7 (4) P7" (12) 32 | # P6 (3) P6' (5) P6" (11) 33 | # P5 (2) P5' (6) P5" (10) 34 | # P4 (1) P4' (7) P4" (9) 35 | # P3 (0) P3" (8) 36 | # So output would be like: 37 | # [ 38 | # {'feat_level': 6, 'inputs_offsets': [3, 4]}, # for P6' 39 | # {'feat_level': 5, 'inputs_offsets': [2, 5]}, # for P5' 40 | # {'feat_level': 4, 'inputs_offsets': [1, 6]}, # for P4' 41 | # {'feat_level': 3, 'inputs_offsets': [0, 7]}, # for P3" 42 | # {'feat_level': 4, 'inputs_offsets': [1, 7, 8]}, # for P4" 43 | # {'feat_level': 5, 'inputs_offsets': [2, 6, 9]}, # for P5" 44 | # {'feat_level': 6, 'inputs_offsets': [3, 5, 10]}, # for P6" 45 | # {'feat_level': 7, 'inputs_offsets': [4, 11]}, # for P7" 46 | # ] 47 | num_levels = max_level - min_level + 1 48 | node_ids = {min_level + i: [i] for i in range(num_levels)} 49 | 50 | level_last_id = lambda level: node_ids[level][-1] 51 | level_all_ids = lambda level: node_ids[level] 52 | id_cnt = itertools.count(num_levels) 53 | 54 | p.nodes = [] 55 | for i in range(max_level - 1, min_level - 1, -1): 56 | # top-down path. 57 | p.nodes.append({ 58 | 'feat_level': i, 59 | 'inputs_offsets': [level_last_id(i), 60 | level_last_id(i + 1)] 61 | }) 62 | node_ids[i].append(next(id_cnt)) 63 | 64 | for i in range(min_level + 1, max_level + 1): 65 | # bottom-up path. 66 | p.nodes.append({ 67 | 'feat_level': i, 68 | 'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)] 69 | }) 70 | node_ids[i].append(next(id_cnt)) 71 | 72 | return p 73 | 74 | 75 | def qufpn_config(min_level, max_level, weight_method=None): 76 | """A dynamic quad fpn config that can adapt to different min/max levels.""" 77 | # It extends the idea of BiFPN, and has four paths: 78 | # (up_down -> bottom_up) + (bottom_up -> up_down). 79 | # See test for an example for level 2 and 7. 80 | p = hparams_config.Config() 81 | p.weight_method = weight_method or 'fastattn' 82 | p.quad_method = 'fastattn' 83 | num_levels = max_level - min_level + 1 84 | node_ids = {min_level + i: [i] for i in range(num_levels)} 85 | level_last_id = lambda level: node_ids[level][-1] 86 | level_all_ids = lambda level: node_ids[level] 87 | level_first_id = lambda level: node_ids[level][0] 88 | id_cnt = itertools.count(num_levels) 89 | 90 | p.nodes = [] 91 | for i in range(max_level - 1, min_level - 1, -1): 92 | # top-down path 1. 93 | p.nodes.append({ 94 | 'feat_level': i, 95 | 'inputs_offsets': [level_last_id(i), 96 | level_last_id(i + 1)], 97 | 'weight_method': p.weight_method 98 | }) 99 | node_ids[i].append(next(id_cnt)) 100 | node_ids[max_level].append(node_ids[max_level][-1]) 101 | 102 | for i in range(min_level + 1, max_level): 103 | # bottom-up path 2. 104 | p.nodes.append({ 105 | 'feat_level': i, 106 | 'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)], 107 | 'weight_method': p.weight_method 108 | }) 109 | node_ids[i].append(next(id_cnt)) 110 | 111 | i = max_level 112 | p.nodes.append({ 113 | 'feat_level': i, 114 | 'inputs_offsets': [level_first_id(i)] + [level_last_id(i - 1)], 115 | 'weight_method': p.weight_method 116 | }) 117 | node_ids[i].append(next(id_cnt)) 118 | node_ids[min_level].append(node_ids[min_level][-1]) 119 | 120 | for i in range(min_level + 1, max_level + 1, 1): 121 | # bottom-up path 3. 122 | p.nodes.append({ 123 | 'feat_level': i, 124 | 'inputs_offsets': [ 125 | level_first_id(i), 126 | level_last_id(i - 1) if i != min_level + 1 else level_first_id(i - 127 | 1) 128 | ], 129 | 'weight_method': p.weight_method 130 | }) 131 | node_ids[i].append(next(id_cnt)) 132 | node_ids[min_level].append(node_ids[min_level][-1]) 133 | 134 | for i in range(max_level - 1, min_level, -1): 135 | # top-down path 4. 136 | p.nodes.append({ 137 | 'feat_level': 138 | i, 139 | 'inputs_offsets': [node_ids[i][0]] + [node_ids[i][-1]] + 140 | [level_last_id(i + 1)], 141 | 'weight_method': 142 | p.weight_method 143 | }) 144 | node_ids[i].append(next(id_cnt)) 145 | i = min_level 146 | p.nodes.append({ 147 | 'feat_level': i, 148 | 'inputs_offsets': [node_ids[i][0]] + [level_last_id(i + 1)], 149 | 'weight_method': p.weight_method 150 | }) 151 | node_ids[i].append(next(id_cnt)) 152 | node_ids[max_level].append(node_ids[max_level][-1]) 153 | 154 | for i in range(max_level, min_level - 1, -1): 155 | # quad-add path. 156 | p.nodes.append({ 157 | 'feat_level': i, 158 | 'inputs_offsets': [node_ids[i][2], node_ids[i][4]], 159 | 'weight_method': p.quad_method 160 | }) 161 | node_ids[i].append(next(id_cnt)) 162 | 163 | return p 164 | 165 | 166 | def get_fpn_config(fpn_name, min_level, max_level, weight_method): 167 | """Get fpn related configuration.""" 168 | if not fpn_name: 169 | fpn_name = 'bifpn' 170 | name_to_config = { 171 | 'bifpn': bifpn_config(min_level, max_level, weight_method), 172 | 'qufpn': qufpn_config(min_level, max_level, weight_method), 173 | # legacy only: to be deprecated. 174 | 'bifpn_dyn': bifpn_config(min_level, max_level, weight_method), 175 | } 176 | return name_to_config[fpn_name] 177 | -------------------------------------------------------------------------------- /efficientdet/tf2/fpn_configs_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for fpn_configs.""" 16 | from absl import logging 17 | import tensorflow as tf 18 | from tf2 import fpn_configs 19 | 20 | 21 | class FpnConfigTest(tf.test.TestCase): 22 | 23 | def test_bifpn_l3l7(self): 24 | p1 = fpn_configs.bifpn_config(3, 7, None) 25 | # pyformat: disable 26 | self.assertEqual( 27 | p1.nodes, 28 | [ 29 | {'feat_level': 6, 'inputs_offsets': [3, 4]}, 30 | {'feat_level': 5, 'inputs_offsets': [2, 5]}, 31 | {'feat_level': 4, 'inputs_offsets': [1, 6]}, 32 | {'feat_level': 3, 'inputs_offsets': [0, 7]}, 33 | {'feat_level': 4, 'inputs_offsets': [1, 7, 8]}, 34 | {'feat_level': 5, 'inputs_offsets': [2, 6, 9]}, 35 | {'feat_level': 6, 'inputs_offsets': [3, 5, 10]}, 36 | {'feat_level': 7, 'inputs_offsets': [4, 11]}, 37 | ]) 38 | # pyformat: enable 39 | 40 | def test_bifpn_l2l7(self): 41 | p = fpn_configs.bifpn_config(2, 7, None) 42 | 43 | # pyformat: disable 44 | self.assertEqual( 45 | p.nodes, 46 | [ 47 | {'feat_level': 6, 'inputs_offsets': [4, 5]}, 48 | {'feat_level': 5, 'inputs_offsets': [3, 6]}, 49 | {'feat_level': 4, 'inputs_offsets': [2, 7]}, 50 | {'feat_level': 3, 'inputs_offsets': [1, 8]}, 51 | {'feat_level': 2, 'inputs_offsets': [0, 9]}, 52 | {'feat_level': 3, 'inputs_offsets': [1, 9, 10]}, 53 | {'feat_level': 4, 'inputs_offsets': [2, 8, 11]}, 54 | {'feat_level': 5, 'inputs_offsets': [3, 7, 12]}, 55 | {'feat_level': 6, 'inputs_offsets': [4, 6, 13]}, 56 | {'feat_level': 7, 'inputs_offsets': [5, 14]}, 57 | ]) 58 | # pyformat: enable 59 | 60 | def test_qufpn_dynamic_l3l7(self): 61 | p = fpn_configs.qufpn_config(3, 7, None) 62 | 63 | # pyformat: disable 64 | # pylint: disable=line-too-long 65 | self.assertEqual( 66 | p.nodes, 67 | [ 68 | {'feat_level': 6, 'inputs_offsets': [3, 4], 'weight_method': 'fastattn'}, 69 | {'feat_level': 5, 'inputs_offsets': [2, 5], 'weight_method': 'fastattn'}, 70 | {'feat_level': 4, 'inputs_offsets': [1, 6], 'weight_method': 'fastattn'}, 71 | {'feat_level': 3, 'inputs_offsets': [0, 7], 'weight_method': 'fastattn'}, 72 | {'feat_level': 4, 'inputs_offsets': [1, 7, 8], 'weight_method': 'fastattn'}, 73 | {'feat_level': 5, 'inputs_offsets': [2, 6, 9], 'weight_method': 'fastattn'}, 74 | {'feat_level': 6, 'inputs_offsets': [3, 5, 10], 'weight_method': 'fastattn'}, 75 | {'feat_level': 7, 'inputs_offsets': [4, 11], 'weight_method': 'fastattn'}, 76 | {'feat_level': 4, 'inputs_offsets': [1, 0], 'weight_method': 'fastattn'}, 77 | {'feat_level': 5, 'inputs_offsets': [2, 13], 'weight_method': 'fastattn'}, 78 | {'feat_level': 6, 'inputs_offsets': [3, 14], 'weight_method': 'fastattn'}, 79 | {'feat_level': 7, 'inputs_offsets': [4, 15], 'weight_method': 'fastattn'}, 80 | {'feat_level': 6, 'inputs_offsets': [3, 15, 16], 'weight_method': 'fastattn'}, 81 | {'feat_level': 5, 'inputs_offsets': [2, 14, 17], 'weight_method': 'fastattn'}, 82 | {'feat_level': 4, 'inputs_offsets': [1, 13, 18], 'weight_method': 'fastattn'}, 83 | {'feat_level': 3, 'inputs_offsets': [0, 19], 'weight_method': 'fastattn'}, 84 | {'feat_level': 7, 'inputs_offsets': [12, 16], 'weight_method': 'fastattn'}, 85 | {'feat_level': 6, 'inputs_offsets': [11, 17], 'weight_method': 'fastattn'}, 86 | {'feat_level': 5, 'inputs_offsets': [10, 18], 'weight_method': 'fastattn'}, 87 | {'feat_level': 4, 'inputs_offsets': [9, 19], 'weight_method': 'fastattn'}, 88 | {'feat_level': 3, 'inputs_offsets': [8, 20], 'weight_method': 'fastattn'}, 89 | ]) 90 | # pylint: enable=line-too-long 91 | # pyformat: enable 92 | 93 | 94 | if __name__ == '__main__': 95 | logging.set_verbosity(logging.WARNING) 96 | tf.test.main() 97 | -------------------------------------------------------------------------------- /efficientdet/tf2/infer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A simple example on how to use keras model for inference.""" 16 | import os 17 | from absl import app 18 | from absl import flags 19 | from absl import logging 20 | import numpy as np 21 | from PIL import Image 22 | import tensorflow as tf 23 | 24 | import hparams_config 25 | import inference 26 | from tf2 import efficientdet_keras 27 | 28 | flags.DEFINE_string('image_path', None, 'Location of test image.') 29 | flags.DEFINE_string('output_dir', None, 'Directory of annotated output images.') 30 | flags.DEFINE_string('model_dir', None, 'Location of the checkpoint to run.') 31 | flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name to use.') 32 | flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file') 33 | flags.DEFINE_bool('debug', False, 'If true, run function in eager for debug.') 34 | flags.DEFINE_string('saved_model_dir', None, 'Saved model directory') 35 | FLAGS = flags.FLAGS 36 | 37 | 38 | def main(_): 39 | 40 | # pylint: disable=line-too-long 41 | # Prepare images and checkpoints: please run these commands in shell. 42 | # !mkdir tmp 43 | # !wget https://user-images.githubusercontent.com/11736571/77320690-099af300-6d37-11ea-9d86-24f14dc2d540.png -O tmp/img.png 44 | # !wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-d0.tar.gz -O tmp/efficientdet-d0.tar.gz 45 | # !tar zxf tmp/efficientdet-d0.tar.gz -C tmp 46 | imgs = [np.array(Image.open(FLAGS.image_path))] 47 | # Create model config. 48 | config = hparams_config.get_efficientdet_config(FLAGS.model_name) 49 | config.is_training_bn = False 50 | config.image_size = '1920x1280' 51 | config.nms_configs.score_thresh = 0.4 52 | config.nms_configs.max_output_size = 100 53 | config.override(FLAGS.hparams) 54 | 55 | # Use 'mixed_float16' if running on GPUs. 56 | policy = tf.keras.mixed_precision.Policy('float32') 57 | tf.keras.mixed_precision.set_global_policy(policy) 58 | tf.config.run_functions_eagerly(FLAGS.debug) 59 | 60 | # Create and run the model. 61 | model = efficientdet_keras.EfficientDetModel(config=config) 62 | model.build((None, None, None, 3)) 63 | model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir)) 64 | model.summary(expand_nested=True) 65 | 66 | class ExportModel(tf.Module): 67 | 68 | def __init__(self, model): 69 | super().__init__() 70 | self.model = model 71 | 72 | @tf.function 73 | def f(self, imgs): 74 | return self.model(imgs, training=False, post_mode='global') 75 | 76 | imgs = tf.convert_to_tensor(imgs, dtype=tf.uint8) 77 | export_model = ExportModel(model) 78 | if FLAGS.saved_model_dir: 79 | tf.saved_model.save( 80 | export_model, 81 | FLAGS.saved_model_dir, 82 | signatures=export_model.f.get_concrete_function( 83 | tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8))) 84 | export_model = tf.saved_model.load(FLAGS.saved_model_dir) 85 | 86 | boxes, scores, classes, valid_len = export_model.f(imgs) 87 | 88 | # Visualize results. 89 | for i, img in enumerate(imgs): 90 | length = valid_len[i] 91 | img = inference.visualize_image( 92 | img, 93 | boxes[i].numpy()[:length], 94 | classes[i].numpy().astype(np.int)[:length], 95 | scores[i].numpy()[:length], 96 | label_map=config.label_map, 97 | min_score_thresh=config.nms_configs.score_thresh, 98 | max_boxes_to_draw=config.nms_configs.max_output_size) 99 | output_image_path = os.path.join(FLAGS.output_dir, str(i) + '.jpg') 100 | Image.fromarray(img).save(output_image_path) 101 | print('writing annotated image to %s' % output_image_path) 102 | 103 | 104 | if __name__ == '__main__': 105 | flags.mark_flag_as_required('image_path') 106 | flags.mark_flag_as_required('output_dir') 107 | flags.mark_flag_as_required('model_dir') 108 | logging.set_verbosity(logging.ERROR) 109 | app.run(main) 110 | -------------------------------------------------------------------------------- /efficientdet/tf2/infer_lib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Inference test cases.""" 16 | import os 17 | import tempfile 18 | from absl import logging 19 | import tensorflow as tf 20 | import test_util 21 | from tf2 import efficientdet_keras 22 | from tf2 import infer_lib 23 | 24 | 25 | class InferenceTest(tf.test.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | tf.random.set_seed(111111) 30 | model = efficientdet_keras.EfficientDetModel('efficientdet-d0') 31 | self.tmp_path = tempfile.mkdtemp() 32 | model.build([1, 512, 512, 3]) 33 | model.save_weights(os.path.join(self.tmp_path, 'model')) 34 | 35 | lite_model = efficientdet_keras.EfficientDetModel('efficientdet-lite0') 36 | self.lite_tmp_path = tempfile.mkdtemp() 37 | lite_model.build([1, 512, 512, 3]) 38 | lite_model.save_weights(os.path.join(self.lite_tmp_path, 'model')) 39 | 40 | def test_export(self): 41 | saved_model_path = os.path.join(self.tmp_path, 'saved_model') 42 | driver = infer_lib.KerasDriver(self.tmp_path, False, 'efficientdet-d0') 43 | driver.export(saved_model_path) 44 | has_saved_model = tf.saved_model.contains_saved_model(saved_model_path) 45 | self.assertAllEqual(has_saved_model, True) 46 | driver = infer_lib.SavedModelDriver(saved_model_path, 'efficientdet-d0') 47 | fg_path = os.path.join(saved_model_path, 'efficientdet-d0_frozen.pb') 48 | driver = infer_lib.SavedModelDriver(fg_path, 'efficientdet-d0') 49 | 50 | def test_export_tflite_only_network(self): 51 | saved_model_path = os.path.join(self.lite_tmp_path, 'saved_model') 52 | driver = infer_lib.KerasDriver( 53 | self.lite_tmp_path, False, 'efficientdet-lite0', only_network=True) 54 | driver.export(saved_model_path, tflite='FP32') 55 | self.assertTrue( 56 | tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite'))) 57 | tf.io.gfile.rmtree(saved_model_path) 58 | driver.export(saved_model_path, tflite='FP16') 59 | self.assertTrue( 60 | tf.io.gfile.exists(os.path.join(saved_model_path, 'fp16.tflite'))) 61 | tf.io.gfile.rmtree(saved_model_path) 62 | tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir()) 63 | driver.export( 64 | saved_model_path, 65 | tflite='INT8', 66 | file_pattern=[tfrecord_path], 67 | num_calibration_steps=1) 68 | self.assertTrue( 69 | tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite'))) 70 | 71 | def test_export_tflite_with_post_processing(self): 72 | saved_model_path = os.path.join(self.lite_tmp_path, 'saved_model') 73 | driver = infer_lib.KerasDriver( 74 | self.lite_tmp_path, False, 'efficientdet-lite0', only_network=False) 75 | driver.export(saved_model_path, tflite='FP32') 76 | self.assertTrue( 77 | tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite'))) 78 | tf.io.gfile.rmtree(saved_model_path) 79 | tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir()) 80 | driver.export( 81 | saved_model_path, 82 | tflite='INT8', 83 | file_pattern=[tfrecord_path], 84 | num_calibration_steps=1) 85 | self.assertTrue( 86 | tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite'))) 87 | 88 | def test_infer_lib(self): 89 | driver = infer_lib.KerasDriver(self.tmp_path, False, 'efficientdet-d0') 90 | images = tf.ones((1, 512, 512, 3)) 91 | boxes, scores, classes, valid_lens = driver.serve(images) 92 | self.assertEqual(tf.reduce_mean(boxes), 163.09) 93 | self.assertEqual(tf.reduce_mean(scores), 0.01) 94 | self.assertEqual(tf.reduce_mean(classes), 1) 95 | self.assertEqual(tf.reduce_mean(valid_lens), 100) 96 | self.assertEqual(boxes.shape, (1, 100, 4)) 97 | self.assertEqual(scores.shape, (1, 100)) 98 | self.assertEqual(classes.shape, (1, 100)) 99 | self.assertEqual(valid_lens.shape, (1,)) 100 | 101 | def test_infer_lib_without_ema(self): 102 | driver = infer_lib.KerasDriver( 103 | self.tmp_path, 104 | False, 105 | 'efficientdet-d0', 106 | model_params={'moving_average_decay': 0}) 107 | images = tf.ones((1, 512, 512, 3)) 108 | boxes, scores, classes, valid_lens = driver.serve(images) 109 | self.assertEqual(tf.reduce_mean(boxes), 163.09) 110 | self.assertEqual(tf.reduce_mean(scores), 0.01) 111 | self.assertEqual(tf.reduce_mean(classes), 1) 112 | self.assertEqual(tf.reduce_mean(valid_lens), 100) 113 | self.assertEqual(boxes.shape, (1, 100, 4)) 114 | self.assertEqual(scores.shape, (1, 100)) 115 | self.assertEqual(classes.shape, (1, 100)) 116 | self.assertEqual(valid_lens.shape, (1,)) 117 | 118 | def test_network_infer_lib(self): 119 | driver = infer_lib.KerasDriver( 120 | self.tmp_path, False, 'efficientdet-d0', only_network=True) 121 | images = tf.ones((1, 512, 512, 3)) 122 | class_outputs, box_outputs = driver.predict(images) 123 | self.assertLen(class_outputs, 5) 124 | self.assertLen(box_outputs, 5) 125 | 126 | def test_infer_lib_mixed_precision(self): 127 | driver = infer_lib.KerasDriver( 128 | self.tmp_path, 129 | False, 130 | 'efficientdet-d0', 131 | model_params={'mixed_precision': True}) 132 | images = tf.ones((1, 512, 512, 3)) 133 | boxes, scores, classes, valid_lens = driver.serve(images) 134 | policy = tf.keras.mixed_precision.global_policy() 135 | if policy.name == 'float32': 136 | self.assertEqual(tf.reduce_mean(boxes), 163.09) 137 | self.assertEqual(tf.reduce_mean(scores), 0.01) 138 | self.assertEqual(tf.reduce_mean(classes), 1) 139 | self.assertEqual(tf.reduce_mean(valid_lens), 100) 140 | elif policy.name == 'float16': 141 | pass 142 | elif policy.name == 'bfloat16': 143 | pass 144 | self.assertEqual(boxes.shape, (1, 100, 4)) 145 | self.assertEqual(scores.shape, (1, 100)) 146 | self.assertEqual(classes.shape, (1, 100)) 147 | self.assertEqual(valid_lens.shape, (1,)) 148 | 149 | 150 | if __name__ == '__main__': 151 | logging.set_verbosity(logging.WARNING) 152 | tf.test.main() 153 | -------------------------------------------------------------------------------- /efficientdet/tf2/inspector_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Tests for model inspect tool.""" 16 | import os 17 | import shutil 18 | import tempfile 19 | from absl import flags 20 | from absl import logging 21 | from absl.testing import flagsaver 22 | import numpy as np 23 | from PIL import Image 24 | import tensorflow as tf 25 | 26 | from tf2 import inspector 27 | FLAGS = flags.FLAGS 28 | 29 | 30 | class InspectorTest(tf.test.TestCase): 31 | """Model inspect tests.""" 32 | 33 | def setUp(self): 34 | super().setUp() 35 | self.tempdir = tempfile.mkdtemp() 36 | FLAGS.model_dir = '_' 37 | 38 | def tearDown(self): 39 | super().tearDown() 40 | shutil.rmtree(self.tempdir) 41 | 42 | @flagsaver.flagsaver(mode='dry') 43 | def test_dry(self): 44 | FLAGS.export_ckpt = os.path.join(self.tempdir, 'model') 45 | inspector.main(None) 46 | self.assertIsNot(tf.train.get_checkpoint_state(self.tempdir), None) 47 | 48 | @flagsaver.flagsaver(mode='infer', saved_model_dir=None) 49 | def test_infer(self): 50 | test_image = np.random.randint(0, 244, (640, 720, 3)).astype(np.uint8) 51 | FLAGS.input_image = os.path.join(self.tempdir, 'img.jpg') 52 | Image.fromarray(test_image).save(FLAGS.input_image) 53 | FLAGS.output_image_dir = self.tempdir 54 | inspector.main(None) 55 | self.assertTrue(tf.io.gfile.exists(os.path.join(self.tempdir, '0.jpg'))) 56 | 57 | @flagsaver.flagsaver(mode='benchmark', saved_model_dir=None) 58 | def test_benchmark(self): 59 | inspector.main(None) 60 | self.assertFalse(tf.io.gfile.exists(os.path.join(self.tempdir, '0.jpg'))) 61 | 62 | @flagsaver.flagsaver(mode='export', tflite='FP32') 63 | def test_export(self): 64 | FLAGS.saved_model_dir = os.path.join(self.tempdir, 'savedmodel') 65 | tflite_path = os.path.join(FLAGS.saved_model_dir, 'fp32.tflite') 66 | inspector.main(None) 67 | self.assertTrue(tf.saved_model.contains_saved_model(FLAGS.saved_model_dir)) 68 | self.assertTrue(tf.io.gfile.exists(tflite_path)) 69 | 70 | 71 | if __name__ == '__main__': 72 | logging.set_verbosity(logging.WARNING) 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /efficientdet/tf2/label_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A few predefined label id mapping.""" 16 | import tensorflow as tf 17 | import yaml 18 | import hparams_config 19 | 20 | coco = { 21 | # 0: 'background', 22 | 1: 'person', 23 | 2: 'bicycle', 24 | 3: 'car', 25 | 4: 'motorcycle', 26 | 5: 'airplane', 27 | 6: 'bus', 28 | 7: 'train', 29 | 8: 'truck', 30 | 9: 'boat', 31 | 10: 'traffic light', 32 | 11: 'fire hydrant', 33 | 13: 'stop sign', 34 | 14: 'parking meter', 35 | 15: 'bench', 36 | 16: 'bird', 37 | 17: 'cat', 38 | 18: 'dog', 39 | 19: 'horse', 40 | 20: 'sheep', 41 | 21: 'cow', 42 | 22: 'elephant', 43 | 23: 'bear', 44 | 24: 'zebra', 45 | 25: 'giraffe', 46 | 27: 'backpack', 47 | 28: 'umbrella', 48 | 31: 'handbag', 49 | 32: 'tie', 50 | 33: 'suitcase', 51 | 34: 'frisbee', 52 | 35: 'skis', 53 | 36: 'snowboard', 54 | 37: 'sports ball', 55 | 38: 'kite', 56 | 39: 'baseball bat', 57 | 40: 'baseball glove', 58 | 41: 'skateboard', 59 | 42: 'surfboard', 60 | 43: 'tennis racket', 61 | 44: 'bottle', 62 | 46: 'wine glass', 63 | 47: 'cup', 64 | 48: 'fork', 65 | 49: 'knife', 66 | 50: 'spoon', 67 | 51: 'bowl', 68 | 52: 'banana', 69 | 53: 'apple', 70 | 54: 'sandwich', 71 | 55: 'orange', 72 | 56: 'broccoli', 73 | 57: 'carrot', 74 | 58: 'hot dog', 75 | 59: 'pizza', 76 | 60: 'donut', 77 | 61: 'cake', 78 | 62: 'chair', 79 | 63: 'couch', 80 | 64: 'potted plant', 81 | 65: 'bed', 82 | 67: 'dining table', 83 | 70: 'toilet', 84 | 72: 'tv', 85 | 73: 'laptop', 86 | 74: 'mouse', 87 | 75: 'remote', 88 | 76: 'keyboard', 89 | 77: 'cell phone', 90 | 78: 'microwave', 91 | 79: 'oven', 92 | 80: 'toaster', 93 | 81: 'sink', 94 | 82: 'refrigerator', 95 | 84: 'book', 96 | 85: 'clock', 97 | 86: 'vase', 98 | 87: 'scissors', 99 | 88: 'teddy bear', 100 | 89: 'hair drier', 101 | 90: 'toothbrush', 102 | } 103 | 104 | voc = { 105 | # 0: 'background', 106 | 1: 'aeroplane', 107 | 2: 'bicycle', 108 | 3: 'bird', 109 | 4: 'boat', 110 | 5: 'bottle', 111 | 6: 'bus', 112 | 7: 'car', 113 | 8: 'cat', 114 | 9: 'chair', 115 | 10: 'cow', 116 | 11: 'diningtable', 117 | 12: 'dog', 118 | 13: 'horse', 119 | 14: 'motorbike', 120 | 15: 'person', 121 | 16: 'pottedplant', 122 | 17: 'sheep', 123 | 18: 'sofa', 124 | 19: 'train', 125 | 20: 'tvmonitor', 126 | } 127 | 128 | waymo = { 129 | # 0: 'background', 130 | 1: 'vehicle', 131 | 2: 'pedestrian', 132 | 3: 'cyclist', 133 | } 134 | 135 | 136 | def get_label_map(mapping): 137 | """Get label id map based on the name, filename, or dict.""" 138 | # case 1: if it is None or dict, just return it. 139 | if not mapping or isinstance(mapping, dict): 140 | return mapping 141 | 142 | if isinstance(mapping, hparams_config.Config): 143 | return mapping.as_dict() 144 | 145 | # case 2: if it is a yaml file, load it to a dict and return the dict. 146 | assert isinstance(mapping, str), 'mapping must be dict or str.' 147 | if mapping.endswith('.yaml'): 148 | with tf.io.gfile.GFile(mapping) as f: 149 | return yaml.load(f, Loader=yaml.FullLoader) 150 | 151 | # case 3: it is a name of a predefined dataset. 152 | return {'coco': coco, 'voc': voc, 'waymo': waymo}[mapping] 153 | -------------------------------------------------------------------------------- /efficientdet/tf2/segmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A demo script to show to train a segmentation model.""" 16 | from absl import app 17 | from absl import logging 18 | import tensorflow as tf 19 | import tensorflow_datasets as tfds 20 | 21 | import hparams_config 22 | from tf2 import efficientdet_keras 23 | 24 | 25 | def create_mask(pred_mask): 26 | pred_mask = tf.argmax(pred_mask, axis=-1) 27 | pred_mask = pred_mask[..., tf.newaxis] 28 | return pred_mask[0] 29 | 30 | 31 | def normalize(input_image, input_mask): 32 | input_image = tf.cast(input_image, tf.float32) / 255.0 33 | input_mask -= 1 34 | return input_image, input_mask 35 | 36 | 37 | def load_image_train(datapoint): 38 | """Load images for training.""" 39 | input_image = tf.image.resize(datapoint['image'], (512, 512)) 40 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) 41 | 42 | if tf.random.uniform(()) > 0.5: 43 | input_image = tf.image.flip_left_right(input_image) 44 | input_mask = tf.image.flip_left_right(input_mask) 45 | 46 | input_image, input_mask = normalize(input_image, input_mask) 47 | 48 | return input_image, input_mask 49 | 50 | 51 | def load_image_test(datapoint): 52 | input_image = tf.image.resize(datapoint['image'], (512, 512)) 53 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) 54 | 55 | input_image, input_mask = normalize(input_image, input_mask) 56 | 57 | return input_image, input_mask 58 | 59 | 60 | def main(_): 61 | dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True) 62 | train_examples = info.splits['train'].num_examples 63 | batch_size = 8 64 | steps_per_epoch = train_examples // batch_size 65 | 66 | train = dataset['train'].map( 67 | load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) 68 | test = dataset['test'].map(load_image_test) 69 | 70 | train_dataset = train.cache().shuffle(1000).batch(batch_size).repeat() 71 | train_dataset = train_dataset.prefetch( 72 | buffer_size=tf.data.experimental.AUTOTUNE) 73 | test_dataset = test.batch(batch_size) 74 | config = hparams_config.get_efficientdet_config('efficientdet-d0') 75 | config.heads = ['segmentation'] 76 | model = efficientdet_keras.EfficientDetNet(config=config) 77 | model.build((1, 512, 512, 3)) 78 | model.compile( 79 | optimizer='adam', 80 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 81 | metrics=['accuracy']) 82 | 83 | val_subsplits = 5 84 | val_steps = info.splits['test'].num_examples // batch_size // val_subsplits 85 | model.fit( 86 | train_dataset, 87 | epochs=20, 88 | steps_per_epoch=steps_per_epoch, 89 | validation_steps=val_steps, 90 | validation_data=test_dataset, 91 | callbacks=[]) 92 | 93 | model.save_weights( 94 | './testdata/segmentation') 95 | 96 | print(create_mask(model(tf.ones((1, 512, 512, 3)), False))) 97 | 98 | 99 | if __name__ == '__main__': 100 | logging.set_verbosity(logging.WARNING) 101 | app.run(main) 102 | -------------------------------------------------------------------------------- /efficientdet/tf2/tfmot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A tool for model optimization.""" 16 | import functools 17 | 18 | import tensorflow_model_optimization as tfmot 19 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper 20 | from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs 21 | 22 | 23 | def quantize(layer, quantize_config=None): 24 | if quantize_config is None: 25 | quantize_config = default_8bit_quantize_configs.Default8BitOutputQuantizeConfig( 26 | ) 27 | return quantize_wrapper.QuantizeWrapper( 28 | layer, quantize_config=quantize_config) 29 | 30 | 31 | optimzation_methods = { 32 | 'prune': tfmot.sparsity.keras.prune_low_magnitude, 33 | 'quantize': quantize 34 | } 35 | 36 | 37 | def set_config(configs): 38 | for key in configs: 39 | if key == 'prune': 40 | optimzation_methods[key] = functools.partial( 41 | tfmot.sparsity.keras.prune_low_magnitude, **configs[key]) 42 | if key == 'quantize': 43 | optimzation_methods[key] = functools.partial(quantize, **configs[key]) 44 | 45 | 46 | def get_method(method): 47 | if method not in optimzation_methods: 48 | raise KeyError(f'only support {optimzation_methods.keys()}') 49 | return optimzation_methods[method] 50 | -------------------------------------------------------------------------------- /efficientdet/tf2/util_keras_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from absl import logging 16 | from absl.testing import parameterized 17 | import tensorflow as tf 18 | 19 | import utils 20 | from tf2 import util_keras 21 | 22 | 23 | class KerasUtilTest(tf.test.TestCase, parameterized.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | ('train_local', True, ''), ('eval_local', False, ''), 27 | ('train_tpu', True, 'tpu'), ('eval_tpu', False, 'tpu')) 28 | def test_batch_norm(self, is_training, strategy): 29 | inputs = tf.random.uniform([8, 40, 40, 3]) 30 | expect_results = utils.batch_norm_act( 31 | inputs, is_training, None, strategy=strategy) 32 | 33 | # Call batch norm layer with is_training parameter. 34 | bn_layer = util_keras.build_batch_norm(is_training, strategy=strategy) 35 | self.assertAllClose(expect_results, bn_layer(inputs, is_training)) 36 | 37 | 38 | if __name__ == '__main__': 39 | logging.set_verbosity(logging.WARNING) 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /efficientdet/tf2/wbf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """WBF for test-time augmentation.""" 16 | import tensorflow as tf 17 | 18 | 19 | def vectorized_iou(clusters, detection): 20 | """Calculates the ious for box with each element of clusters.""" 21 | x11, y11, x12, y12 = tf.split(clusters[:, 1:5], 4, axis=1) 22 | x21, y21, x22, y22 = tf.split(detection[1:5], 4) 23 | 24 | xa = tf.maximum(x11, x21) 25 | ya = tf.maximum(y11, y21) 26 | xb = tf.minimum(x12, x22) 27 | yb = tf.minimum(y12, y22) 28 | 29 | inter_area = tf.maximum((xb - xa), 0) * tf.maximum((yb - ya), 0) 30 | 31 | boxa_area = (x12 - x11) * (y12 - y11) 32 | boxb_area = (x22 - x21) * (y22 - y21) 33 | 34 | iou = inter_area / (boxa_area + boxb_area - inter_area) 35 | 36 | return iou 37 | 38 | 39 | def find_matching_cluster(clusters, detection): 40 | """Returns the index of the highest iou matching cluster for detection.""" 41 | if not clusters: 42 | return -1 43 | ious = vectorized_iou(tf.stack(clusters), detection) 44 | ious = tf.reshape(ious, [len(clusters)]) 45 | if tf.math.reduce_max(ious) < 0.55: 46 | # returns -1 if no iou is higher than 0.55. 47 | return -1 48 | return tf.argmax(ious) 49 | 50 | 51 | def weighted_average(samples, weights): 52 | return tf.math.reduce_sum(samples * weights) / tf.math.reduce_sum(weights) 53 | 54 | 55 | def average_detections(detections, num_models): 56 | """Takes a list of detections and returns the average, both in box co-ordinates and confidence.""" 57 | num_detections = len(detections) 58 | detections = tf.stack(detections) 59 | return [ 60 | detections[0][0], 61 | weighted_average(detections[:, 1], detections[:, 5]), 62 | weighted_average(detections[:, 2], detections[:, 5]), 63 | weighted_average(detections[:, 3], detections[:, 5]), 64 | weighted_average(detections[:, 4], detections[:, 5]), 65 | tf.math.reduce_mean(detections[:, 5]) * min(1, num_detections/num_models), 66 | detections[0][6], 67 | ] 68 | 69 | 70 | def ensemble_detections(params, detections, num_models): 71 | """Ensembles a group of detections by clustering the detections and returning the average of the clusters.""" 72 | all_clusters = [] 73 | 74 | for cid in range(params['num_classes']): 75 | indices = tf.where(tf.equal(detections[:, 6], cid)) 76 | if indices.shape[0] == 0: 77 | continue 78 | class_detections = tf.gather_nd(detections, indices) 79 | 80 | clusters = [] 81 | cluster_averages = [] 82 | for d in class_detections: 83 | cluster_index = find_matching_cluster(cluster_averages, d) 84 | if cluster_index == -1: 85 | clusters.append([d]) 86 | cluster_averages.append(average_detections([d], num_models)) 87 | else: 88 | clusters[cluster_index].append(d) 89 | cluster_averages[cluster_index] = average_detections( 90 | clusters[cluster_index], num_models) 91 | 92 | all_clusters.extend(cluster_averages) 93 | 94 | all_clusters.sort(reverse=True, key=lambda d: d[5]) 95 | return tf.stack(all_clusters) 96 | -------------------------------------------------------------------------------- /efficientdet/tf2/wbf_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test for wbf.""" 16 | from absl import logging 17 | import tensorflow as tf 18 | 19 | from tf2 import wbf 20 | 21 | 22 | class WbfTest(tf.test.TestCase): 23 | 24 | def test_detection_iou_same(self): 25 | d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32) 26 | d2 = tf.constant([1, 1, 1, 3, 3, 1, 1], dtype=tf.float32) 27 | 28 | iou = wbf.vectorized_iou(d1, d2) 29 | 30 | self.assertAllClose(iou[0][0], 1.0) 31 | 32 | def test_detection_iou_corners(self): 33 | d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32) 34 | d2 = tf.constant([1, 2, 2, 4, 4, 1, 1], dtype=tf.float32) 35 | 36 | iou = wbf.vectorized_iou(d1, d2) 37 | 38 | self.assertAllClose(iou[0][0], 1.0 / 7.0) 39 | 40 | def test_detection_iou_ends(self): 41 | d1 = tf.constant([[1, 1, 1, 3, 2, 1, 1]], dtype=tf.float32) 42 | d2 = tf.constant([1, 2, 1, 4, 2, 1, 1], dtype=tf.float32) 43 | 44 | iou = wbf.vectorized_iou(d1, d2) 45 | 46 | self.assertAllClose(iou[0][0], 1.0 / 3.0) 47 | 48 | def test_detection_iou_none(self): 49 | d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32) 50 | d2 = tf.constant([1, 3, 3, 5, 5, 1, 1], dtype=tf.float32) 51 | 52 | iou = wbf.vectorized_iou(d1, d2) 53 | 54 | self.assertAllClose(iou[0][0], 0) 55 | 56 | def test_detection_iou_vector(self): 57 | vector_to_match = tf.constant( 58 | [ 59 | [1, 1, 1, 3, 3, 1, 1], 60 | [1, 2, 2, 4, 4, 1, 1], 61 | [1, 3, 3, 5, 5, 1, 1], 62 | ], 63 | dtype=tf.float32, 64 | ) 65 | 66 | detection = tf.constant([1, 1, 1, 3, 3, 1, 1], dtype=tf.float32) 67 | 68 | ious = wbf.vectorized_iou(vector_to_match, detection) 69 | self.assertAllClose(tf.reshape(ious, [3]), [1, 1.0 / 7.0, 0]) 70 | 71 | def test_find_matching_cluster_matches(self): 72 | matching_cluster = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) 73 | non_matching_cluster = tf.constant([1, 3, 3, 2, 2, 1, 1], dtype=tf.float32) 74 | 75 | box = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) 76 | 77 | cluster_index = wbf.find_matching_cluster( 78 | (matching_cluster, non_matching_cluster), box) 79 | 80 | self.assertAllClose(cluster_index, 0) 81 | 82 | cluster_index = wbf.find_matching_cluster( 83 | (non_matching_cluster, matching_cluster), box) 84 | 85 | self.assertAllClose(cluster_index, 1) 86 | 87 | def test_find_matching_cluster_best_overlap(self): 88 | overlaps = tf.constant([1, 1, 1, 11, 2, 1, 1], dtype=tf.float32) 89 | overlaps_better = tf.constant([1, 2, 1, 12, 2, 1, 1], dtype=tf.float32) 90 | 91 | box = tf.constant([1, 3, 1, 13, 2, 1, 1], dtype=tf.float32) 92 | 93 | cluster_index = wbf.find_matching_cluster((overlaps,), box) 94 | 95 | self.assertAllClose(cluster_index, 0) 96 | 97 | cluster_index = wbf.find_matching_cluster((overlaps, overlaps_better), box) 98 | 99 | self.assertAllClose(cluster_index, 1) 100 | 101 | def test_weighted_average(self): 102 | samples = tf.constant([1, 3], dtype=tf.float32) 103 | 104 | weights1 = tf.constant([0.5, 0.5], dtype=tf.float32) 105 | weighted_average1 = wbf.weighted_average(samples, weights1) 106 | 107 | self.assertAllClose(weighted_average1, 2) 108 | 109 | weights2 = tf.constant([1, 0], dtype=tf.float32) 110 | weighted_average2 = wbf.weighted_average(samples, weights2) 111 | 112 | self.assertAllClose(weighted_average2, 1) 113 | 114 | weights3 = tf.constant([1, 2], dtype=tf.float32) 115 | weighted_average3 = wbf.weighted_average(samples, weights3) 116 | 117 | self.assertAllClose(weighted_average3, 7.0 / 3.0) 118 | 119 | def test_average_detections(self): 120 | d1 = tf.constant([1, 1, 1, 2, 2, 0.3, 1], dtype=tf.float32) 121 | d2 = tf.constant([1, 3, 3, 4, 4, 0.7, 1], dtype=tf.float32) 122 | 123 | averaged_single_model = wbf.average_detections((d1, d2), 1) 124 | self.assertAllClose(averaged_single_model, [1, 2.4, 2.4, 3.4, 3.4, 0.5, 1]) 125 | 126 | averaged_multi_model = wbf.average_detections((d1, d2), 3) 127 | self.assertAllClose(averaged_multi_model, 128 | [1, 2.4, 2.4, 3.4, 3.4, 0.333333, 1]) 129 | 130 | averaged_single_detection = wbf.average_detections((d2,), 2) 131 | self.assertAllClose(averaged_single_detection, [1, 3, 3, 4, 4, 0.35, 1]) 132 | 133 | def test_ensemble_boxes(self): 134 | d1 = tf.constant([1, 2, 1, 10, 1, 0.75, 1], dtype=tf.float32) 135 | d2 = tf.constant([1, 3, 1, 10, 1, 0.75, 1], dtype=tf.float32) 136 | d3 = tf.constant([1, 3, 1, 10, 1, 1, 2], dtype=tf.float32) 137 | 138 | ensembled = wbf.ensemble_detections({'num_classes': 3}, 139 | tf.stack([d1, d2, d3]), 2) 140 | 141 | self.assertAllClose(ensembled, 142 | [[1, 2.5, 1, 10, 1, 0.75, 1], [1, 3, 1, 10, 1, 0.5, 2]]) 143 | 144 | 145 | if __name__ == '__main__': 146 | logging.set_verbosity(logging.WARNING) 147 | tf.test.main() 148 | -------------------------------------------------------------------------------- /efficientdet/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for utils.""" 16 | import os 17 | from absl import logging 18 | import tensorflow.compat.v1 as tf 19 | 20 | import utils 21 | 22 | 23 | class UtilsTest(tf.test.TestCase): 24 | 25 | def setUp(self): 26 | super(UtilsTest, self).setUp() 27 | self.model_dir = os.path.join(tf.test.get_temp_dir(), 'model_dir') 28 | 29 | def build_model(self): 30 | x = tf.Variable(1.0) 31 | y = tf.Variable(2.0) 32 | z = x + y 33 | return z 34 | 35 | def test_archive_ckpt(self): 36 | model_dir = os.path.join(tf.test.get_temp_dir(), 'model_dir') 37 | ckpt_path = os.path.join(model_dir, 'ckpt') 38 | self.build_model() 39 | saver = tf.train.Saver() 40 | with self.session() as sess: 41 | sess.run(tf.global_variables_initializer()) 42 | saver.save(sess, ckpt_path) 43 | 44 | # Save checkpoint if the new objective is better. 45 | self.assertTrue(utils.archive_ckpt('eval1', 0.1, ckpt_path)) 46 | logging.info(os.listdir(model_dir)) 47 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive'))) 48 | self.assertFalse(tf.io.gfile.exists(os.path.join(model_dir, 'backup'))) 49 | 50 | # Save checkpoint if the new objective is better. 51 | self.assertTrue(utils.archive_ckpt('eval2', 0.2, ckpt_path)) 52 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive'))) 53 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'backup'))) 54 | 55 | # Skip checkpoint if the new objective is worse. 56 | self.assertFalse(utils.archive_ckpt('eval3', 0.1, ckpt_path)) 57 | 58 | # Save checkpoint if the new objective is better. 59 | self.assertTrue(utils.archive_ckpt('eval4', 0.3, ckpt_path)) 60 | 61 | # Save checkpoint if the new objective is equal. 62 | self.assertTrue(utils.archive_ckpt('eval5', 0.3, ckpt_path)) 63 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive'))) 64 | self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'backup'))) 65 | 66 | def test_image_size(self): 67 | self.assertEqual(utils.parse_image_size('1280x640'), (640, 1280)) 68 | self.assertEqual(utils.parse_image_size(1280), (1280, 1280)) 69 | self.assertEqual(utils.parse_image_size((1280, 640)), (1280, 640)) 70 | 71 | def test_get_feat_sizes(self): 72 | feats = utils.get_feat_sizes(640, 2) 73 | self.assertEqual(feats, [{ 74 | 'height': 640, 75 | 'width': 640 76 | }, { 77 | 'height': 320, 78 | 'width': 320 79 | }, { 80 | 'height': 160, 81 | 'width': 160 82 | }]) 83 | 84 | feats = utils.get_feat_sizes((640, 300), 2) 85 | self.assertEqual(feats, [{ 86 | 'height': 640, 87 | 'width': 300, 88 | }, { 89 | 'height': 320, 90 | 'width': 150, 91 | }, { 92 | 'height': 160, 93 | 'width': 75, 94 | }]) 95 | 96 | def test_precision_float16(self): 97 | def _model(inputs): 98 | x = tf.ones((4, 4, 4, 4), dtype='float32') 99 | conv = tf.keras.layers.Conv2D(filters=4, kernel_size=2, use_bias=False) 100 | a = tf.Variable(1.0) 101 | return tf.cast(a, inputs.dtype) * conv(x) * inputs 102 | 103 | x = tf.constant(2.0, dtype=tf.float32) # input can be any type. 104 | out = utils.build_model_with_precision('mixed_float16', _model, x) 105 | # Variables should be float32. 106 | for v in tf.global_variables(): 107 | self.assertIn(v.dtype, (tf.float32, tf.dtypes.as_dtype('float32_ref'))) 108 | self.assertIs(out.dtype, tf.float16) # output should be float16. 109 | 110 | 111 | class ActivationTest(tf.test.TestCase): 112 | 113 | def test_swish(self): 114 | features = tf.constant([.5, 10]) 115 | 116 | result = utils.activation_fn(features, 'swish') 117 | expected = features * tf.sigmoid(features) 118 | self.assertAllClose(result, expected) 119 | 120 | result = utils.activation_fn(features, 'swish_native') 121 | self.assertAllClose(result, expected) 122 | 123 | def test_hswish(self): 124 | features = tf.constant([.5, 10]) 125 | result = utils.activation_fn(features, 'hswish') 126 | self.assertAllClose(result, [0.29166667, 10.0]) 127 | 128 | def test_relu(self): 129 | features = tf.constant([.5, 10]) 130 | result = utils.activation_fn(features, 'relu') 131 | self.assertAllClose(result, [0.5, 10]) 132 | 133 | def test_relu6(self): 134 | features = tf.constant([.5, 10]) 135 | result = utils.activation_fn(features, 'relu6') 136 | self.assertAllClose(result, [0.5, 6]) 137 | 138 | def test_mish(self): 139 | features = tf.constant([.5, 10]) 140 | result = utils.activation_fn(features, 'mish') 141 | self.assertAllClose(result, [0.37524524, 10.0]) 142 | 143 | 144 | if __name__ == '__main__': 145 | logging.set_verbosity(logging.WARNING) 146 | tf.disable_eager_execution() 147 | tf.test.main() 148 | -------------------------------------------------------------------------------- /efficientdet/visualize/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # Visualization library is mostly based on TensorFlow object detection API: 16 | # https://github.com/tensorflow/models/tree/master/research/object_detection 17 | -------------------------------------------------------------------------------- /efficientdet/visualize/static_shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helper functions to access TensorShape values. 16 | 17 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 18 | """ 19 | 20 | 21 | def get_dim_as_int(dim): 22 | """Utility to get v1 or v2 TensorShape dim as an int. 23 | 24 | Args: 25 | dim: The TensorShape dimension to get as an int 26 | 27 | Returns: 28 | None or an int. 29 | """ 30 | try: 31 | return dim.value 32 | except AttributeError: 33 | return dim 34 | 35 | 36 | def get_batch_size(tensor_shape): 37 | """Returns batch size from the tensor shape. 38 | 39 | Args: 40 | tensor_shape: A rank 4 TensorShape. 41 | 42 | Returns: 43 | An integer representing the batch size of the tensor. 44 | """ 45 | tensor_shape.assert_has_rank(rank=4) 46 | return get_dim_as_int(tensor_shape[0]) 47 | 48 | 49 | def get_height(tensor_shape): 50 | """Returns height from the tensor shape. 51 | 52 | Args: 53 | tensor_shape: A rank 4 TensorShape. 54 | 55 | Returns: 56 | An integer representing the height of the tensor. 57 | """ 58 | tensor_shape.assert_has_rank(rank=4) 59 | return get_dim_as_int(tensor_shape[1]) 60 | 61 | 62 | def get_width(tensor_shape): 63 | """Returns width from the tensor shape. 64 | 65 | Args: 66 | tensor_shape: A rank 4 TensorShape. 67 | 68 | Returns: 69 | An integer representing the width of the tensor. 70 | """ 71 | tensor_shape.assert_has_rank(rank=4) 72 | return get_dim_as_int(tensor_shape[2]) 73 | 74 | 75 | def get_depth(tensor_shape): 76 | """Returns depth from the tensor shape. 77 | 78 | Args: 79 | tensor_shape: A rank 4 TensorShape. 80 | 81 | Returns: 82 | An integer representing the depth of the tensor. 83 | """ 84 | tensor_shape.assert_has_rank(rank=4) 85 | return get_dim_as_int(tensor_shape[3]) 86 | -------------------------------------------------------------------------------- /efficientnetv2/autoaugment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for autoaugment.""" 16 | import tensorflow.compat.v1 as tf 17 | import autoaugment 18 | 19 | 20 | class AutoaugmentTest(tf.test.TestCase): 21 | 22 | def test_autoaugment(self): 23 | """Smoke test to be sure no syntax errors.""" 24 | image = tf.zeros((224, 224, 3), dtype=tf.uint8) 25 | aug_image = autoaugment.distort_image_with_autoaugment(image, 'v0') 26 | self.assertEqual((224, 224, 3), aug_image.shape) 27 | 28 | def test_randaug(self): 29 | """Smoke test to be sure no syntax errors.""" 30 | num_layers = 2 31 | magnitude = 15 32 | image = tf.zeros((224, 224, 3), dtype=tf.uint8) 33 | aug_image = autoaugment.distort_image_with_randaugment( 34 | image, num_layers, magnitude) 35 | self.assertEqual((224, 224, 3), aug_image.shape) 36 | 37 | 38 | if __name__ == '__main__': 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /efficientnetv2/cflags.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Command line flags, shared by all executable binaries.""" 16 | from absl import flags 17 | FLAGS = flags.FLAGS 18 | 19 | 20 | def define_flags(): 21 | """A shared function to define flags.""" 22 | flags.DEFINE_string('model_name', 'efficientnetv2-b0', 'model name.') 23 | flags.DEFINE_string('dataset_cfg', 'Imagenet', 'dataset config name.') 24 | flags.DEFINE_string('hparam_str', '', 'Comma separated k=v pairs of hparams.') 25 | flags.DEFINE_string('sweeps', '', 'Comma separated k=v pairs for sweeping.') 26 | flags.DEFINE_bool('use_tpu', True, 'If true, use TPU; otherwise use CPU/GPU.') 27 | flags.DEFINE_string('tpu_job_name', None, 'job name, default to tpu_worker.') 28 | # Cloud TPU Cluster Resolvers 29 | flags.DEFINE_string('tpu', None, 'address e.g. grpc://ip.address.of.tpu:8470') 30 | flags.DEFINE_string('gcp_project', None, 'Project name.') 31 | flags.DEFINE_string('tpu_zone', None, 'GCE zone') 32 | # Model specific flags 33 | flags.DEFINE_string('data_dir', None, 'The directory for training images.') 34 | flags.DEFINE_string('eval_name', None, 'Evaluation name.') 35 | flags.DEFINE_bool('archive_ckpt', True, 'If true, archive the best ckpt.') 36 | flags.DEFINE_string('model_dir', None, 'Dir for checkpoint and summaries.') 37 | flags.DEFINE_string('mode', 'train', 'One of {"train", "eval"}.') 38 | flags.DEFINE_bool('export_to_tpu', False, 'Export metagraph.') 39 | -------------------------------------------------------------------------------- /efficientnetv2/datasets_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests datasets.""" 16 | import tensorflow as tf 17 | import datasets 18 | 19 | 20 | class ImagenetInputTest(tf.test.TestCase): 21 | 22 | def test_imagenet(self): 23 | ds_class = datasets.get_dataset_class('imagenet') 24 | ds = ds_class( 25 | is_training=False, 26 | data_dir='null', 27 | cache=False, 28 | image_size=224, 29 | image_dtype=None, 30 | augname=None, 31 | mixup_alpha=0, 32 | ra_num_layers=2, 33 | ra_magnitude=20) 34 | params = {'batch_size': 2} 35 | for _, labels in ds.input_fn(params): 36 | label = labels['label'] 37 | self.assertAllClose(label[:, 0:4], [[0, 0, 0, 0], [0, 0, 0, 0]]) 38 | break 39 | 40 | def test_imagenet21k(self): 41 | ds_class = datasets.get_dataset_class('imagenet21k') 42 | ds = ds_class( 43 | is_training=False, 44 | data_dir='null', 45 | cache=False, 46 | image_size=224, 47 | image_dtype=None, 48 | augname=None, 49 | mixup_alpha=0, 50 | ra_num_layers=2, 51 | ra_magnitude=20) 52 | params = {'batch_size': 2} 53 | for _, labels in ds.input_fn(params): 54 | label = labels['label'] 55 | self.assertAllClose(label[:, 0:4], [[0, 0, 1, 1], [0, 0, 1, 1]]) 56 | break 57 | 58 | 59 | class DatasetConfigTest(tf.test.TestCase): 60 | 61 | def test_dataset_config(self): 62 | cfg = datasets.get_dataset_config('cifar10ft') 63 | self.assertEqual(cfg.data.ds_name, 'cifar10') 64 | 65 | 66 | if __name__ == '__main__': 67 | tf.test.main() 68 | -------------------------------------------------------------------------------- /efficientnetv2/effnetv2_configs_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests fo effnetv2_configs.""" 16 | from absl import logging 17 | import tensorflow as tf 18 | import effnetv2_configs 19 | 20 | 21 | class EffnetV2ConfigsTest(tf.test.TestCase): 22 | 23 | def test_model_config(self): 24 | cfg = effnetv2_configs.get_model_config('efficientnet-b0') 25 | self.assertEqual(cfg.model.model_name, 'efficientnet-b0') 26 | 27 | cfg = effnetv2_configs.get_model_config('efficientnetv2-s') 28 | self.assertEqual(cfg.model.model_name, 'efficientnetv2-s') 29 | 30 | 31 | if __name__ == '__main__': 32 | logging.set_verbosity(logging.WARNING) 33 | tf.test.main() 34 | 35 | -------------------------------------------------------------------------------- /efficientnetv2/effnetv2_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for effnetv2_model.""" 16 | from absl import logging 17 | from absl.testing import parameterized 18 | import tensorflow.compat.v1 as tf 19 | import effnetv2_model 20 | 21 | 22 | class EffNetV2ModelTest(tf.test.TestCase, parameterized.TestCase): 23 | 24 | @parameterized.named_parameters(('v1_b0', 'efficientnet-b0', 5330564), 25 | ('v1_b1', 'efficientnet-b1', 7856232), 26 | ('v1_b2', 'efficientnet-b2', 9177562), 27 | ('v1_b3', 'efficientnet-b3', 12314268), 28 | ('v1_b4', 'efficientnet-b4', 19466816), 29 | ('v1_b5', 'efficientnet-b5', 30562520), 30 | ('v1_b6', 'efficientnet-b6', 43265136)) 31 | def test_effnetv1(self, model_name, expected_params): 32 | images = tf.zeros((1, 224, 224, 3), dtype=tf.float32) 33 | model = effnetv2_model.EffNetV2Model(model_name) 34 | _ = model(images) 35 | self.assertEqual(model.count_params(), expected_params) 36 | 37 | @parameterized.named_parameters(('v1-b0', 'efficientnetv2-b0', 7200312), 38 | ('v1-b1', 'efficientnetv2-b1', 8212124), 39 | ('v1-b2', 'efficientnetv2-b2', 10178374), 40 | ('v1-b3', 'efficientnetv2-b3', 14467622), 41 | ('s', 'efficientnetv2-s', 21612360), 42 | ('m', 'efficientnetv2-m', 54431388), 43 | ('l', 'efficientnetv2-l', 119027848), 44 | ('xl', 'efficientnetv2-xl', 208896832)) 45 | def test_effnetv2(self, model_name, expected_params): 46 | images = tf.zeros((10, 224, 224, 3), dtype=tf.float32) 47 | model = effnetv2_model.EffNetV2Model(model_name) 48 | _ = model(images) 49 | self.assertEqual(model.count_params(), expected_params) 50 | 51 | 52 | if __name__ == '__main__': 53 | logging.set_verbosity(logging.WARNING) 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /efficientnetv2/g3doc/effnetv2-l-gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/effnetv2-l-gpu.png -------------------------------------------------------------------------------- /efficientnetv2/g3doc/effnetv2-m-gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/effnetv2-m-gpu.png -------------------------------------------------------------------------------- /efficientnetv2/g3doc/effnetv2-s-gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/effnetv2-s-gpu.png -------------------------------------------------------------------------------- /efficientnetv2/g3doc/effnetv2-s-relu6-gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/effnetv2-s-relu6-gpu.png -------------------------------------------------------------------------------- /efficientnetv2/g3doc/param_flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/param_flops.png -------------------------------------------------------------------------------- /efficientnetv2/g3doc/train_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/efficientnetv2/g3doc/train_params.png -------------------------------------------------------------------------------- /efficientnetv2/mlir.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A simple example on how to export MLIR.""" 16 | import os 17 | from absl import app 18 | from absl import flags 19 | from absl import logging 20 | import tensorflow as tf 21 | 22 | import effnetv2_model 23 | import utils 24 | # pylint: disable=g-direct-tensorflow-import 25 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph 26 | from tensorflow.lite.python.util import get_grappler_config 27 | from tensorflow.lite.python.util import run_graph_optimizations 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | def define_flags(): 33 | """Define all flags for binary run.""" 34 | flags.DEFINE_string('model_dir', None, 'Location of the checkpoint to run.') 35 | flags.DEFINE_string('model_name', 'efficientnetv2-b0', 'Model name to use.') 36 | flags.DEFINE_string('dataset_cfg', 'Imagenet', 'dataset config name.') 37 | flags.DEFINE_string('hparam_str', '', 'k=v,x=y pairs or yaml file.') 38 | flags.DEFINE_string('export_dir', None, 'Export or saved model directory') 39 | 40 | 41 | def main(_): 42 | """Export model to MLIR.""" 43 | model = effnetv2_model.get_model( 44 | FLAGS.model_name, 45 | FLAGS.hparam_str, 46 | include_top=True, 47 | pretrained=FLAGS.model_dir or True) 48 | # Use call (not build) to match the namescope: tensorflow issues/29576 49 | model(tf.ones([1, 224, 224, 3]), False) 50 | if FLAGS.model_dir: 51 | ckpt = FLAGS.model_dir 52 | if tf.io.gfile.isdir(ckpt): 53 | ckpt = tf.train.latest_checkpoint(FLAGS.model_dir) 54 | utils.restore_tf2_ckpt(model, ckpt, exclude_layers=('_head', 'optimizer')) 55 | model.summary() 56 | 57 | fff = tf.function(model).get_concrete_function( 58 | tf.TensorSpec([1, 224, 224, 3], tf.float32)) 59 | 60 | frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(fff) 61 | 62 | input_tensors = [ 63 | tensor for tensor in frozen_func.inputs if tensor.dtype != tf.resource 64 | ] 65 | output_tensors = frozen_func.outputs 66 | 67 | graph_def = run_graph_optimizations( 68 | graph_def, 69 | input_tensors, 70 | output_tensors, 71 | config=get_grappler_config([ 72 | 'pruning', 'function', 'constfold', 'shape', 'remap', 'memory', 73 | 'common_subgraph_elimination', 'arithmetic', 'loop', 'dependency', 74 | 'debug_stripper' 75 | ]), 76 | graph=frozen_func.graph) 77 | 78 | tf_mlir_graph = tf.mlir.experimental.convert_graph_def(graph_def) 79 | 80 | print('export model to {}.mlir'.format(FLAGS.model_name)) 81 | export_dir = FLAGS.export_dir 82 | if export_dir is None: 83 | export_dir = '.' 84 | os.makedirs(export_dir, exist_ok=True) 85 | outfile = open('{}/{}.mlir'.format(export_dir, FLAGS.model_name), 'wb') 86 | outfile.write(tf_mlir_graph.encode()) 87 | outfile.close() 88 | 89 | 90 | if __name__ == '__main__': 91 | logging.set_verbosity(logging.ERROR) 92 | define_flags() 93 | app.run(main) 94 | -------------------------------------------------------------------------------- /efficientnetv2/preprocessing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for preprocessing.""" 16 | from absl import logging 17 | from absl.testing import parameterized 18 | import tensorflow as tf 19 | import preprocessing 20 | 21 | 22 | class PreprocessingTest(tf.test.TestCase, parameterized.TestCase): 23 | 24 | @parameterized.parameters('effnetv1_autoaug', 'effnetv1_randaug', None) 25 | def test_preprocessing_legacy(self, augname): 26 | image = tf.zeros((300, 300, 3), dtype=tf.float32) 27 | try: 28 | preprocessing.preprocess_image(image, 224, False, None, augname) 29 | except tf.errors.InvalidArgumentError as e: 30 | if 'ExtractJpegShape' not in str(e): 31 | raise e 32 | 33 | @parameterized.parameters('autoaug', 'randaug', 'ft', 'ft_autoaug', None) 34 | def test_preprocessing(self, augname): 35 | image = tf.zeros((300, 300, 3), dtype=tf.float32) 36 | preprocessing.preprocess_image(image, 224, True, None, augname) 37 | 38 | 39 | if __name__ == '__main__': 40 | logging.set_verbosity(logging.WARNING) 41 | tf.test.main() 42 | -------------------------------------------------------------------------------- /efficientnetv2/smoke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for EfficientNetV2 train smoke tests.""" 16 | import sys 17 | import tempfile 18 | from absl import flags 19 | from absl.testing import flagsaver 20 | import tensorflow as tf 21 | import main as main_lib 22 | 23 | FLAGS = flags.FLAGS 24 | GPU_TEST = 'gpu_test' in sys.argv[0] 25 | TPU_TEST = 'test_tpu' in sys.argv[0] 26 | 27 | 28 | class EfficientNetV2Test(tf.test.TestCase): 29 | 30 | @classmethod 31 | def setUpClass(cls): 32 | super().setUpClass() 33 | FLAGS.tpu = '' 34 | FLAGS.model_dir = tempfile.mkdtemp() 35 | FLAGS.data_dir = 'null' 36 | cls.hparam_str = ( 37 | 'train.batch_size=2,eval.batch_size=2,train.epochs=0,train.min_steps=1,' 38 | 'train.stages=0,train.lr_base=0,data.splits.eval.num_images=6,') 39 | 40 | def _run_single_step_train_and_eval(self, hparam_str=''): 41 | """Single step run with TPUEstimator.""" 42 | FLAGS.hparam_str = self.hparam_str + hparam_str 43 | FLAGS.mode = 'train' 44 | main_lib.main(None) 45 | 46 | tf.compat.v1.reset_default_graph() 47 | FLAGS.mode = 'eval' 48 | main_lib.main(None) 49 | 50 | @flagsaver.flagsaver( 51 | use_tpu=False, model_name='efficientnetv2-s', dataset_cfg='ImageNet') 52 | def test_cpu_b0_model_single_step(self): 53 | self._run_single_step_train_and_eval() 54 | 55 | @flagsaver.flagsaver(use_tpu=True) 56 | def test_tpu_b0_model_bfloat_single_step(self): 57 | if TPU_TEST: 58 | self._run_single_step_train_and_eval('') 59 | else: 60 | self.skipTest('Skip because no TPU is available.') 61 | 62 | @flagsaver.flagsaver(use_tpu=False) 63 | def test_tpu_b0_model_single_step_gpu(self): 64 | if GPU_TEST: 65 | # Disables export as tflite does not support NCHW layout. 66 | self._run_single_step_train_and_eval('model.data_format=channels_first') 67 | else: 68 | self.skipTest('Skip because no GPU is available.') 69 | 70 | 71 | if __name__ == '__main__': 72 | tf.compat.v1.disable_eager_execution() 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /efficientnetv2/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for utils.""" 16 | 17 | import tensorflow as tf 18 | import utils 19 | 20 | 21 | class UtilsTest(tf.test.TestCase): 22 | 23 | def test_constant_lr(self): 24 | constant_schedule = utils.WarmupLearningRateSchedule( 25 | 1.0, lr_decay_type='constant', warmup_epochs=None) 26 | 27 | lr = constant_schedule(10) 28 | self.assertAllClose(lr, 1.0) 29 | 30 | def test_linear_lr(self): 31 | linear_schedule = utils.WarmupLearningRateSchedule( 32 | 1.0, total_steps=10, lr_decay_type='linear', warmup_epochs=None) 33 | 34 | lr = linear_schedule(0) 35 | self.assertAllClose(lr, 1.0) 36 | 37 | lr = linear_schedule(5) 38 | self.assertAllClose(lr, 0.5) 39 | 40 | lr = linear_schedule(10) 41 | self.assertAllClose(lr, 0.0) 42 | 43 | def test_cosine_lr(self): 44 | cosine_schedule = utils.WarmupLearningRateSchedule( 45 | 1.0, total_steps=10, lr_decay_type='cosine', warmup_epochs=None) 46 | 47 | lr = cosine_schedule(4) 48 | self.assertAllClose(lr, 0.654508) 49 | 50 | lr = cosine_schedule(5) 51 | self.assertAllClose(lr, 0.5) 52 | 53 | lr = cosine_schedule(6) 54 | self.assertAllClose(lr, 0.345491) 55 | 56 | def test_exponential_lr(self): 57 | exponential_schedule = utils.WarmupLearningRateSchedule( 58 | 1.0, 59 | total_steps=100, 60 | steps_per_epoch=10, 61 | decay_epochs=2, 62 | decay_factor=0.5, 63 | lr_decay_type='exponential', 64 | warmup_epochs=None) 65 | 66 | lr = exponential_schedule(5) 67 | self.assertAllClose(lr, 1.0) 68 | 69 | lr = exponential_schedule(25) 70 | self.assertAllClose(lr, 0.5) 71 | 72 | lr = exponential_schedule(70) 73 | self.assertAllClose(lr, 0.125) 74 | 75 | def test_warmup(self): 76 | warmup_schedule = utils.WarmupLearningRateSchedule( 77 | 1.0, 78 | total_steps=100, 79 | steps_per_epoch=10, 80 | warmup_epochs=2, 81 | lr_decay_type='constant') 82 | 83 | lr = warmup_schedule(5) 84 | self.assertAllClose(lr, 0.25) 85 | 86 | lr = warmup_schedule(35) 87 | self.assertAllClose(lr, 1.0) 88 | 89 | 90 | if __name__ == '__main__': 91 | tf.test.main() 92 | -------------------------------------------------------------------------------- /hero/example_commands.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Example command for installing dependencies. 4 | # JAX installation is environment-specific (CPU, GPU, TPU). Check the official JAX installation guide at https://docs.jax.dev/en/latest/installation.html. 5 | # Examples: 6 | # CPU: `pip install -U jax` 7 | # GPU: `pip install -U "jax[cuda12]"` (Replace `cuda12`; ensure CUDA/cuDNN/drivers are correct) 8 | # TPU: `pip install -U "jax[tpu]"` 9 | pip install -r requirements.txt 10 | 11 | # Example command for local run. 12 | # Add "export JAX_DISABLE_JIT=True;" to disable `jit` for easier debugging. 13 | # Change "TransformerLMTest" to other experiment config names in `config_lib.py` to run other experiments. 14 | EXP=local_test_1; rm -rf /tmp/${EXP}; python main.py --experiment_config TransformerLMTest --experiment_dir /tmp/${EXP} --verbosity=-1 15 | 16 | # Example command for checking learning curves with tensorboard. 17 | tensorboard --logdir /tmp/${EXP} 18 | -------------------------------------------------------------------------------- /hero/fn_lib.py: -------------------------------------------------------------------------------- 1 | """The functions used in the program search. 2 | 3 | This library implements the functions that will be used in the program search. 4 | Having a shared library of functions will make it easier to share the discovered 5 | algorithms between different tasks. For example, if we want to evaluate one 6 | algorithm on both language and vision tasks. 7 | """ 8 | 9 | from typing import Sequence, Optional 10 | 11 | import jax 12 | import jax.nn 13 | 14 | import jax.numpy as jnp 15 | 16 | import core 17 | 18 | 19 | @jax.jit 20 | def interpolate(x, y, weight): 21 | return core.tree_add( 22 | core.tree_mult(1.0 - weight, x), core.tree_mult(weight, y)) 23 | 24 | 25 | @jax.jit 26 | def global_norm(tree): 27 | leaves = jax.tree_leaves(tree) 28 | norm = jnp.sqrt(sum([jnp.vdot(x, x) for x in leaves])) 29 | return norm 30 | 31 | 32 | @jax.jit 33 | def tree_dot(tree1, tree2): 34 | tree_result = jax.tree_map(jnp.vdot, tree1, tree2) 35 | return sum(jax.tree_leaves(tree_result)) 36 | 37 | 38 | @jax.jit 39 | def tree_cosine_sim(tree1, tree2): 40 | tree_result = jax.tree_map(jnp.vdot, tree1, tree2) 41 | dot_result = sum(jax.tree_leaves(tree_result)) 42 | norm1 = global_norm(tree1) 43 | norm2 = global_norm(tree2) 44 | return dot_result / (norm1 * norm2) 45 | 46 | 47 | @jax.jit 48 | def clip_by_global_norm(tree, clip_norm): 49 | l2_g = global_norm(tree) 50 | g_factor = jnp.minimum(1.0, clip_norm / l2_g) 51 | return core.tree_mult(g_factor, tree) 52 | 53 | 54 | def get_math_fns(allowed_fns: Optional[Sequence[str]] = None): 55 | """Get the dictionary containing the math functions.""" 56 | 57 | fn_dict = {} 58 | 59 | noarg_fn_dict = dict( 60 | get_pi=lambda: jnp.pi, get_e=lambda: jnp.e, get_eps=lambda: 1e-8) 61 | 62 | for k, v in noarg_fn_dict.items(): 63 | fn_dict[k] = core.Function(v, 0, []) 64 | 65 | def nonneg(f): 66 | def g(x): 67 | return f(jnp.fabs(x)) 68 | return g 69 | 70 | def map_to_tree(f): 71 | def g(x): 72 | return jax.tree_map(f, x) 73 | return g 74 | 75 | unary_fn_dict = dict( 76 | abs=jnp.abs, 77 | cos=jnp.cos, 78 | sin=jnp.sin, 79 | tan=jnp.tan, 80 | arcsin=jnp.arcsin, 81 | arccos=jnp.arccos, 82 | arctan=jnp.arctan, 83 | sinh=jnp.sinh, 84 | cosh=jnp.cosh, 85 | tanh=jnp.tanh, 86 | arcsinh=jnp.arcsinh, 87 | arccosh=jnp.arccosh, 88 | arctanh=jnp.arctanh, 89 | exp=jnp.exp, 90 | exp2=jnp.exp2, 91 | exp10=lambda x: jnp.power(10, x), 92 | expm1=jnp.expm1, 93 | log=nonneg(jnp.log), 94 | log10=nonneg(jnp.log10), 95 | log2=nonneg(jnp.log2), 96 | log1p=lambda x: jnp.log(jnp.fabs(1 + x)), 97 | square=jnp.square, 98 | sqrt=nonneg(jnp.sqrt), 99 | cube=lambda x: jnp.power(x, 3), 100 | cbrt=lambda x: jnp.cbrt, 101 | sign=jnp.sign, 102 | reciprocal=jnp.reciprocal, 103 | norm=jnp.linalg.norm, 104 | invert=jnp.invert, 105 | negative=jnp.negative) 106 | 107 | for k, v in unary_fn_dict.items(): 108 | fn_dict[k] = core.Function(map_to_tree(v), 1, [core.is_numeric]) 109 | 110 | fn_dict['global_norm'] = core.Function(global_norm, 1, [core.is_numeric]) 111 | 112 | fn_dict['interpolate'] = core.Function( 113 | interpolate, 3, 114 | [core.ExampleAnnotation(1.0).check, core.is_numeric, core.is_numeric]) 115 | 116 | fn_dict['dot'] = core.Function( 117 | tree_dot, 2, [core.is_numeric, core.is_numeric]) 118 | 119 | fn_dict['cosine_sim'] = core.Function( 120 | tree_cosine_sim, 2, [core.is_numeric, core.is_numeric]) 121 | 122 | fn_dict['clip_by_global_norm'] = core.Function( 123 | clip_by_global_norm, 2, 124 | [core.is_numeric, core.ExampleAnnotation(1.0).check]) 125 | 126 | fn_dict['power'] = core.Function( 127 | jnp.power, 2, [core.is_numeric, core.is_numeric]) 128 | 129 | if allowed_fns is not None: 130 | new_fn_dict = {} 131 | allowed_fns_set = set(allowed_fns) 132 | for key in fn_dict: 133 | if key in allowed_fns_set: 134 | new_fn_dict[key] = fn_dict[key] 135 | else: 136 | new_fn_dict = fn_dict 137 | 138 | return new_fn_dict 139 | 140 | -------------------------------------------------------------------------------- /hero/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Chen Liang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Simply a language model.""" 15 | 16 | import os 17 | import re 18 | from typing import Sequence 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | import config_lib 24 | import data_lib 25 | import model_lib 26 | 27 | _EXPERIMENT_CONFIG = flags.DEFINE_string( 28 | 'experiment_config', 'TransformerLMTest', 'Name of the experiment config.') 29 | 30 | _SHARDING_CONFIG = flags.DEFINE_string( 31 | 'sharding_config', 'GSPMDSharding', 'Name of the sharding config.') 32 | 33 | _EXPERIMENT_DIR = flags.DEFINE_string( 34 | 'experiment_dir', '/tmp/simply_lm/', 'Path to save the experiment data.') 35 | 36 | _MESH_SHAPE = flags.DEFINE_list( 37 | 'mesh_shape', 38 | None, 39 | 'Shape for the mesh, comma separated integers, e.g. 1,265,1', 40 | ) 41 | 42 | _DCN_MESH_SHAPE = flags.DEFINE_list( 43 | 'dcn_mesh_shape', 44 | None, 45 | 'Shape for the dcn mesh, comma separated integers, e.g. 2,1,1', 46 | ) 47 | 48 | 49 | def main(argv: Sequence[str]) -> None: 50 | del argv 51 | if mesh_shape := _MESH_SHAPE.value: 52 | mesh_shape = [int(i) for i in mesh_shape] 53 | if dcn_mesh_shape := _DCN_MESH_SHAPE.value: 54 | dcn_mesh_shape = [int(i) for i in dcn_mesh_shape] 55 | config = config_lib.ExperimentConfigRegistry.get_config( 56 | _EXPERIMENT_CONFIG.value) 57 | sharding_config = config_lib.ShardingConfigRegistry.get_config( 58 | _SHARDING_CONFIG.value) 59 | logging.info('config: %s', config) 60 | logging.info('sharding_config: %s', sharding_config) 61 | logging.info('mesh_shape: %s', mesh_shape) 62 | logging.info('dcn_mesh_shape: %s', dcn_mesh_shape) 63 | model_lib.run_experiment( 64 | config=config, sharding_config=sharding_config, 65 | mesh_shape=mesh_shape, 66 | dcn_mesh_shape=dcn_mesh_shape, 67 | create_dataset=data_lib.create_dataset, 68 | experiment_dir=_EXPERIMENT_DIR.value) 69 | 70 | 71 | _TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+') 72 | 73 | if __name__ == '__main__': 74 | app.run(main) 75 | -------------------------------------------------------------------------------- /hero/requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | orbax 3 | seqio 4 | einops 5 | absl-py 6 | clu 7 | t5 8 | -------------------------------------------------------------------------------- /hero/vb100864_openmix_v1.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/hero/vb100864_openmix_v1.model -------------------------------------------------------------------------------- /hero/vb32000_t5_cc.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/hero/vb32000_t5_cc.model -------------------------------------------------------------------------------- /lion/fig/ablation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/ablation.png -------------------------------------------------------------------------------- /lion/fig/alg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/alg.png -------------------------------------------------------------------------------- /lion/fig/basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/basic.png -------------------------------------------------------------------------------- /lion/fig/diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/diffusion.png -------------------------------------------------------------------------------- /lion/fig/ft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/ft.png -------------------------------------------------------------------------------- /lion/fig/i1k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/i1k.png -------------------------------------------------------------------------------- /lion/fig/imagen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/imagen.png -------------------------------------------------------------------------------- /lion/fig/jft-ft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/jft-ft.png -------------------------------------------------------------------------------- /lion/fig/jft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/jft.png -------------------------------------------------------------------------------- /lion/fig/lit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/lit.png -------------------------------------------------------------------------------- /lion/fig/llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/llm.png -------------------------------------------------------------------------------- /lion/fig/lm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/lm.png -------------------------------------------------------------------------------- /lion/fig/retrieval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/automl/6a54c8741e7c3265d4547c4f35f47a0391122dc5/lion/fig/retrieval.png -------------------------------------------------------------------------------- /lion/lion_optax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Optax implementation of the Lion optimizer.""" 16 | 17 | from typing import Any, Callable, NamedTuple, Optional, Union 18 | 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | import optax 23 | 24 | 25 | def _scale_by_learning_rate( 26 | learning_rate: optax.ScalarOrSchedule, flip_sign=True): 27 | m = -1 if flip_sign else 1 28 | if callable(learning_rate): 29 | return optax.scale_by_schedule(lambda count: m * learning_rate(count)) 30 | return optax.scale(m * learning_rate) 31 | 32 | 33 | def lion( 34 | learning_rate: optax.ScalarOrSchedule, 35 | b1: float = 0.9, 36 | b2: float = 0.99, 37 | mu_dtype: Optional[Any] = None, 38 | weight_decay: float = 0.0, 39 | mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, 40 | ) -> optax.GradientTransformation: 41 | """Lion. 42 | 43 | Args: 44 | learning_rate: A fixed global scaling factor. 45 | b1: Exponential decay rate to combine the gradient and the moment. 46 | b2: Exponential decay rate to track the moment of past gradients. 47 | mu_dtype: Optional `dtype` to be used for the first order accumulator; if 48 | `None` then the `dtype` is inferred from `params` and `updates`. 49 | weight_decay: Strength of the weight decay regularization. Note that this 50 | weight decay is multiplied with the learning rate. This is consistent 51 | with other frameworks such as PyTorch, but different from 52 | (Loshchilov et al, 2019) where the weight decay is only multiplied with 53 | the "schedule multiplier", but not the base learning rate. 54 | mask: A tree with same structure as (or a prefix of) the params PyTree, 55 | or a Callable that returns such a pytree given the params/updates. 56 | The leaves should be booleans, `True` for leaves/subtrees you want to 57 | apply the weight decay to, and `False` for those you want to skip. Note 58 | that the Adam gradient transformations are applied to all parameters. 59 | 60 | Returns: 61 | The corresponding `GradientTransformation`. 62 | """ 63 | return optax.chain( 64 | scale_by_lion( 65 | b1=b1, b2=b2, mu_dtype=mu_dtype), 66 | optax.add_decayed_weights(weight_decay, mask), 67 | _scale_by_learning_rate(learning_rate), 68 | ) 69 | 70 | 71 | def update_moment(updates, moments, decay, order): 72 | """Compute the exponential moving average of the `order`-th moment.""" 73 | return jax.tree_util.tree_map( 74 | lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) 75 | 76 | 77 | class ScaleByLionState(NamedTuple): 78 | """State for the Lion algorithm.""" 79 | count: chex.Array # shape=(), dtype=jnp.int32. 80 | mu: optax.Updates 81 | 82 | 83 | def scale_by_lion( 84 | b1: float = 0.9, 85 | b2: float = 0.99, 86 | mu_dtype: Optional[Any] = None, 87 | ) -> optax.GradientTransformation: 88 | """Rescale updates according to the Lion algorithm. 89 | 90 | Args: 91 | b1: rate for combining moment and the current grad. 92 | b2: decay rate for the exponentially weighted average of grads. 93 | mu_dtype: optional `dtype` to be used for the first order accumulator; if 94 | `None` then the `dtype is inferred from `params` and `updates`. 95 | 96 | Returns: 97 | A `GradientTransformation` object. 98 | """ 99 | 100 | def init_fn(params): 101 | mu = jax.tree_util.tree_map( # moment 102 | lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) 103 | return ScaleByLionState(count=jnp.zeros([], jnp.int32), mu=mu) 104 | 105 | def update_fn(updates, state, params=None): 106 | del params 107 | mu = update_moment(updates, state.mu, b2, 1) 108 | mu = jax.tree_map(lambda x: x.astype(mu_dtype), mu) 109 | count_inc = optax.safe_int32_increment(state.count) 110 | updates = jax.tree_util.tree_map( 111 | lambda g, m: jnp.sign((1. - b1) * g + b1 * m), updates, state.mu) 112 | return updates, ScaleByLionState(count=count_inc, mu=mu) 113 | 114 | return optax.GradientTransformation(init_fn, update_fn) 115 | -------------------------------------------------------------------------------- /lion/lion_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """PyTorch implementation of the Lion optimizer.""" 16 | import torch 17 | from torch.optim.optimizer import Optimizer 18 | 19 | 20 | class Lion(Optimizer): 21 | r"""Implements Lion algorithm.""" 22 | 23 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): 24 | """Initialize the hyperparameters. 25 | 26 | Args: 27 | params (iterable): iterable of parameters to optimize or dicts defining 28 | parameter groups 29 | lr (float, optional): learning rate (default: 1e-4) 30 | betas (Tuple[float, float], optional): coefficients used for computing 31 | running averages of gradient and its square (default: (0.9, 0.99)) 32 | weight_decay (float, optional): weight decay coefficient (default: 0) 33 | """ 34 | 35 | if not 0.0 <= lr: 36 | raise ValueError('Invalid learning rate: {}'.format(lr)) 37 | if not 0.0 <= betas[0] < 1.0: 38 | raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) 39 | if not 0.0 <= betas[1] < 1.0: 40 | raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) 41 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 42 | super().__init__(params, defaults) 43 | 44 | @torch.no_grad() 45 | def step(self, closure=None): 46 | """Performs a single optimization step. 47 | 48 | Args: 49 | closure (callable, optional): A closure that reevaluates the model 50 | and returns the loss. 51 | 52 | Returns: 53 | the loss. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | with torch.enable_grad(): 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | 65 | # Perform stepweight decay 66 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 67 | 68 | grad = p.grad 69 | state = self.state[p] 70 | # State initialization 71 | if len(state) == 0: 72 | # Exponential moving average of gradient values 73 | state['exp_avg'] = torch.zeros_like(p) 74 | 75 | exp_avg = state['exp_avg'] 76 | beta1, beta2 = group['betas'] 77 | 78 | # Weight update 79 | update = exp_avg * beta1 + grad * (1 - beta1) 80 | 81 | p.add_(update.sign_(), alpha=-group['lr']) 82 | 83 | # Decay the momentum running average coefficient 84 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 85 | 86 | return loss 87 | -------------------------------------------------------------------------------- /lion/lion_tf1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF1 implementation of the Lion optimizer.""" 16 | from typing import Optional, Union, Callable 17 | 18 | import tensorflow.compat.v1 as tf 19 | from tensorflow.python.ops import resource_variable_ops 20 | 21 | VType = Union[Callable, float, tf.Tensor] 22 | 23 | 24 | class Lion(tf.compat.v1.train.Optimizer): 25 | """Optimizer that implements the discovered algorithm in automl-hero.""" 26 | 27 | def __init__(self, 28 | learning_rate: VType = 0.0001, 29 | beta1: VType = 0.9, 30 | beta2: VType = 0.99, 31 | wd: Optional[VType] = 0.0, 32 | use_locking=False, 33 | name="Lion"): 34 | r"""Construct a new Lion optimizer. 35 | 36 | Args: 37 | learning_rate: A Tensor or a floating point value. The learning rate. 38 | beta1: A float value or a constant float tensor. The rate to combine 39 | the gradient and the moment estimate. 40 | beta2: A float value or a constant float tensor. The exponential decay 41 | rate for the moment estimate. 42 | wd: Optional[A float value or a constant float tensor]. 43 | The decoupled weight decay. 44 | use_locking: If True use locks for update operations. 45 | name: Optional name for the operations created when applying gradients. 46 | """ 47 | super(Lion, self).__init__(use_locking, name) 48 | self._lr = learning_rate 49 | self._beta1 = beta1 50 | self._beta2 = beta2 51 | self._wd = None if isinstance(wd, float) and wd < 0 else wd 52 | 53 | # Tensor versions of the constructor arguments, created in _prepare(). 54 | self._lr_t = None 55 | self._beta1_t = None 56 | self._beta2_t = None 57 | self._wd_t = None 58 | 59 | def _create_slots(self, var_list): 60 | # Create slots for the moment. 61 | for v in var_list: 62 | self._zeros_slot(v, "m", self._name) 63 | 64 | def _prepare(self): 65 | lr = self._call_if_callable(self._lr) 66 | beta1 = self._call_if_callable(self._beta1) 67 | beta2 = self._call_if_callable(self._beta2) 68 | wd = self._call_if_callable(self._wd) 69 | 70 | self._lr_t = tf.convert_to_tensor(lr, name="learning_rate") 71 | self._beta1_t = tf.convert_to_tensor(beta1, name="beta1") 72 | self._beta2_t = tf.convert_to_tensor(beta2, name="beta2") 73 | if wd is not None: 74 | self._wd_t = tf.convert_to_tensor(wd, name="weight_decay") 75 | 76 | def _apply_dense_shared(self, grad, var): 77 | m = self.get_slot(var, "m") 78 | 79 | lr_t = tf.cast(self._lr_t, dtype=var.dtype) 80 | beta1_t = tf.cast(self._beta1_t, dtype=var.dtype) 81 | beta2_t = tf.cast(self._beta2_t, dtype=var.dtype) 82 | if self._wd_t is None: 83 | weight_decay_t = None 84 | else: 85 | weight_decay_t = tf.cast(self._wd_t, dtype=var.dtype) 86 | 87 | updates_grad = tf.sign(m * beta1_t + grad * (1. - beta1_t)) 88 | if weight_decay_t is not None: 89 | updates_grad = updates_grad + var * weight_decay_t 90 | 91 | var_update = tf.assign_sub( 92 | var, lr_t * updates_grad, use_locking=self._use_locking) 93 | with tf.control_dependencies([var_update]): 94 | m_update = tf.assign(m, m * beta2_t + grad * (1. - beta2_t)) 95 | return tf.group(*[var_update, m_update]) 96 | 97 | def _apply_dense(self, grad, var): 98 | return self._apply_dense_shared(grad, var) 99 | 100 | def _resource_apply_dense(self, grad, var): 101 | return self._apply_dense_shared(grad, var) 102 | 103 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 104 | m = self.get_slot(var, "m") 105 | lr_t = tf.cast(self._lr_t, var.dtype.base_dtype) 106 | beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype) 107 | beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype) 108 | wd_t = tf.cast(self._wd_t, var.dtype.base_dtype) 109 | 110 | m_update = tf.assign(m, m * beta1_t, use_locking=self._use_locking) 111 | with tf.control_dependencies([m_update]): 112 | m_update = scatter_add(m, indices, grad * (1. - beta1_t)) 113 | with tf.control_dependencies([m_update]): 114 | var_update = tf.assign_sub( 115 | var, 116 | lr_t * (tf.sign(m) + var * wd_t), 117 | use_locking=self._use_locking) 118 | with tf.control_dependencies([var_update]): 119 | m_update = scatter_add(m, indices, grad * (beta1_t - 1.)) 120 | with tf.control_dependencies([m_update]): 121 | m_update = tf.assign( 122 | m, m * beta2_t / beta1_t, use_locking=self._use_locking) 123 | with tf.control_dependencies([m_update]): 124 | m_update = scatter_add(m, indices, grad * (1. - beta2_t)) 125 | return tf.group(*[var_update, m_update]) 126 | 127 | def _apply_sparse(self, grad, var): 128 | return self._apply_sparse_shared( 129 | grad.values, 130 | var, 131 | grad.indices, 132 | lambda x, i, v: tf.scatter_add( 133 | x, 134 | i, 135 | v, 136 | use_locking=self._use_locking)) 137 | 138 | def _resource_scatter_add(self, x, i, v): 139 | with tf.control_dependencies( 140 | [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 141 | return x.value() 142 | 143 | def _resource_apply_sparse(self, grad, var, indices): 144 | return self._apply_sparse_shared(grad, var, indices, 145 | self._resource_scatter_add) 146 | -------------------------------------------------------------------------------- /lion/lion_tf2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF2 implementation of the Lion optimizer.""" 16 | 17 | import tensorflow.compat.v2 as tf 18 | 19 | 20 | class Lion(tf.keras.optimizers.legacy.Optimizer): 21 | r"""Optimizer that implements the Lion algorithm.""" 22 | 23 | def __init__(self, 24 | learning_rate=0.0001, 25 | beta_1=0.9, 26 | beta_2=0.99, 27 | wd=0, 28 | name='lion', 29 | **kwargs): 30 | """Construct a new Lion optimizer.""" 31 | 32 | super(Lion, self).__init__(name, **kwargs) 33 | self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 34 | self._set_hyper('beta_1', beta_1) 35 | self._set_hyper('beta_2', beta_2) 36 | self._set_hyper('wd', wd) 37 | 38 | def _create_slots(self, var_list): 39 | # Create slots for the first and second moments. 40 | # Separate for-loops to respect the ordering of slot variables from v1. 41 | for var in var_list: 42 | self.add_slot(var, 'm') 43 | 44 | def _prepare_local(self, var_device, var_dtype, apply_state): 45 | super(Lion, self)._prepare_local(var_device, var_dtype, apply_state) 46 | 47 | beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype)) 48 | beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype)) 49 | wd_t = tf.identity(self._get_hyper('wd', var_dtype)) 50 | lr = apply_state[(var_device, var_dtype)]['lr_t'] 51 | apply_state[(var_device, var_dtype)].update( 52 | dict( 53 | lr=lr, 54 | beta_1_t=beta_1_t, 55 | one_minus_beta_1_t=1 - beta_1_t, 56 | beta_2_t=beta_2_t, 57 | one_minus_beta_2_t=1 - beta_2_t, 58 | wd_t=wd_t)) 59 | 60 | @tf.function(jit_compile=True) 61 | def _resource_apply_dense(self, grad, var, apply_state=None): 62 | var_device, var_dtype = var.device, var.dtype.base_dtype 63 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) or 64 | self._fallback_apply_state(var_device, var_dtype)) 65 | 66 | m = self.get_slot(var, 'm') 67 | var_t = var.assign_sub( 68 | coefficients['lr_t'] * 69 | (tf.math.sign(m * coefficients['beta_1_t'] + 70 | grad * coefficients['one_minus_beta_1_t']) + 71 | var * coefficients['wd_t'])) 72 | with tf.control_dependencies([var_t]): 73 | m.assign(m * coefficients['beta_2_t'] + 74 | grad * coefficients['one_minus_beta_2_t']) 75 | 76 | @tf.function(jit_compile=True) 77 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 78 | var_device, var_dtype = var.device, var.dtype.base_dtype 79 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) or 80 | self._fallback_apply_state(var_device, var_dtype)) 81 | 82 | m = self.get_slot(var, 'm') 83 | m_t = m.assign(m * coefficients['beta_1_t']) 84 | m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 85 | m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) 86 | var_t = var.assign_sub(coefficients['lr'] * 87 | (tf.math.sign(m_t) + var * coefficients['wd_t'])) 88 | 89 | with tf.control_dependencies([var_t]): 90 | m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices)) 91 | m_t = m_t.assign(m_t * coefficients['beta_2_t'] / 92 | coefficients['beta_1_t']) 93 | m_scaled_g_values = grad * coefficients['one_minus_beta_2_t'] 94 | m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) 95 | 96 | def get_config(self): 97 | config = super(Lion, self).get_config() 98 | config.update({ 99 | 'learning_rate': self._serialize_hyperparameter('learning_rate'), 100 | 'beta_1': self._serialize_hyperparameter('beta_1'), 101 | 'beta_2': self._serialize_hyperparameter('beta_2'), 102 | 'wd': self._serialize_hyperparameter('wd'), 103 | }) 104 | return config 105 | --------------------------------------------------------------------------------