├── .gitignore ├── efficientdet ├── horovod_estimator │ ├── horovod │ ├── utis.py │ ├── __init__.py │ └── estimator.py ├── g3doc │ ├── flops.png │ ├── params.png │ ├── network.png │ ├── efficientdet-d0_train.png │ └── efficientdet-d0_val.png ├── testdata │ ├── img1.jpg │ └── img1-d1.jpg ├── scripts │ └── train.sh ├── __init__.py ├── aug │ ├── __init__.py │ └── autoaugment_test.py ├── backbone │ ├── __init__.py │ ├── backbone_factory.py │ ├── efficientnet_lite_builder_test.py │ ├── efficientnet_builder_test.py │ ├── efficientnet_lite_builder.py │ ├── efficientnet_model_test.py │ └── efficientnet_builder.py ├── visualize │ ├── __init__.py │ ├── static_shape.py │ └── standard_fields.py ├── object_detection │ ├── __init__.py │ ├── shape_utils.py │ ├── faster_rcnn_box_coder.py │ ├── region_similarity_calculator.py │ ├── box_coder.py │ ├── box_list.py │ ├── tf_example_decoder.py │ ├── argmax_matcher.py │ ├── matcher.py │ └── target_assigner.py ├── dataset │ ├── __init__.py │ ├── README.md │ ├── tfrecord_util.py │ ├── create_pascal_tfrecord_test.py │ ├── label_map_util.py │ ├── create_coco_tfrecord_test.py │ └── create_pascal_tfrecord.py ├── coco_metric.py ├── hparams_config.py └── normalization_v2.py ├── README.md ├── CONTRIBUTING.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.pyc 3 | .idea 4 | -------------------------------------------------------------------------------- /efficientdet/horovod_estimator/horovod: -------------------------------------------------------------------------------- 1 | /Users/liupeng/code/horovod -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Brain AutoML 2 | 3 | This repository contains a list of AutoML related models and libraries. 4 | -------------------------------------------------------------------------------- /efficientdet/g3doc/flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnslarcher/automl/master/efficientdet/g3doc/flops.png -------------------------------------------------------------------------------- /efficientdet/g3doc/params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnslarcher/automl/master/efficientdet/g3doc/params.png -------------------------------------------------------------------------------- /efficientdet/g3doc/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnslarcher/automl/master/efficientdet/g3doc/network.png -------------------------------------------------------------------------------- /efficientdet/testdata/img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnslarcher/automl/master/efficientdet/testdata/img1.jpg -------------------------------------------------------------------------------- /efficientdet/testdata/img1-d1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnslarcher/automl/master/efficientdet/testdata/img1-d1.jpg -------------------------------------------------------------------------------- /efficientdet/g3doc/efficientdet-d0_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnslarcher/automl/master/efficientdet/g3doc/efficientdet-d0_train.png -------------------------------------------------------------------------------- /efficientdet/g3doc/efficientdet-d0_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mnslarcher/automl/master/efficientdet/g3doc/efficientdet-d0_val.png -------------------------------------------------------------------------------- /efficientdet/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | num_gpus=$1 3 | 4 | #export NCCL_P2P_DISABLE=1 5 | export PYTHONPATH=`pwd`:$PYTHONPATH 6 | 7 | mpirun -np $num_gpus -H localhost:$num_gpus \ 8 | --allow-run-as-root -bind-to none -map-by slot -x NCCL_DEBUG=INFO \ 9 | -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib \ 10 | --mca btl_smcuda_use_cuda_ipc 0 \ 11 | python main.py --training_file_pattern=/share/liupeng/data/cv/coco/coco_tfrecord/train* \ 12 | --model_name=$MODEL \ 13 | --model_dir=/tmp/$MODEL \ 14 | --hparams="use_bfloat16=false" \ 15 | --use_tpu=False \ 16 | --train_batch_size 4 17 | -------------------------------------------------------------------------------- /efficientdet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /efficientdet/aug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /efficientdet/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /efficientdet/visualize/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # Visualization library is mostly based on TensorFlow object detection API: 16 | # https://github.com/tensorflow/models/tree/master/research/object_detection 17 | -------------------------------------------------------------------------------- /efficientdet/object_detection/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # Object detection data loaders and libraries are mostly based on RetinaNet: 16 | # https://github.com/tensorflow/tpu/tree/master/models/official/retinanet 17 | -------------------------------------------------------------------------------- /efficientdet/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # This library is mostly based on tensorflow object detection API 16 | # https://github.com/tensorflow/models/blob/master/research/object_detection/dataset_tools/create_coco_tf_record.py 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /efficientdet/dataset/README.md: -------------------------------------------------------------------------------- 1 | This folder provides tools for converting raw coco/pascal data to tfrecord. 2 | 3 | ### 1. Convert COCO validation set to tfrecord: 4 | 5 | # Download coco data. 6 | !wget http://images.cocodataset.org/zips/val2017.zip 7 | !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 8 | !unzip val2017.zip 9 | !unzip annotations_trainval2017.zip 10 | 11 | # convert coco data to tfrecord. 12 | !mkdir tfrecord 13 | !PYTHONPATH=".:$PYTHONPATH" python dataset/create_coco_tfrecord.py \ 14 | --image_dir=val2017 \ 15 | --caption_annotations_file=annotations/captions_val2017.json \ 16 | --output_file_prefix=tfrecord/val \ 17 | --num_shards=32 18 | 19 | ### 2. Convert Pascal VOC 2012 to tfrecord: 20 | 21 | # Download and convert pascal data. 22 | !wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 23 | !tar xf VOCtrainval_11-May-2012.tar 24 | !mkdir tfrecord 25 | !PYTHONPATH=".:$PYTHONPATH" python dataset/create_pascal_tfrecord.py \ 26 | --data_dir=VOCdevkit --year=VOC2012 --output_path=tfrecord/pascal 27 | 28 | Attention: soure_id (or image_id) needs to be an integer due to the official COCO library requreiments. 29 | -------------------------------------------------------------------------------- /efficientdet/backbone/backbone_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Backbone network factory.""" 16 | 17 | from backbone import efficientnet_builder 18 | from backbone import efficientnet_lite_builder 19 | 20 | 21 | def get_model_builder(model_name): 22 | """Get the model_builder module for a given model name.""" 23 | if model_name.startswith('efficientnet-lite'): 24 | return efficientnet_lite_builder 25 | elif model_name.startswith('efficientnet-b'): 26 | return efficientnet_builder 27 | else: 28 | raise ValueError('Unknown model name {}'.format(model_name)) 29 | -------------------------------------------------------------------------------- /efficientdet/aug/autoaugment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Autoaugment.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | 23 | from aug import autoaugment 24 | 25 | 26 | class AutoaugmentTest(tf.test.TestCase): 27 | 28 | def test_autoaugment_policy(self): 29 | # A very simple test to verify no syntax error. 30 | image = tf.placeholder(tf.uint8, shape=[640, 640, 3]) 31 | bboxes = tf.placeholder(tf.float32, shape=[4, 4]) 32 | autoaugment.distort_image_with_autoaugment(image, bboxes, 'test') 33 | 34 | 35 | if __name__ == '__main__': 36 | tf.disable_v2_behavior() 37 | tf.test.main() 38 | 39 | -------------------------------------------------------------------------------- /efficientdet/horovod_estimator/utis.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | try: 4 | import horovod.tensorflow as hvd 5 | except ImportError: 6 | hvd = None 7 | import tensorflow.compat.v1 as tf 8 | 9 | def is_rank0(): 10 | if hvd is not None: 11 | return hvd.rank() == 0 12 | else: 13 | return True 14 | 15 | 16 | global IS_HVD_INIT 17 | IS_HVD_INIT = False 18 | 19 | 20 | def hvd_try_init(): 21 | global IS_HVD_INIT 22 | if not IS_HVD_INIT and hvd is not None: 23 | hvd.init() 24 | IS_HVD_INIT = True 25 | 26 | tf.get_logger().propagate = False 27 | if hvd.rank() == 0: 28 | tf.logging.set_verbosity('INFO') 29 | else: 30 | tf.logging.set_verbosity('WARN') 31 | 32 | 33 | def hvd_info(msg): 34 | hvd_try_init() 35 | if hvd is not None: 36 | head = 'hvd rank{}/{} in {}'.format(hvd.rank(), hvd.size(), socket.gethostname()) 37 | else: 38 | head = '{}'.format(socket.gethostname()) 39 | tf.logging.info('{}: {}'.format(head, msg)) 40 | 41 | 42 | def hvd_info_rank0(msg, with_head=True): 43 | hvd_try_init() 44 | if is_rank0(): 45 | if with_head: 46 | if hvd is not None: 47 | head = 'hvd only rank{}/{} in {}'.format(hvd.rank(), hvd.size(), socket.gethostname()) 48 | else: 49 | head = '{}'.format(socket.gethostname()) 50 | tf.logging.info('{}: {}'.format(head, msg)) 51 | else: 52 | tf.logging.info(msg) 53 | 54 | -------------------------------------------------------------------------------- /efficientdet/horovod_estimator/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import horovod.tensorflow as hvd 3 | except ImportError: 4 | hvd = None 5 | 6 | import glob 7 | import os 8 | from multiprocessing import Pool, cpu_count 9 | 10 | import numpy as np 11 | import tensorflow.compat.v1 as tf 12 | 13 | from .estimator import HorovodEstimator 14 | from .utis import hvd_info, hvd_info_rank0, hvd_try_init 15 | 16 | 17 | def _count_per_file(filename): 18 | c = 0 19 | for _ in tf.python_io.tf_record_iterator(filename): 20 | c += 1 21 | 22 | return c 23 | 24 | 25 | def get_record_num(filenames): 26 | pool = Pool(cpu_count()) 27 | c_list = pool.map(_count_per_file, filenames) 28 | total_count = np.sum(np.array(c_list)) 29 | return total_count 30 | 31 | 32 | def get_filenames(data_dir: str, filename_regexp: str, show_result=True): 33 | filenames = glob.glob(os.path.join(data_dir, filename_regexp)) 34 | if show_result: 35 | hvd_info_rank0('find {} files in {}, such as {}'.format(len(filenames), data_dir, filenames[0:5])) 36 | return filenames 37 | 38 | 39 | def _idx_a_minus_b(a, b): 40 | a_splits = a.split('/') 41 | b_splits = b.split('/') 42 | for i in range(min(len(a_splits), len(b_splits))): 43 | if a_splits[i] != b_splits[i]: 44 | break 45 | 46 | return len('/'.join(a_splits[0:i])) 47 | 48 | 49 | def show_model(): 50 | prev = None 51 | for var in tf.global_variables(): 52 | # if var.name.split('/')[-1] in ['beta:0', 'moving_mean:0', 'moving_variance:0']: 53 | # continue 54 | 55 | if prev is None: 56 | print('{} - {}'.format(var.name, var.shape.as_list())) 57 | else: 58 | idx = _idx_a_minus_b(var.name, prev.name) 59 | short_name = var.name[idx:] 60 | if short_name.startswith('/bn'): 61 | short_name = '/bn' 62 | print('{}{} - {}'.format(' ' * idx, short_name, var.shape.as_list())) 63 | prev = var 64 | -------------------------------------------------------------------------------- /efficientdet/object_detection/shape_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utils used to manipulate tensor shapes.""" 16 | 17 | import tensorflow.compat.v1 as tf 18 | 19 | 20 | def assert_shape_equal(shape_a, shape_b): 21 | """Asserts that shape_a and shape_b are equal. 22 | 23 | If the shapes are static, raises a ValueError when the shapes 24 | mismatch. 25 | 26 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes 27 | mismatch. 28 | 29 | Args: 30 | shape_a: a list containing shape of the first tensor. 31 | shape_b: a list containing shape of the second tensor. 32 | 33 | Returns: 34 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op 35 | when the shapes are dynamic. 36 | 37 | Raises: 38 | ValueError: When shapes are both static and unequal. 39 | """ 40 | if (all(isinstance(dim, int) for dim in shape_a) and 41 | all(isinstance(dim, int) for dim in shape_b)): 42 | if shape_a != shape_b: 43 | raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) 44 | else: return tf.no_op() 45 | else: 46 | return tf.assert_equal(shape_a, shape_b) 47 | 48 | 49 | def combined_static_and_dynamic_shape(tensor): 50 | """Returns a list containing static and dynamic values for the dimensions. 51 | 52 | Returns a list of static and dynamic values for shape dimensions. This is 53 | useful to preserve static shapes when available in reshape operation. 54 | 55 | Args: 56 | tensor: A tensor of any type. 57 | 58 | Returns: 59 | A list of size tensor.shape.ndims containing integers or a scalar tensor. 60 | """ 61 | static_tensor_shape = tensor.shape.as_list() 62 | dynamic_tensor_shape = tf.shape(tensor) 63 | combined_shape = [] 64 | for index, dim in enumerate(static_tensor_shape): 65 | if dim is not None: 66 | combined_shape.append(dim) 67 | else: 68 | combined_shape.append(dynamic_tensor_shape[index]) 69 | return combined_shape 70 | -------------------------------------------------------------------------------- /efficientdet/visualize/static_shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions to access TensorShape values. 17 | 18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | 26 | def get_dim_as_int(dim): 27 | """Utility to get v1 or v2 TensorShape dim as an int. 28 | 29 | Args: 30 | dim: The TensorShape dimension to get as an int 31 | 32 | Returns: 33 | None or an int. 34 | """ 35 | try: 36 | return dim.value 37 | except AttributeError: 38 | return dim 39 | 40 | 41 | def get_batch_size(tensor_shape): 42 | """Returns batch size from the tensor shape. 43 | 44 | Args: 45 | tensor_shape: A rank 4 TensorShape. 46 | 47 | Returns: 48 | An integer representing the batch size of the tensor. 49 | """ 50 | tensor_shape.assert_has_rank(rank=4) 51 | return get_dim_as_int(tensor_shape[0]) 52 | 53 | 54 | def get_height(tensor_shape): 55 | """Returns height from the tensor shape. 56 | 57 | Args: 58 | tensor_shape: A rank 4 TensorShape. 59 | 60 | Returns: 61 | An integer representing the height of the tensor. 62 | """ 63 | tensor_shape.assert_has_rank(rank=4) 64 | return get_dim_as_int(tensor_shape[1]) 65 | 66 | 67 | def get_width(tensor_shape): 68 | """Returns width from the tensor shape. 69 | 70 | Args: 71 | tensor_shape: A rank 4 TensorShape. 72 | 73 | Returns: 74 | An integer representing the width of the tensor. 75 | """ 76 | tensor_shape.assert_has_rank(rank=4) 77 | return get_dim_as_int(tensor_shape[2]) 78 | 79 | 80 | def get_depth(tensor_shape): 81 | """Returns depth from the tensor shape. 82 | 83 | Args: 84 | tensor_shape: A rank 4 TensorShape. 85 | 86 | Returns: 87 | An integer representing the depth of the tensor. 88 | """ 89 | tensor_shape.assert_has_rank(rank=4) 90 | return get_dim_as_int(tensor_shape[3]) 91 | -------------------------------------------------------------------------------- /efficientdet/backbone/efficientnet_lite_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for efficientnet_lite_builder.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow.compat.v1 as tf 23 | 24 | from backbone import efficientnet_lite_builder 25 | 26 | 27 | class EfficientnetBuilderTest(tf.test.TestCase): 28 | 29 | def _test_model_params(self, 30 | model_name, 31 | input_size, 32 | expected_params, 33 | override_params=None, 34 | features_only=False, 35 | pooled_features_only=False): 36 | images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32) 37 | efficientnet_lite_builder.build_model( 38 | images, 39 | model_name=model_name, 40 | override_params=override_params, 41 | training=True, 42 | features_only=features_only, 43 | pooled_features_only=pooled_features_only) 44 | num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 45 | 46 | self.assertEqual(num_params, expected_params) 47 | 48 | def test_efficientnet_b0(self): 49 | self._test_model_params( 50 | 'efficientnet-lite0', 224, expected_params=4652008) 51 | 52 | def test_efficientnet_b1(self): 53 | self._test_model_params( 54 | 'efficientnet-lite1', 240, expected_params=5416680) 55 | 56 | def test_efficientnet_b2(self): 57 | self._test_model_params( 58 | 'efficientnet-lite2', 260, expected_params=6092072) 59 | 60 | def test_efficientnet_b3(self): 61 | self._test_model_params( 62 | 'efficientnet-lite3', 280, expected_params=8197096) 63 | 64 | def test_efficientnet_b4(self): 65 | self._test_model_params( 66 | 'efficientnet-lite4', 300, expected_params=13006568) 67 | 68 | 69 | if __name__ == '__main__': 70 | tf.disable_v2_behavior() 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /efficientdet/dataset/tfrecord_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""TFRecord related utilities.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from six.moves import range 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | def int64_feature(value): 25 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 26 | 27 | 28 | def int64_list_feature(value): 29 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 30 | 31 | 32 | def bytes_feature(value): 33 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 34 | 35 | 36 | def bytes_list_feature(value): 37 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 38 | 39 | 40 | def float_list_feature(value): 41 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 42 | 43 | 44 | def read_examples_list(path): 45 | """Read list of training or validation examples. 46 | 47 | The file is assumed to contain a single example per line where the first 48 | token in the line is an identifier that allows us to find the image and 49 | annotation xml for that example. 50 | 51 | For example, the line: 52 | xyz 3 53 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 54 | 55 | Args: 56 | path: absolute path to examples list file. 57 | 58 | Returns: 59 | list of example identifiers (strings). 60 | """ 61 | with tf.gfile.GFile(path) as fid: 62 | lines = fid.readlines() 63 | return [line.strip().split(' ')[0] for line in lines] 64 | 65 | 66 | def recursive_parse_xml_to_dict(xml): 67 | """Recursively parses XML contents to python dict. 68 | 69 | We assume that `object` tags are the only ones that can appear 70 | multiple times at the same level of a tree. 71 | 72 | Args: 73 | xml: xml tree obtained by parsing XML file contents using lxml.etree 74 | 75 | Returns: 76 | Python dictionary holding XML contents. 77 | """ 78 | if not xml: 79 | return {xml.tag: xml.text} 80 | result = {} 81 | for child in xml: 82 | child_result = recursive_parse_xml_to_dict(child) 83 | if child.tag != 'object': 84 | result[child.tag] = child_result[child.tag] 85 | else: 86 | if child.tag not in result: 87 | result[child.tag] = [] 88 | result[child.tag].append(child_result[child.tag]) 89 | return {xml.tag: result} 90 | 91 | 92 | def open_sharded_output_tfrecords(exit_stack, base_path, num_shards): 93 | """Opens all TFRecord shards for writing and adds them to an exit stack. 94 | 95 | Args: 96 | exit_stack: A context2.ExitStack used to automatically closed the TFRecords 97 | opened in this function. 98 | base_path: The base path for all shards 99 | num_shards: The number of shards 100 | 101 | Returns: 102 | The list of opened TFRecords. Position k in the list corresponds to shard k. 103 | """ 104 | tf_record_output_filenames = [ 105 | '{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards) 106 | for idx in range(num_shards) 107 | ] 108 | 109 | tfrecords = [ 110 | exit_stack.enter_context(tf.python_io.TFRecordWriter(file_name)) 111 | for file_name in tf_record_output_filenames 112 | ] 113 | 114 | return tfrecords 115 | -------------------------------------------------------------------------------- /efficientdet/object_detection/faster_rcnn_box_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Faster RCNN box coder. 16 | 17 | Faster RCNN box coder follows the coding schema described below: 18 | ty = (y - ya) / ha 19 | tx = (x - xa) / wa 20 | th = log(h / ha) 21 | tw = log(w / wa) 22 | where x, y, w, h denote the box's center coordinates, width and height 23 | respectively. Similarly, xa, ya, wa, ha denote the anchor's center 24 | coordinates, width and height. tx, ty, tw and th denote the anchor-encoded 25 | center, width and height respectively. 26 | 27 | See http://arxiv.org/abs/1506.01497 for details. 28 | """ 29 | 30 | import tensorflow.compat.v1 as tf 31 | 32 | from object_detection import box_coder 33 | from object_detection import box_list 34 | 35 | EPSILON = 1e-8 36 | 37 | 38 | class FasterRcnnBoxCoder(box_coder.BoxCoder): 39 | """Faster RCNN box coder.""" 40 | 41 | def __init__(self, scale_factors=None): 42 | """Constructor for FasterRcnnBoxCoder. 43 | 44 | Args: 45 | scale_factors: List of 4 positive scalars to scale ty, tx, th and tw. 46 | If set to None, does not perform scaling. For Faster RCNN, 47 | the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0]. 48 | """ 49 | if scale_factors: 50 | assert len(scale_factors) == 4 51 | for scalar in scale_factors: 52 | assert scalar > 0 53 | self._scale_factors = scale_factors 54 | 55 | @property 56 | def code_size(self): 57 | return 4 58 | 59 | def _encode(self, boxes, anchors): 60 | """Encode a box collection with respect to anchor collection. 61 | 62 | Args: 63 | boxes: BoxList holding N boxes to be encoded. 64 | anchors: BoxList of anchors. 65 | 66 | Returns: 67 | a tensor representing N anchor-encoded boxes of the format 68 | [ty, tx, th, tw]. 69 | """ 70 | # Convert anchors to the center coordinate representation. 71 | ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes() 72 | ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes() 73 | # Avoid NaN in division and log below. 74 | ha += EPSILON 75 | wa += EPSILON 76 | h += EPSILON 77 | w += EPSILON 78 | 79 | tx = (xcenter - xcenter_a) / wa 80 | ty = (ycenter - ycenter_a) / ha 81 | tw = tf.log(w / wa) 82 | th = tf.log(h / ha) 83 | # Scales location targets as used in paper for joint training. 84 | if self._scale_factors: 85 | ty *= self._scale_factors[0] 86 | tx *= self._scale_factors[1] 87 | th *= self._scale_factors[2] 88 | tw *= self._scale_factors[3] 89 | return tf.transpose(tf.stack([ty, tx, th, tw])) 90 | 91 | def _decode(self, rel_codes, anchors): 92 | """Decode relative codes to boxes. 93 | 94 | Args: 95 | rel_codes: a tensor representing N anchor-encoded boxes. 96 | anchors: BoxList of anchors. 97 | 98 | Returns: 99 | boxes: BoxList holding N bounding boxes. 100 | """ 101 | ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes() 102 | 103 | ty, tx, th, tw = tf.unstack(tf.transpose(rel_codes)) 104 | if self._scale_factors: 105 | ty /= self._scale_factors[0] 106 | tx /= self._scale_factors[1] 107 | th /= self._scale_factors[2] 108 | tw /= self._scale_factors[3] 109 | w = tf.exp(tw) * wa 110 | h = tf.exp(th) * ha 111 | ycenter = ty * ha + ycenter_a 112 | xcenter = tx * wa + xcenter_a 113 | ymin = ycenter - h / 2. 114 | xmin = xcenter - w / 2. 115 | ymax = ycenter + h / 2. 116 | xmax = xcenter + w / 2. 117 | return box_list.BoxList(tf.transpose(tf.stack([ymin, xmin, ymax, xmax]))) 118 | -------------------------------------------------------------------------------- /efficientdet/dataset/create_pascal_tfrecord_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test for create_pascal_tfrecord.py.""" 16 | 17 | import os 18 | 19 | import numpy as np 20 | import PIL.Image 21 | import six 22 | import tensorflow.compat.v1 as tf 23 | 24 | from dataset import create_pascal_tfrecord 25 | 26 | 27 | class CreatePascalTFRecordTest(tf.test.TestCase): 28 | 29 | def _assertProtoEqual(self, proto_field, expectation): 30 | """Helper function to assert if a proto field equals some value. 31 | 32 | Args: 33 | proto_field: The protobuf field to compare. 34 | expectation: The expected value of the protobuf field. 35 | """ 36 | proto_list = [p for p in proto_field] 37 | self.assertListEqual(proto_list, expectation) 38 | 39 | def test_dict_to_tf_example(self): 40 | image_file_name = '2012_12.jpg' 41 | image_data = np.random.rand(256, 256, 3) 42 | save_path = os.path.join(self.get_temp_dir(), image_file_name) 43 | image = PIL.Image.fromarray(image_data, 'RGB') 44 | image.save(save_path) 45 | 46 | data = { 47 | 'folder': '', 48 | 'filename': image_file_name, 49 | 'size': { 50 | 'height': 256, 51 | 'width': 256, 52 | }, 53 | 'object': [ 54 | { 55 | 'difficult': 1, 56 | 'bndbox': { 57 | 'xmin': 64, 58 | 'ymin': 64, 59 | 'xmax': 192, 60 | 'ymax': 192, 61 | }, 62 | 'name': 'person', 63 | 'truncated': 0, 64 | 'pose': '', 65 | }, 66 | ], 67 | } 68 | 69 | label_map_dict = { 70 | 'background': 0, 71 | 'person': 1, 72 | 'notperson': 2, 73 | } 74 | 75 | example = create_pascal_tfrecord.dict_to_tf_example( 76 | data, self.get_temp_dir(), label_map_dict, image_subdirectory='') 77 | self._assertProtoEqual( 78 | example.features.feature['image/height'].int64_list.value, [256]) 79 | self._assertProtoEqual( 80 | example.features.feature['image/width'].int64_list.value, [256]) 81 | self._assertProtoEqual( 82 | example.features.feature['image/filename'].bytes_list.value, 83 | [six.b(image_file_name)]) 84 | self._assertProtoEqual( 85 | example.features.feature['image/source_id'].bytes_list.value, 86 | [six.b(str(1))]) 87 | self._assertProtoEqual( 88 | example.features.feature['image/format'].bytes_list.value, 89 | [six.b('jpeg')]) 90 | self._assertProtoEqual( 91 | example.features.feature['image/object/bbox/xmin'].float_list.value, 92 | [0.25]) 93 | self._assertProtoEqual( 94 | example.features.feature['image/object/bbox/ymin'].float_list.value, 95 | [0.25]) 96 | self._assertProtoEqual( 97 | example.features.feature['image/object/bbox/xmax'].float_list.value, 98 | [0.75]) 99 | self._assertProtoEqual( 100 | example.features.feature['image/object/bbox/ymax'].float_list.value, 101 | [0.75]) 102 | self._assertProtoEqual( 103 | example.features.feature['image/object/class/text'].bytes_list.value, 104 | [six.b('person')]) 105 | self._assertProtoEqual( 106 | example.features.feature['image/object/class/label'].int64_list.value, 107 | [1]) 108 | self._assertProtoEqual( 109 | example.features.feature['image/object/difficult'].int64_list.value, 110 | [1]) 111 | self._assertProtoEqual( 112 | example.features.feature['image/object/truncated'].int64_list.value, 113 | [0]) 114 | self._assertProtoEqual( 115 | example.features.feature['image/object/view'].bytes_list.value, 116 | [six.b('')]) 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.test.main() 121 | -------------------------------------------------------------------------------- /efficientdet/backbone/efficientnet_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for efficientnet_builder.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow.compat.v1 as tf 23 | 24 | from backbone import efficientnet_builder 25 | 26 | 27 | class EfficientnetBuilderTest(tf.test.TestCase): 28 | 29 | def _test_model_params(self, 30 | model_name, 31 | input_size, 32 | expected_params, 33 | override_params=None, 34 | features_only=False, 35 | pooled_features_only=False): 36 | images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32) 37 | efficientnet_builder.build_model( 38 | images, 39 | model_name=model_name, 40 | override_params=override_params, 41 | training=True, 42 | features_only=features_only, 43 | pooled_features_only=pooled_features_only) 44 | num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 45 | self.assertEqual(num_params, expected_params) 46 | 47 | def test_efficientnet_b0(self): 48 | self._test_model_params('efficientnet-b0', 224, expected_params=5288548) 49 | 50 | def test_efficientnet_b1(self): 51 | self._test_model_params('efficientnet-b1', 240, expected_params=7794184) 52 | 53 | def test_efficientnet_b2(self): 54 | self._test_model_params('efficientnet-b2', 260, expected_params=9109994) 55 | 56 | def test_efficientnet_b3(self): 57 | self._test_model_params('efficientnet-b3', 300, expected_params=12233232) 58 | 59 | def test_efficientnet_b4(self): 60 | self._test_model_params('efficientnet-b4', 380, expected_params=19341616) 61 | 62 | def test_efficientnet_b5(self): 63 | self._test_model_params('efficientnet-b5', 456, expected_params=30389784) 64 | 65 | def test_efficientnet_b6(self): 66 | self._test_model_params('efficientnet-b6', 528, expected_params=43040704) 67 | 68 | def test_efficientnet_b7(self): 69 | self._test_model_params('efficientnet-b7', 600, expected_params=66347960) 70 | 71 | def test_efficientnet_b0_with_customized_num_classes(self): 72 | self._test_model_params( 73 | 'efficientnet-b0', 74 | 224, 75 | expected_params=4135648, 76 | override_params={'num_classes': 100}) 77 | 78 | def test_efficientnet_b0_with_features_only(self): 79 | self._test_model_params( 80 | 'efficientnet-b0', 224, features_only=True, expected_params=3595388) 81 | 82 | def test_efficientnet_b0_with_pooled_features_only(self): 83 | self._test_model_params( 84 | 'efficientnet-b0', 85 | 224, 86 | pooled_features_only=True, 87 | expected_params=4007548) 88 | 89 | def test_efficientnet_b0_fails_if_both_features_requested(self): 90 | with self.assertRaises(AssertionError): 91 | efficientnet_builder.build_model( 92 | None, 93 | model_name='efficientnet-b0', 94 | training=True, 95 | features_only=True, 96 | pooled_features_only=True) 97 | 98 | def test_efficientnet_b0_base(self): 99 | # Creates a base model using the model configuration. 100 | images = tf.zeros((1, 224, 224, 3), dtype=tf.float32) 101 | _, endpoints = efficientnet_builder.build_model_base( 102 | images, model_name='efficientnet-b0', training=True) 103 | 104 | # reduction_1 to reduction_5 should be in endpoints 105 | self.assertIn('reduction_1', endpoints) 106 | self.assertIn('reduction_5', endpoints) 107 | # reduction_5 should be the last one: no reduction_6. 108 | self.assertNotIn('reduction_6', endpoints) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.disable_v2_behavior() 113 | tf.test.main() 114 | -------------------------------------------------------------------------------- /efficientdet/object_detection/region_similarity_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Region Similarity Calculators for BoxLists. 16 | 17 | Region Similarity Calculators compare a pairwise measure of similarity 18 | between the boxes in two BoxLists. 19 | """ 20 | from abc import ABCMeta 21 | from abc import abstractmethod 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | def area(boxlist, scope=None): 27 | """Computes area of boxes. 28 | 29 | Args: 30 | boxlist: BoxList holding N boxes 31 | scope: name scope. 32 | 33 | Returns: 34 | a tensor with shape [N] representing box areas. 35 | """ 36 | with tf.name_scope(scope, 'Area'): 37 | y_min, x_min, y_max, x_max = tf.split( 38 | value=boxlist.get(), num_or_size_splits=4, axis=1) 39 | return tf.squeeze((y_max - y_min) * (x_max - x_min), [1]) 40 | 41 | 42 | def intersection(boxlist1, boxlist2, scope=None): 43 | """Compute pairwise intersection areas between boxes. 44 | 45 | Args: 46 | boxlist1: BoxList holding N boxes 47 | boxlist2: BoxList holding M boxes 48 | scope: name scope. 49 | 50 | Returns: 51 | a tensor with shape [N, M] representing pairwise intersections 52 | """ 53 | with tf.name_scope(scope, 'Intersection'): 54 | y_min1, x_min1, y_max1, x_max1 = tf.split( 55 | value=boxlist1.get(), num_or_size_splits=4, axis=1) 56 | y_min2, x_min2, y_max2, x_max2 = tf.split( 57 | value=boxlist2.get(), num_or_size_splits=4, axis=1) 58 | all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2)) 59 | all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2)) 60 | intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin) 61 | all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2)) 62 | all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2)) 63 | intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin) 64 | return intersect_heights * intersect_widths 65 | 66 | 67 | def iou(boxlist1, boxlist2, scope=None): 68 | """Computes pairwise intersection-over-union between box collections. 69 | 70 | Args: 71 | boxlist1: BoxList holding N boxes 72 | boxlist2: BoxList holding M boxes 73 | scope: name scope. 74 | 75 | Returns: 76 | a tensor with shape [N, M] representing pairwise iou scores. 77 | """ 78 | with tf.name_scope(scope, 'IOU'): 79 | intersections = intersection(boxlist1, boxlist2) 80 | areas1 = area(boxlist1) 81 | areas2 = area(boxlist2) 82 | unions = ( 83 | tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections) 84 | return tf.where( 85 | tf.equal(intersections, 0.0), 86 | tf.zeros_like(intersections), tf.truediv(intersections, unions)) 87 | 88 | 89 | class RegionSimilarityCalculator(object): 90 | """Abstract base class for region similarity calculator.""" 91 | __metaclass__ = ABCMeta 92 | 93 | def compare(self, boxlist1, boxlist2, scope=None): 94 | """Computes matrix of pairwise similarity between BoxLists. 95 | 96 | This op (to be overridden) computes a measure of pairwise similarity between 97 | the boxes in the given BoxLists. Higher values indicate more similarity. 98 | 99 | Note that this method simply measures similarity and does not explicitly 100 | perform a matching. 101 | 102 | Args: 103 | boxlist1: BoxList holding N boxes. 104 | boxlist2: BoxList holding M boxes. 105 | scope: Op scope name. Defaults to 'Compare' if None. 106 | 107 | Returns: 108 | a (float32) tensor of shape [N, M] with pairwise similarity score. 109 | """ 110 | with tf.name_scope(scope, 'Compare', [boxlist1, boxlist2]) as scope: 111 | return self._compare(boxlist1, boxlist2) 112 | 113 | @abstractmethod 114 | def _compare(self, boxlist1, boxlist2): 115 | pass 116 | 117 | 118 | class IouSimilarity(RegionSimilarityCalculator): 119 | """Class to compute similarity based on Intersection over Union (IOU) metric. 120 | 121 | This class computes pairwise similarity between two BoxLists based on IOU. 122 | """ 123 | 124 | def _compare(self, boxlist1, boxlist2): 125 | """Compute pairwise IOU similarity between the two BoxLists. 126 | 127 | Args: 128 | boxlist1: BoxList holding N boxes. 129 | boxlist2: BoxList holding M boxes. 130 | 131 | Returns: 132 | A tensor with shape [N, M] representing pairwise iou scores. 133 | """ 134 | return iou(boxlist1, boxlist2) 135 | -------------------------------------------------------------------------------- /efficientdet/dataset/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Label map utility functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import logging 22 | from six.moves import range 23 | 24 | 25 | def _validate_label_map(label_map): 26 | """Checks if a label map is valid. 27 | 28 | Args: 29 | label_map: StringIntLabelMap to validate. 30 | 31 | Raises: 32 | ValueError: if label map is invalid. 33 | """ 34 | for item in label_map.item: 35 | if item.id < 0: 36 | raise ValueError('Label map ids should be >= 0.') 37 | if (item.id == 0 and item.name != 'background' and 38 | item.display_name != 'background'): 39 | raise ValueError('Label map id 0 is reserved for the background label') 40 | 41 | 42 | def create_category_index(categories): 43 | """Creates dictionary of COCO compatible categories keyed by category id. 44 | 45 | Args: 46 | categories: a list of dicts, each of which has the following keys: 47 | 'id': (required) an integer id uniquely identifying this category. 48 | 'name': (required) string representing category name 49 | e.g., 'cat', 'dog', 'pizza'. 50 | 51 | Returns: 52 | category_index: a dict containing the same entries as categories, but keyed 53 | by the 'id' field of each category. 54 | """ 55 | category_index = {} 56 | for cat in categories: 57 | category_index[cat['id']] = cat 58 | return category_index 59 | 60 | 61 | def get_max_label_map_index(label_map): 62 | """Get maximum index in label map. 63 | 64 | Args: 65 | label_map: a StringIntLabelMapProto 66 | 67 | Returns: 68 | an integer 69 | """ 70 | return max([item.id for item in label_map.item]) 71 | 72 | 73 | def convert_label_map_to_categories(label_map, 74 | max_num_classes, 75 | use_display_name=True): 76 | """Given label map proto returns categories list compatible with eval. 77 | 78 | This function converts label map proto and returns a list of dicts, each of 79 | which has the following keys: 80 | 'id': (required) an integer id uniquely identifying this category. 81 | 'name': (required) string representing category name 82 | e.g., 'cat', 'dog', 'pizza'. 83 | 'keypoints': (optional) a dictionary of keypoint string 'label' to integer 84 | 'id'. 85 | We only allow class into the list if its id-label_id_offset is 86 | between 0 (inclusive) and max_num_classes (exclusive). 87 | If there are several items mapping to the same id in the label map, 88 | we will only keep the first one in the categories list. 89 | 90 | Args: 91 | label_map: a StringIntLabelMapProto or None. If None, a default categories 92 | list is created with max_num_classes categories. 93 | max_num_classes: maximum number of (consecutive) label indices to include. 94 | use_display_name: (boolean) choose whether to load 'display_name' field as 95 | category name. If False or if the display_name field does not exist, uses 96 | 'name' field as category names instead. 97 | 98 | Returns: 99 | categories: a list of dictionaries representing all possible categories. 100 | """ 101 | categories = [] 102 | list_of_ids_already_added = [] 103 | if not label_map: 104 | label_id_offset = 1 105 | for class_id in range(max_num_classes): 106 | categories.append({ 107 | 'id': class_id + label_id_offset, 108 | 'name': 'category_{}'.format(class_id + label_id_offset) 109 | }) 110 | return categories 111 | for item in label_map.item: 112 | if not 0 < item.id <= max_num_classes: 113 | logging.info( 114 | 'Ignore item %d since it falls outside of requested ' 115 | 'label range.', item.id) 116 | continue 117 | if use_display_name and item.HasField('display_name'): 118 | name = item.display_name 119 | else: 120 | name = item.name 121 | if item.id not in list_of_ids_already_added: 122 | list_of_ids_already_added.append(item.id) 123 | category = {'id': item.id, 'name': name} 124 | if item.keypoints: 125 | keypoints = {} 126 | list_of_keypoint_ids = [] 127 | for kv in item.keypoints: 128 | if kv.id in list_of_keypoint_ids: 129 | raise ValueError('Duplicate keypoint ids are not allowed. ' 130 | 'Found {} more than once'.format(kv.id)) 131 | keypoints[kv.label] = kv.id 132 | list_of_keypoint_ids.append(kv.id) 133 | category['keypoints'] = keypoints 134 | categories.append(category) 135 | return categories 136 | 137 | 138 | def create_class_agnostic_category_index(): 139 | """Creates a category index with a single `object` class.""" 140 | return {1: {'id': 1, 'name': 'object'}} 141 | -------------------------------------------------------------------------------- /efficientdet/object_detection/box_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base box coder. 16 | 17 | Box coders convert between coordinate frames, namely image-centric 18 | (with (0,0) on the top left of image) and anchor-centric (with (0,0) being 19 | defined by a specific anchor). 20 | 21 | Users of a BoxCoder can call two methods: 22 | encode: which encodes a box with respect to a given anchor 23 | (or rather, a tensor of boxes wrt a corresponding tensor of anchors) and 24 | decode: which inverts this encoding with a decode operation. 25 | In both cases, the arguments are assumed to be in 1-1 correspondence already; 26 | it is not the job of a BoxCoder to perform matching. 27 | """ 28 | from abc import ABCMeta 29 | from abc import abstractmethod 30 | from abc import abstractproperty 31 | 32 | import tensorflow.compat.v1 as tf 33 | 34 | 35 | # Box coder types. 36 | FASTER_RCNN = 'faster_rcnn' 37 | KEYPOINT = 'keypoint' 38 | MEAN_STDDEV = 'mean_stddev' 39 | SQUARE = 'square' 40 | 41 | 42 | class BoxCoder(object): 43 | """Abstract base class for box coder.""" 44 | __metaclass__ = ABCMeta 45 | 46 | @abstractproperty 47 | def code_size(self): 48 | """Return the size of each code. 49 | 50 | This number is a constant and should agree with the output of the `encode` 51 | op (e.g. if rel_codes is the output of self.encode(...), then it should have 52 | shape [N, code_size()]). This abstractproperty should be overridden by 53 | implementations. 54 | 55 | Returns: 56 | an integer constant 57 | """ 58 | pass 59 | 60 | def encode(self, boxes, anchors): 61 | """Encode a box list relative to an anchor collection. 62 | 63 | Args: 64 | boxes: BoxList holding N boxes to be encoded 65 | anchors: BoxList of N anchors 66 | 67 | Returns: 68 | a tensor representing N relative-encoded boxes 69 | """ 70 | with tf.name_scope('Encode'): 71 | return self._encode(boxes, anchors) 72 | 73 | def decode(self, rel_codes, anchors): 74 | """Decode boxes that are encoded relative to an anchor collection. 75 | 76 | Args: 77 | rel_codes: a tensor representing N relative-encoded boxes 78 | anchors: BoxList of anchors 79 | 80 | Returns: 81 | boxlist: BoxList holding N boxes encoded in the ordinary way (i.e., 82 | with corners y_min, x_min, y_max, x_max) 83 | """ 84 | with tf.name_scope('Decode'): 85 | return self._decode(rel_codes, anchors) 86 | 87 | @abstractmethod 88 | def _encode(self, boxes, anchors): 89 | """Method to be overridden by implementations. 90 | 91 | Args: 92 | boxes: BoxList holding N boxes to be encoded 93 | anchors: BoxList of N anchors 94 | 95 | Returns: 96 | a tensor representing N relative-encoded boxes 97 | """ 98 | pass 99 | 100 | @abstractmethod 101 | def _decode(self, rel_codes, anchors): 102 | """Method to be overridden by implementations. 103 | 104 | Args: 105 | rel_codes: a tensor representing N relative-encoded boxes 106 | anchors: BoxList of anchors 107 | 108 | Returns: 109 | boxlist: BoxList holding N boxes encoded in the ordinary way (i.e., 110 | with corners y_min, x_min, y_max, x_max) 111 | """ 112 | pass 113 | 114 | 115 | def batch_decode(encoded_boxes, box_coder, anchors): 116 | """Decode a batch of encoded boxes. 117 | 118 | This op takes a batch of encoded bounding boxes and transforms 119 | them to a batch of bounding boxes specified by their corners in 120 | the order of [y_min, x_min, y_max, x_max]. 121 | 122 | Args: 123 | encoded_boxes: a float32 tensor of shape [batch_size, num_anchors, 124 | code_size] representing the location of the objects. 125 | box_coder: a BoxCoder object. 126 | anchors: a BoxList of anchors used to encode `encoded_boxes`. 127 | 128 | Returns: 129 | decoded_boxes: a float32 tensor of shape [batch_size, num_anchors, 130 | coder_size] representing the corners of the objects in the order 131 | of [y_min, x_min, y_max, x_max]. 132 | 133 | Raises: 134 | ValueError: if batch sizes of the inputs are inconsistent, or if 135 | the number of anchors inferred from encoded_boxes and anchors are 136 | inconsistent. 137 | """ 138 | encoded_boxes.get_shape().assert_has_rank(3) 139 | if encoded_boxes.get_shape()[1].value != anchors.num_boxes_static(): 140 | raise ValueError('The number of anchors inferred from encoded_boxes' 141 | ' and anchors are inconsistent: shape[1] of encoded_boxes' 142 | ' %s should be equal to the number of anchors: %s.' % 143 | (encoded_boxes.get_shape()[1].value, 144 | anchors.num_boxes_static())) 145 | 146 | decoded_boxes = tf.stack([ 147 | box_coder.decode(boxes, anchors).get() 148 | for boxes in tf.unstack(encoded_boxes) 149 | ]) 150 | return decoded_boxes 151 | -------------------------------------------------------------------------------- /efficientdet/object_detection/box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Bounding Box List definition. 16 | 17 | BoxList represents a list of bounding boxes as tensorflow 18 | tensors, where each bounding box is represented as a row of 4 numbers, 19 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes 20 | within a given list correspond to a single image. See also 21 | box_list_ops.py for common box related operations (such as area, iou, etc). 22 | 23 | Optionally, users can add additional related fields (such as weights). 24 | We assume the following things to be true about fields: 25 | * they correspond to boxes in the box_list along the 0th dimension 26 | * they have inferable rank at graph construction time 27 | * all dimensions except for possibly the 0th can be inferred 28 | (i.e., not None) at graph construction time. 29 | 30 | Some other notes: 31 | * Following tensorflow conventions, we use height, width ordering, 32 | and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering 33 | * Tensors are always provided as (flat) [N, 4] tensors. 34 | """ 35 | 36 | import tensorflow.compat.v1 as tf 37 | 38 | 39 | class BoxList(object): 40 | """Box collection.""" 41 | 42 | def __init__(self, boxes): 43 | """Constructs box collection. 44 | 45 | Args: 46 | boxes: a tensor of shape [N, 4] representing box corners 47 | 48 | Raises: 49 | ValueError: if invalid dimensions for bbox data or if bbox data is not in 50 | float32 format. 51 | """ 52 | if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4: 53 | raise ValueError('Invalid dimensions for box data.') 54 | if boxes.dtype != tf.float32: 55 | raise ValueError('Invalid tensor type: should be tf.float32') 56 | self.data = {'boxes': boxes} 57 | 58 | def num_boxes(self): 59 | """Returns number of boxes held in collection. 60 | 61 | Returns: 62 | a tensor representing the number of boxes held in the collection. 63 | """ 64 | return tf.shape(self.data['boxes'])[0] 65 | 66 | def num_boxes_static(self): 67 | """Returns number of boxes held in collection. 68 | 69 | This number is inferred at graph construction time rather than run-time. 70 | 71 | Returns: 72 | Number of boxes held in collection (integer) or None if this is not 73 | inferable at graph construction time. 74 | """ 75 | return self.data['boxes'].get_shape()[0].value 76 | 77 | def get_all_fields(self): 78 | """Returns all fields.""" 79 | return self.data.keys() 80 | 81 | def get_extra_fields(self): 82 | """Returns all non-box fields (i.e., everything not named 'boxes').""" 83 | return [k for k in self.data.keys() if k != 'boxes'] 84 | 85 | def add_field(self, field, field_data): 86 | """Add field to box list. 87 | 88 | This method can be used to add related box data such as 89 | weights/labels, etc. 90 | 91 | Args: 92 | field: a string key to access the data via `get` 93 | field_data: a tensor containing the data to store in the BoxList 94 | """ 95 | self.data[field] = field_data 96 | 97 | def has_field(self, field): 98 | return field in self.data 99 | 100 | def get(self): 101 | """Convenience function for accessing box coordinates. 102 | 103 | Returns: 104 | a tensor with shape [N, 4] representing box coordinates. 105 | """ 106 | return self.get_field('boxes') 107 | 108 | def set(self, boxes): 109 | """Convenience function for setting box coordinates. 110 | 111 | Args: 112 | boxes: a tensor of shape [N, 4] representing box corners 113 | 114 | Raises: 115 | ValueError: if invalid dimensions for bbox data 116 | """ 117 | if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4: 118 | raise ValueError('Invalid dimensions for box data.') 119 | self.data['boxes'] = boxes 120 | 121 | def get_field(self, field): 122 | """Accesses a box collection and associated fields. 123 | 124 | This function returns specified field with object; if no field is specified, 125 | it returns the box coordinates. 126 | 127 | Args: 128 | field: this optional string parameter can be used to specify 129 | a related field to be accessed. 130 | 131 | Returns: 132 | a tensor representing the box collection or an associated field. 133 | 134 | Raises: 135 | ValueError: if invalid field 136 | """ 137 | if not self.has_field(field): 138 | raise ValueError('field ' + str(field) + ' does not exist') 139 | return self.data[field] 140 | 141 | def set_field(self, field, value): 142 | """Sets the value of a field. 143 | 144 | Updates the field of a box_list with a given value. 145 | 146 | Args: 147 | field: (string) name of the field to set value. 148 | value: the value to assign to the field. 149 | 150 | Raises: 151 | ValueError: if the box_list does not have specified field. 152 | """ 153 | if not self.has_field(field): 154 | raise ValueError('field %s does not exist' % field) 155 | self.data[field] = value 156 | 157 | def get_center_coordinates_and_sizes(self, scope=None): 158 | """Computes the center coordinates, height and width of the boxes. 159 | 160 | Args: 161 | scope: name scope of the function. 162 | 163 | Returns: 164 | a list of 4 1-D tensors [ycenter, xcenter, height, width]. 165 | """ 166 | with tf.name_scope(scope, 'get_center_coordinates_and_sizes'): 167 | box_corners = self.get() 168 | ymin, xmin, ymax, xmax = tf.unstack(tf.transpose(box_corners)) 169 | width = xmax - xmin 170 | height = ymax - ymin 171 | ycenter = ymin + height / 2. 172 | xcenter = xmin + width / 2. 173 | return [ycenter, xcenter, height, width] 174 | 175 | def transpose_coordinates(self, scope=None): 176 | """Transpose the coordinate representation in a boxlist. 177 | 178 | Args: 179 | scope: name scope of the function. 180 | """ 181 | with tf.name_scope(scope, 'transpose_coordinates'): 182 | y_min, x_min, y_max, x_max = tf.split( 183 | value=self.get(), num_or_size_splits=4, axis=1) 184 | self.set(tf.concat([x_min, y_min, x_max, y_max], 1)) 185 | 186 | def as_tensor_dict(self, fields=None): 187 | """Retrieves specified fields as a dictionary of tensors. 188 | 189 | Args: 190 | fields: (optional) list of fields to return in the dictionary. 191 | If None (default), all fields are returned. 192 | 193 | Returns: 194 | tensor_dict: A dictionary of tensors specified by fields. 195 | 196 | Raises: 197 | ValueError: if specified field is not contained in boxlist. 198 | """ 199 | tensor_dict = {} 200 | if fields is None: 201 | fields = self.get_all_fields() 202 | for field in fields: 203 | if not self.has_field(field): 204 | raise ValueError('boxlist must contain all specified fields') 205 | tensor_dict[field] = self.get_field(field) 206 | return tensor_dict 207 | -------------------------------------------------------------------------------- /efficientdet/object_detection/tf_example_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tensorflow Example proto decoder for object detection. 16 | 17 | A decoder to decode string tensors containing serialized tensorflow.Example 18 | protos for object detection. 19 | """ 20 | 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | def _get_source_id_from_encoded_image(parsed_tensors): 25 | return tf.strings.as_string( 26 | tf.strings.to_hash_bucket_fast(parsed_tensors['image/encoded'], 27 | 2**63 - 1)) 28 | 29 | 30 | class TfExampleDecoder(object): 31 | """Tensorflow Example proto decoder.""" 32 | 33 | def __init__(self, include_mask=False, regenerate_source_id=False): 34 | self._include_mask = include_mask 35 | self._regenerate_source_id = regenerate_source_id 36 | self._keys_to_features = { 37 | 'image/encoded': tf.FixedLenFeature((), tf.string), 38 | 'image/source_id': tf.FixedLenFeature((), tf.string, ''), 39 | 'image/height': tf.FixedLenFeature((), tf.int64, -1), 40 | 'image/width': tf.FixedLenFeature((), tf.int64, -1), 41 | 'image/object/bbox/xmin': tf.VarLenFeature(tf.float32), 42 | 'image/object/bbox/xmax': tf.VarLenFeature(tf.float32), 43 | 'image/object/bbox/ymin': tf.VarLenFeature(tf.float32), 44 | 'image/object/bbox/ymax': tf.VarLenFeature(tf.float32), 45 | 'image/object/class/label': tf.VarLenFeature(tf.int64), 46 | 'image/object/area': tf.VarLenFeature(tf.float32), 47 | 'image/object/is_crowd': tf.VarLenFeature(tf.int64), 48 | } 49 | if include_mask: 50 | self._keys_to_features.update({ 51 | 'image/object/mask': 52 | tf.VarLenFeature(tf.string), 53 | }) 54 | 55 | def _decode_image(self, parsed_tensors): 56 | """Decodes the image and set its static shape.""" 57 | image = tf.io.decode_image(parsed_tensors['image/encoded'], channels=3) 58 | image.set_shape([None, None, 3]) 59 | return image 60 | 61 | def _decode_boxes(self, parsed_tensors): 62 | """Concat box coordinates in the format of [ymin, xmin, ymax, xmax].""" 63 | xmin = parsed_tensors['image/object/bbox/xmin'] 64 | xmax = parsed_tensors['image/object/bbox/xmax'] 65 | ymin = parsed_tensors['image/object/bbox/ymin'] 66 | ymax = parsed_tensors['image/object/bbox/ymax'] 67 | return tf.stack([ymin, xmin, ymax, xmax], axis=-1) 68 | 69 | def _decode_masks(self, parsed_tensors): 70 | """Decode a set of PNG masks to the tf.float32 tensors.""" 71 | def _decode_png_mask(png_bytes): 72 | mask = tf.squeeze( 73 | tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1) 74 | mask = tf.cast(mask, dtype=tf.float32) 75 | mask.set_shape([None, None]) 76 | return mask 77 | 78 | height = parsed_tensors['image/height'] 79 | width = parsed_tensors['image/width'] 80 | masks = parsed_tensors['image/object/mask'] 81 | return tf.cond( 82 | tf.greater(tf.size(masks), 0), 83 | lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32), 84 | lambda: tf.zeros([0, height, width], dtype=tf.float32)) 85 | 86 | def _decode_areas(self, parsed_tensors): 87 | xmin = parsed_tensors['image/object/bbox/xmin'] 88 | xmax = parsed_tensors['image/object/bbox/xmax'] 89 | ymin = parsed_tensors['image/object/bbox/ymin'] 90 | ymax = parsed_tensors['image/object/bbox/ymax'] 91 | return tf.cond( 92 | tf.greater(tf.shape(parsed_tensors['image/object/area'])[0], 0), 93 | lambda: parsed_tensors['image/object/area'], 94 | lambda: (xmax - xmin) * (ymax - ymin)) 95 | 96 | def decode(self, serialized_example): 97 | """Decode the serialized example. 98 | 99 | Args: 100 | serialized_example: a single serialized tf.Example string. 101 | 102 | Returns: 103 | decoded_tensors: a dictionary of tensors with the following fields: 104 | - image: a uint8 tensor of shape [None, None, 3]. 105 | - source_id: a string scalar tensor. 106 | - height: an integer scalar tensor. 107 | - width: an integer scalar tensor. 108 | - groundtruth_classes: a int64 tensor of shape [None]. 109 | - groundtruth_is_crowd: a bool tensor of shape [None]. 110 | - groundtruth_area: a float32 tensor of shape [None]. 111 | - groundtruth_boxes: a float32 tensor of shape [None, 4]. 112 | - groundtruth_instance_masks: a float32 tensor of shape 113 | [None, None, None]. 114 | - groundtruth_instance_masks_png: a string tensor of shape [None]. 115 | """ 116 | parsed_tensors = tf.io.parse_single_example( 117 | serialized_example, self._keys_to_features) 118 | for k in parsed_tensors: 119 | if isinstance(parsed_tensors[k], tf.SparseTensor): 120 | if parsed_tensors[k].dtype == tf.string: 121 | parsed_tensors[k] = tf.sparse_tensor_to_dense( 122 | parsed_tensors[k], default_value='') 123 | else: 124 | parsed_tensors[k] = tf.sparse_tensor_to_dense( 125 | parsed_tensors[k], default_value=0) 126 | 127 | image = self._decode_image(parsed_tensors) 128 | boxes = self._decode_boxes(parsed_tensors) 129 | areas = self._decode_areas(parsed_tensors) 130 | 131 | decode_image_shape = tf.logical_or( 132 | tf.equal(parsed_tensors['image/height'], -1), 133 | tf.equal(parsed_tensors['image/width'], -1)) 134 | image_shape = tf.cast(tf.shape(image), dtype=tf.int64) 135 | 136 | parsed_tensors['image/height'] = tf.where(decode_image_shape, 137 | image_shape[0], 138 | parsed_tensors['image/height']) 139 | parsed_tensors['image/width'] = tf.where(decode_image_shape, image_shape[1], 140 | parsed_tensors['image/width']) 141 | 142 | is_crowds = tf.cond( 143 | tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0), 144 | lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool), 145 | lambda: tf.zeros_like(parsed_tensors['image/object/class/label'], dtype=tf.bool)) # pylint: disable=line-too-long 146 | if self._regenerate_source_id: 147 | source_id = _get_source_id_from_encoded_image(parsed_tensors) 148 | else: 149 | source_id = tf.cond( 150 | tf.greater(tf.strings.length(parsed_tensors['image/source_id']), 151 | 0), lambda: parsed_tensors['image/source_id'], 152 | lambda: _get_source_id_from_encoded_image(parsed_tensors)) 153 | if self._include_mask: 154 | masks = self._decode_masks(parsed_tensors) 155 | 156 | decoded_tensors = { 157 | 'image': image, 158 | 'source_id': source_id, 159 | 'height': parsed_tensors['image/height'], 160 | 'width': parsed_tensors['image/width'], 161 | 'groundtruth_classes': parsed_tensors['image/object/class/label'], 162 | 'groundtruth_is_crowd': is_crowds, 163 | 'groundtruth_area': areas, 164 | 'groundtruth_boxes': boxes, 165 | } 166 | if self._include_mask: 167 | decoded_tensors.update({ 168 | 'groundtruth_instance_masks': masks, 169 | 'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'], 170 | }) 171 | return decoded_tensors 172 | -------------------------------------------------------------------------------- /efficientdet/horovod_estimator/estimator.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tensorflow_estimator import estimator 3 | except ImportError: 4 | from tensorflow import estimator 5 | 6 | import tensorflow.compat.v1 as tf 7 | from tensorflow.python.framework import ops 8 | from tensorflow.python.platform import tf_logging as logging 9 | from tensorflow.python.training import basic_session_run_hooks 10 | from tensorflow.python.training import training 11 | from tensorflow.python.training import warm_starting_util 12 | from tensorflow.python.training.monitored_session import USE_DEFAULT, Scaffold, MonitoredSession, ChiefSessionCreator 13 | 14 | from .utis import hvd_info_rank0, is_rank0 15 | 16 | estimator.Estimator._assert_members_are_not_overridden = lambda _: None 17 | 18 | 19 | def MonitoredTrainingSession(master='', 20 | is_chief=True, 21 | checkpoint_dir=None, 22 | scaffold=None, 23 | hooks=None, 24 | chief_only_hooks=None, 25 | save_checkpoint_secs=USE_DEFAULT, 26 | save_summaries_steps=USE_DEFAULT, 27 | save_summaries_secs=USE_DEFAULT, 28 | config=None, 29 | stop_grace_period_secs=120, 30 | log_step_count_steps=100, 31 | save_checkpoint_steps=USE_DEFAULT, 32 | summary_dir=None): 33 | if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT: 34 | save_summaries_steps = 100 35 | save_summaries_secs = None 36 | elif save_summaries_secs == USE_DEFAULT: 37 | save_summaries_secs = None 38 | elif save_summaries_steps == USE_DEFAULT: 39 | save_summaries_steps = None 40 | 41 | if (save_checkpoint_steps == USE_DEFAULT and 42 | save_checkpoint_secs == USE_DEFAULT): 43 | save_checkpoint_steps = None 44 | save_checkpoint_secs = 600 45 | elif save_checkpoint_secs == USE_DEFAULT: 46 | save_checkpoint_secs = None 47 | elif save_checkpoint_steps == USE_DEFAULT: 48 | save_checkpoint_steps = None 49 | 50 | scaffold = scaffold or Scaffold() 51 | 52 | all_hooks = [] 53 | if is_chief and chief_only_hooks: 54 | all_hooks.extend(chief_only_hooks) 55 | 56 | session_creator = ChiefSessionCreator( 57 | scaffold=scaffold, 58 | checkpoint_dir=checkpoint_dir, 59 | master=master, 60 | config=config) 61 | 62 | summary_dir = summary_dir or checkpoint_dir 63 | if summary_dir: 64 | if (save_summaries_steps and save_summaries_steps > 0) or ( 65 | save_summaries_secs and save_summaries_secs > 0): 66 | all_hooks.append( 67 | basic_session_run_hooks.SummarySaverHook( 68 | scaffold=scaffold, 69 | save_steps=save_summaries_steps, 70 | save_secs=save_summaries_secs, 71 | output_dir=summary_dir)) 72 | 73 | if checkpoint_dir: 74 | if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( 75 | save_checkpoint_steps and save_checkpoint_steps > 0): 76 | all_hooks.append( 77 | basic_session_run_hooks.CheckpointSaverHook( 78 | checkpoint_dir, 79 | save_steps=save_checkpoint_steps, 80 | save_secs=save_checkpoint_secs, 81 | scaffold=scaffold)) 82 | 83 | if hooks: 84 | all_hooks.extend(hooks) 85 | 86 | hvd_info_rank0('all hooks {}'.format(all_hooks)) 87 | return MonitoredSession( 88 | session_creator=session_creator, 89 | hooks=all_hooks, 90 | stop_grace_period_secs=stop_grace_period_secs) 91 | 92 | 93 | class HorovodEstimator(tf.estimator.Estimator): 94 | def __init__(self, model_fn, model_dir, config=None, params=None, warm_start_from=None): 95 | super(HorovodEstimator, self).__init__(model_fn=model_fn, model_dir=model_dir, config=config, params=params, 96 | warm_start_from=warm_start_from) 97 | 98 | def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners): 99 | """Train a model with the given Estimator Spec.""" 100 | if self._warm_start_settings: 101 | logging.info('Warm-starting with WarmStartSettings: %s' % (self._warm_start_settings,)) 102 | warm_starting_util.warm_start(*self._warm_start_settings) 103 | # Check if the user created a loss summary, and add one if they didn't. 104 | # We assume here that the summary is called 'loss'. If it is not, we will 105 | # make another one with the name 'loss' to ensure it shows up in the right 106 | # graph in TensorBoard. 107 | # if not any([x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]): 108 | # summary.scalar('loss', estimator_spec.loss) 109 | ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) 110 | worker_hooks.extend(hooks) 111 | # worker_hooks.extend([ 112 | # training.NanTensorHook(estimator_spec.loss) 113 | # ]) 114 | 115 | worker_hooks.extend(estimator_spec.training_hooks) 116 | 117 | if not (estimator_spec.scaffold.saver or 118 | ops.get_collection(ops.GraphKeys.SAVERS)): 119 | ops.add_to_collection( 120 | ops.GraphKeys.SAVERS, 121 | training.Saver( 122 | sharded=True, 123 | max_to_keep=self._config.keep_checkpoint_max, 124 | keep_checkpoint_every_n_hours=( 125 | self._config.keep_checkpoint_every_n_hours), 126 | defer_build=True, 127 | save_relative_paths=True)) 128 | 129 | chief_hooks = [] 130 | all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) 131 | saver_hooks = [h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)] 132 | if (self._config.save_checkpoints_secs or 133 | self._config.save_checkpoints_steps): 134 | if not saver_hooks: 135 | chief_hooks = [ 136 | training.CheckpointSaverHook( 137 | self._model_dir, 138 | save_secs=self._config.save_checkpoints_secs, 139 | save_steps=self._config.save_checkpoints_steps, 140 | scaffold=estimator_spec.scaffold) 141 | ] 142 | saver_hooks = [chief_hooks[0]] 143 | if saving_listeners: 144 | if not saver_hooks: 145 | raise ValueError( 146 | 'There should be a CheckpointSaverHook to use saving_listeners. ' 147 | 'Please set one of the RunConfig.save_checkpoints_steps or ' 148 | 'RunConfig.save_checkpoints_secs.') 149 | else: 150 | # It is expected to have one CheckpointSaverHook. If multiple, we pick 151 | # up the first one to add listener. 152 | saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access 153 | 154 | if is_rank0(): 155 | log_step_count_steps = self._config.log_step_count_steps 156 | checkpoint_dir = self.model_dir 157 | chief_only_hooks = (tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)) 158 | else: 159 | log_step_count_steps = None 160 | checkpoint_dir = None 161 | chief_only_hooks = None 162 | 163 | with MonitoredTrainingSession( 164 | master=self._config.master, 165 | is_chief=is_rank0(), 166 | checkpoint_dir=checkpoint_dir, 167 | scaffold=estimator_spec.scaffold, 168 | hooks=worker_hooks, 169 | chief_only_hooks=chief_only_hooks, 170 | save_checkpoint_secs=0, # Saving is handled by a hook. 171 | save_summaries_steps=self._config.save_summary_steps, 172 | config=self._session_config, 173 | log_step_count_steps=log_step_count_steps) as mon_sess: 174 | loss = None 175 | while not mon_sess.should_stop(): 176 | _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) 177 | return loss 178 | -------------------------------------------------------------------------------- /efficientdet/backbone/efficientnet_lite_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Model Builder for EfficientNet Edge Models. 16 | 17 | efficientnet-litex (x=0,1,2,3,4) checkpoints are located in: 18 | https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/lite/efficientnet-litex.tar.gz 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | from absl import logging 27 | import tensorflow.compat.v1 as tf 28 | 29 | import utils 30 | from backbone import efficientnet_builder 31 | from backbone import efficientnet_model 32 | 33 | # Edge models use inception-style MEAN and STDDEV for better post-quantization. 34 | MEAN_RGB = [127.0, 127.0, 127.0] 35 | STDDEV_RGB = [128.0, 128.0, 128.0] 36 | 37 | 38 | def efficientnet_lite_params(model_name): 39 | """Get efficientnet params based on model name.""" 40 | params_dict = { 41 | # (width_coefficient, depth_coefficient, resolution, dropout_rate) 42 | 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), 43 | 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), 44 | 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), 45 | 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), 46 | 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), 47 | } 48 | return params_dict[model_name] 49 | 50 | 51 | _DEFAULT_BLOCKS_ARGS = [ 52 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 53 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 54 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 55 | 'r1_k3_s11_e6_i192_o320_se0.25', 56 | ] 57 | 58 | 59 | def efficientnet_lite(width_coefficient=None, 60 | depth_coefficient=None, 61 | dropout_rate=0.2, 62 | survival_prob=0.8): 63 | """Creates a efficientnet model.""" 64 | global_params = efficientnet_model.GlobalParams( 65 | blocks_args=_DEFAULT_BLOCKS_ARGS, 66 | batch_norm_momentum=0.99, 67 | batch_norm_epsilon=1e-3, 68 | dropout_rate=dropout_rate, 69 | survival_prob=survival_prob, 70 | data_format='channels_last', 71 | num_classes=1000, 72 | width_coefficient=width_coefficient, 73 | depth_coefficient=depth_coefficient, 74 | depth_divisor=8, 75 | min_depth=None, 76 | relu_fn=tf.nn.relu6, # Relu6 is for easier quantization. 77 | # The default is TPU-specific batch norm. 78 | # The alternative is tf.layers.BatchNormalization. 79 | batch_norm=utils.TpuBatchNormalization, # TPU-specific requirement. 80 | clip_projection_output=False, 81 | fix_head_stem=True, # Don't scale stem and head. 82 | local_pooling=True, # special cases for tflite issues. 83 | use_se=False) # SE is not well supported on many lite devices. 84 | return global_params 85 | 86 | 87 | def get_model_params(model_name, override_params): 88 | """Get the block args and global params for a given model.""" 89 | if model_name.startswith('efficientnet-lite'): 90 | width_coefficient, depth_coefficient, _, dropout_rate = ( 91 | efficientnet_lite_params(model_name)) 92 | global_params = efficientnet_lite( 93 | width_coefficient, depth_coefficient, dropout_rate) 94 | else: 95 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 96 | 97 | if override_params: 98 | # ValueError will be raised here if override_params has fields not included 99 | # in global_params. 100 | global_params = global_params._replace(**override_params) 101 | 102 | decoder = efficientnet_builder.BlockDecoder() 103 | blocks_args = decoder.decode(global_params.blocks_args) 104 | 105 | logging.info('global_params= %s', global_params) 106 | return blocks_args, global_params 107 | 108 | 109 | def build_model(images, 110 | model_name, 111 | training, 112 | override_params=None, 113 | model_dir=None, 114 | fine_tuning=False, 115 | features_only=False, 116 | pooled_features_only=False): 117 | """A helper function to create a model and return predicted logits. 118 | 119 | Args: 120 | images: input images tensor. 121 | model_name: string, the predefined model name. 122 | training: boolean, whether the model is constructed for training. 123 | override_params: A dictionary of params for overriding. Fields must exist in 124 | efficientnet_model.GlobalParams. 125 | model_dir: string, optional model dir for saving configs. 126 | fine_tuning: boolean, whether the model is used for finetuning. 127 | features_only: build the base feature network only (excluding final 128 | 1x1 conv layer, global pooling, dropout and fc head). 129 | pooled_features_only: build the base network for features extraction (after 130 | 1x1 conv layer and global pooling, but before dropout and fc head). 131 | 132 | Returns: 133 | logits: the logits tensor of classes. 134 | endpoints: the endpoints for each layer. 135 | 136 | Raises: 137 | When model_name specified an undefined model, raises NotImplementedError. 138 | When override_params has invalid fields, raises ValueError. 139 | """ 140 | assert isinstance(images, tf.Tensor) 141 | assert not (features_only and pooled_features_only) 142 | 143 | # For backward compatibility. 144 | if override_params and override_params.get('drop_connect_rate', None): 145 | override_params['survival_prob'] = 1 - override_params['drop_connect_rate'] 146 | 147 | if not training or fine_tuning: 148 | if not override_params: 149 | override_params = {} 150 | override_params['batch_norm'] = utils.BatchNormalization 151 | blocks_args, global_params = get_model_params(model_name, override_params) 152 | 153 | if model_dir: 154 | param_file = os.path.join(model_dir, 'model_params.txt') 155 | if not tf.gfile.Exists(param_file): 156 | if not tf.gfile.Exists(model_dir): 157 | tf.gfile.MakeDirs(model_dir) 158 | with tf.gfile.GFile(param_file, 'w') as f: 159 | logging.info('writing to %s', param_file) 160 | f.write('model_name= %s\n\n' % model_name) 161 | f.write('global_params= %s\n\n' % str(global_params)) 162 | f.write('blocks_args= %s\n\n' % str(blocks_args)) 163 | 164 | with tf.variable_scope(model_name): 165 | model = efficientnet_model.Model(blocks_args, global_params) 166 | outputs = model( 167 | images, 168 | training=training, 169 | features_only=features_only, 170 | pooled_features_only=pooled_features_only) 171 | if features_only: 172 | outputs = tf.identity(outputs, 'features') 173 | elif pooled_features_only: 174 | outputs = tf.identity(outputs, 'pooled_features') 175 | else: 176 | outputs = tf.identity(outputs, 'logits') 177 | return outputs, model.endpoints 178 | 179 | 180 | def build_model_base(images, model_name, training, override_params=None): 181 | """Create a base feature network and return the features before pooling. 182 | 183 | Args: 184 | images: input images tensor. 185 | model_name: string, the predefined model name. 186 | training: boolean, whether the model is constructed for training. 187 | override_params: A dictionary of params for overriding. Fields must exist in 188 | efficientnet_model.GlobalParams. 189 | 190 | Returns: 191 | features: base features before pooling. 192 | endpoints: the endpoints for each layer. 193 | 194 | Raises: 195 | When model_name specified an undefined model, raises NotImplementedError. 196 | When override_params has invalid fields, raises ValueError. 197 | """ 198 | assert isinstance(images, tf.Tensor) 199 | # For backward compatibility. 200 | if override_params and override_params.get('drop_connect_rate', None): 201 | override_params['survival_prob'] = 1 - override_params['drop_connect_rate'] 202 | 203 | blocks_args, global_params = get_model_params(model_name, override_params) 204 | 205 | with tf.variable_scope(model_name): 206 | model = efficientnet_model.Model(blocks_args, global_params) 207 | features = model(images, training=training, features_only=True) 208 | 209 | features = tf.identity(features, 'features') 210 | return features, model.endpoints 211 | -------------------------------------------------------------------------------- /efficientdet/dataset/create_coco_tfrecord_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test for create_coco_tfrecord.py.""" 16 | 17 | import io 18 | import json 19 | import os 20 | 21 | import numpy as np 22 | import PIL.Image 23 | import six 24 | import tensorflow.compat.v1 as tf 25 | 26 | from dataset import create_coco_tfrecord 27 | 28 | 29 | class CreateCocoTFRecordTest(tf.test.TestCase): 30 | 31 | def _assertProtoEqual(self, proto_field, expectation): 32 | """Helper function to assert if a proto field equals some value. 33 | 34 | Args: 35 | proto_field: The protobuf field to compare. 36 | expectation: The expected value of the protobuf field. 37 | """ 38 | proto_list = [p for p in proto_field] 39 | self.assertListEqual(proto_list, expectation) 40 | 41 | def test_create_tf_example(self): 42 | image_file_name = 'tmp_image.jpg' 43 | image_data = np.random.rand(256, 256, 3) 44 | tmp_dir = self.get_temp_dir() 45 | save_path = os.path.join(tmp_dir, image_file_name) 46 | image = PIL.Image.fromarray(image_data, 'RGB') 47 | image.save(save_path) 48 | 49 | image = { 50 | 'file_name': image_file_name, 51 | 'height': 256, 52 | 'width': 256, 53 | 'id': 11, 54 | } 55 | 56 | annotations_list = [{ 57 | 'area': .5, 58 | 'iscrowd': False, 59 | 'image_id': 11, 60 | 'bbox': [64, 64, 128, 128], 61 | 'category_id': 2, 62 | 'id': 1000, 63 | }] 64 | 65 | image_dir = tmp_dir 66 | category_index = { 67 | 1: { 68 | 'name': 'dog', 69 | 'id': 1 70 | }, 71 | 2: { 72 | 'name': 'cat', 73 | 'id': 2 74 | }, 75 | 3: { 76 | 'name': 'human', 77 | 'id': 3 78 | } 79 | } 80 | 81 | (_, example, 82 | num_annotations_skipped) = create_coco_tfrecord.create_tf_example( 83 | image, image_dir, annotations_list, category_index) 84 | 85 | self.assertEqual(num_annotations_skipped, 0) 86 | self._assertProtoEqual( 87 | example.features.feature['image/height'].int64_list.value, [256]) 88 | self._assertProtoEqual( 89 | example.features.feature['image/width'].int64_list.value, [256]) 90 | self._assertProtoEqual( 91 | example.features.feature['image/filename'].bytes_list.value, 92 | [six.b(image_file_name)]) 93 | self._assertProtoEqual( 94 | example.features.feature['image/source_id'].bytes_list.value, 95 | [six.b(str(image['id']))]) 96 | self._assertProtoEqual( 97 | example.features.feature['image/format'].bytes_list.value, 98 | [six.b('jpeg')]) 99 | self._assertProtoEqual( 100 | example.features.feature['image/object/bbox/xmin'].float_list.value, 101 | [0.25]) 102 | self._assertProtoEqual( 103 | example.features.feature['image/object/bbox/ymin'].float_list.value, 104 | [0.25]) 105 | self._assertProtoEqual( 106 | example.features.feature['image/object/bbox/xmax'].float_list.value, 107 | [0.75]) 108 | self._assertProtoEqual( 109 | example.features.feature['image/object/bbox/ymax'].float_list.value, 110 | [0.75]) 111 | self._assertProtoEqual( 112 | example.features.feature['image/object/class/text'].bytes_list.value, 113 | [six.b('cat')]) 114 | 115 | def test_create_tf_example_with_instance_masks(self): 116 | image_file_name = 'tmp_image.jpg' 117 | image_data = np.random.rand(8, 8, 3) 118 | tmp_dir = self.get_temp_dir() 119 | save_path = os.path.join(tmp_dir, image_file_name) 120 | image = PIL.Image.fromarray(image_data, 'RGB') 121 | image.save(save_path) 122 | 123 | image = { 124 | 'file_name': image_file_name, 125 | 'height': 8, 126 | 'width': 8, 127 | 'id': 11, 128 | } 129 | 130 | annotations_list = [{ 131 | 'area': .5, 132 | 'iscrowd': False, 133 | 'image_id': 11, 134 | 'bbox': [0, 0, 8, 8], 135 | 'segmentation': [[4, 0, 0, 0, 0, 4], [8, 4, 4, 8, 8, 8]], 136 | 'category_id': 1, 137 | 'id': 1000, 138 | }] 139 | 140 | image_dir = tmp_dir 141 | category_index = { 142 | 1: { 143 | 'name': 'dog', 144 | 'id': 1 145 | }, 146 | } 147 | 148 | (_, example, 149 | num_annotations_skipped) = create_coco_tfrecord.create_tf_example( 150 | image, image_dir, annotations_list, category_index, include_masks=True) 151 | 152 | self.assertEqual(num_annotations_skipped, 0) 153 | self._assertProtoEqual( 154 | example.features.feature['image/height'].int64_list.value, [8]) 155 | self._assertProtoEqual( 156 | example.features.feature['image/width'].int64_list.value, [8]) 157 | self._assertProtoEqual( 158 | example.features.feature['image/filename'].bytes_list.value, 159 | [six.b(image_file_name)]) 160 | self._assertProtoEqual( 161 | example.features.feature['image/source_id'].bytes_list.value, 162 | [six.b(str(image['id']))]) 163 | self._assertProtoEqual( 164 | example.features.feature['image/format'].bytes_list.value, 165 | [six.b('jpeg')]) 166 | self._assertProtoEqual( 167 | example.features.feature['image/object/bbox/xmin'].float_list.value, 168 | [0]) 169 | self._assertProtoEqual( 170 | example.features.feature['image/object/bbox/ymin'].float_list.value, 171 | [0]) 172 | self._assertProtoEqual( 173 | example.features.feature['image/object/bbox/xmax'].float_list.value, 174 | [1]) 175 | self._assertProtoEqual( 176 | example.features.feature['image/object/bbox/ymax'].float_list.value, 177 | [1]) 178 | self._assertProtoEqual( 179 | example.features.feature['image/object/class/text'].bytes_list.value, 180 | [six.b('dog')]) 181 | encoded_mask_pngs = [ 182 | io.BytesIO(encoded_masks) for encoded_masks in example.features.feature[ 183 | 'image/object/mask'].bytes_list.value 184 | ] 185 | pil_masks = [ 186 | np.array(PIL.Image.open(encoded_mask_png)) 187 | for encoded_mask_png in encoded_mask_pngs 188 | ] 189 | self.assertEqual(len(pil_masks), 1) 190 | self.assertAllEqual(pil_masks[0], 191 | [[1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0], 192 | [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], 193 | [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 1, 1], 194 | [0, 0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 1, 1, 1, 1]]) 195 | 196 | def test_create_sharded_tf_record(self): 197 | tmp_dir = self.get_temp_dir() 198 | image_paths = ['tmp1_image.jpg', 'tmp2_image.jpg'] 199 | for image_path in image_paths: 200 | image_data = np.random.rand(256, 256, 3) 201 | save_path = os.path.join(tmp_dir, image_path) 202 | image = PIL.Image.fromarray(image_data, 'RGB') 203 | image.save(save_path) 204 | 205 | images = [{ 206 | 'file_name': image_paths[0], 207 | 'height': 256, 208 | 'width': 256, 209 | 'id': 11, 210 | }, { 211 | 'file_name': image_paths[1], 212 | 'height': 256, 213 | 'width': 256, 214 | 'id': 12, 215 | }] 216 | 217 | annotations = [{ 218 | 'area': .5, 219 | 'iscrowd': False, 220 | 'image_id': 11, 221 | 'bbox': [64, 64, 128, 128], 222 | 'category_id': 2, 223 | 'id': 1000, 224 | }] 225 | 226 | category_index = [{ 227 | 'name': 'dog', 228 | 'id': 1 229 | }, { 230 | 'name': 'cat', 231 | 'id': 2 232 | }, { 233 | 'name': 'human', 234 | 'id': 3 235 | }] 236 | groundtruth_data = {'images': images, 'annotations': annotations, 237 | 'categories': category_index} 238 | annotation_file = os.path.join(tmp_dir, 'annotation.json') 239 | with open(annotation_file, 'w') as annotation_fid: 240 | json.dump(groundtruth_data, annotation_fid) 241 | 242 | output_path = os.path.join(tmp_dir, 'out') 243 | create_coco_tfrecord._create_tf_record_from_coco_annotations( 244 | annotation_file, 245 | tmp_dir, 246 | output_path, 247 | num_shards=2, 248 | include_masks=False) 249 | self.assertTrue(os.path.exists(output_path + '-00000-of-00002.tfrecord')) 250 | self.assertTrue(os.path.exists(output_path + '-00001-of-00002.tfrecord')) 251 | 252 | 253 | if __name__ == '__main__': 254 | tf.test.main() 255 | -------------------------------------------------------------------------------- /efficientdet/coco_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """COCO-style evaluation metrics. 16 | 17 | Implements the interface of COCO API and metric_fn in tf.TPUEstimator. 18 | 19 | COCO API: github.com/cocodataset/cocoapi/ 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | 27 | import json 28 | import os 29 | from absl import flags 30 | from absl import logging 31 | 32 | import numpy as np 33 | from pycocotools.coco import COCO 34 | from pycocotools.cocoeval import COCOeval 35 | 36 | import tensorflow.compat.v1 as tf 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | class EvaluationMetric(object): 42 | """COCO evaluation metric class.""" 43 | 44 | def __init__(self, filename=None, testdev_dir=None): 45 | """Constructs COCO evaluation class. 46 | 47 | The class provides the interface to metrics_fn in TPUEstimator. The 48 | _update_op() takes detections from each image and push them to 49 | self.detections. The _evaluate() loads a JSON file in COCO annotation format 50 | as the groundtruth and runs COCO evaluation. 51 | 52 | Args: 53 | filename: Ground truth JSON file name. If filename is None, use 54 | groundtruth data passed from the dataloader for evaluation. filename is 55 | ignored if testdev_dir is not None. 56 | testdev_dir: folder name for testdev data. If None, run eval without 57 | groundtruth, and filename will be ignored. 58 | """ 59 | if filename: 60 | self.coco_gt = COCO(filename) 61 | self.filename = filename 62 | self.testdev_dir = testdev_dir 63 | self.metric_names = ['AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'ARmax1', 64 | 'ARmax10', 'ARmax100', 'ARs', 'ARm', 'ARl'] 65 | self._reset() 66 | 67 | def _reset(self): 68 | """Reset COCO API object.""" 69 | if self.filename is None: 70 | self.coco_gt = COCO() 71 | self.detections = [] 72 | self.dataset = { 73 | 'images': [], 74 | 'annotations': [], 75 | 'categories': [] 76 | } 77 | self.image_id = 1 78 | self.annotation_id = 1 79 | self.category_ids = [] 80 | 81 | def estimator_metric_fn(self, detections, groundtruth_data): 82 | """Constructs the metric function for tf.TPUEstimator. 83 | 84 | For each metric, we return the evaluation op and an update op; the update op 85 | is shared across all metrics and simply appends the set of detections to the 86 | `self.detections` list. The metric op is invoked after all examples have 87 | been seen and computes the aggregate COCO metrics. Please find details API 88 | in: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/MetricSpec 89 | Args: 90 | detections: Detection results in a tensor with each row representing 91 | [image_id, x, y, width, height, score, class] 92 | groundtruth_data: Groundtruth annotations in a tensor with each row 93 | representing [y1, x1, y2, x2, is_crowd, area, class]. 94 | Returns: 95 | metrics_dict: A dictionary mapping from evaluation name to a tuple of 96 | operations (`metric_op`, `update_op`). `update_op` appends the 97 | detections for the metric to the `self.detections` list. 98 | """ 99 | 100 | def _evaluate(): 101 | """Evaluates with detections from all images with COCO API. 102 | 103 | Returns: 104 | coco_metric: float numpy array with shape [12] representing the 105 | coco-style evaluation metrics. 106 | """ 107 | if self.filename is None: 108 | self.coco_gt.dataset = self.dataset 109 | self.coco_gt.createIndex() 110 | 111 | if self.testdev_dir: 112 | # Run on test-dev dataset. 113 | box_result_list = [] 114 | for det in self.detections: 115 | box_result_list.append({ 116 | 'image_id': int(det[0]), 117 | 'category_id': int(det[6]), 118 | 'bbox': np.around( 119 | det[1:5].astype(np.float64), decimals=2).tolist(), 120 | 'score': float(np.around(det[5], decimals=3)), 121 | }) 122 | json.encoder.FLOAT_REPR = lambda o: format(o, '.3f') 123 | output_path = os.path.join(self.testdev_dir, 124 | 'detections_test-dev2017_test_results.json') 125 | logging.info('Writing output json file to: %s', output_path) 126 | with tf.io.gfile.GFile(output_path, 'w') as fid: 127 | json.dump(box_result_list, fid) 128 | 129 | self._reset() 130 | return np.array([0.], dtype=np.float32) 131 | else: 132 | # Run on validation dataset. 133 | detections = np.array(self.detections) 134 | image_ids = list(set(detections[:, 0])) 135 | coco_dt = self.coco_gt.loadRes(detections) 136 | coco_eval = COCOeval(self.coco_gt, coco_dt, iouType='bbox') 137 | coco_eval.params.imgIds = image_ids 138 | coco_eval.evaluate() 139 | coco_eval.accumulate() 140 | coco_eval.summarize() 141 | coco_metrics = coco_eval.stats 142 | # clean self.detections after evaluation is done. 143 | # this makes sure the next evaluation will start with an empty list of 144 | # self.detections. 145 | self._reset() 146 | return np.array(coco_metrics, dtype=np.float32) 147 | 148 | def _update_op(detections, groundtruth_data): 149 | """Update detection results and groundtruth data. 150 | 151 | Append detection results to self.detections to aggregate results from 152 | all validation set. The groundtruth_data is parsed and added into a 153 | dictionary with the same format as COCO dataset, which can be used for 154 | evaluation. 155 | 156 | Args: 157 | detections: Detection results in a tensor with each row representing 158 | [image_id, x, y, width, height, score, class]. 159 | groundtruth_data: Groundtruth annotations in a tensor with each row 160 | representing [y1, x1, y2, x2, is_crowd, area, class]. 161 | """ 162 | for i in range(len(detections)): 163 | # Filter out detections with predicted class label = -1. 164 | indices = np.where(detections[i, :, -1] > -1)[0] 165 | detections[i] = detections[i, indices] 166 | if detections[i].shape[0] == 0: 167 | continue 168 | # Append groundtruth annotations to create COCO dataset object. 169 | # Add images. 170 | image_id = detections[i][0, 0] 171 | if image_id == -1: 172 | image_id = self.image_id 173 | detections[i][:, 0] = image_id 174 | self.detections.extend(detections[i]) 175 | 176 | if self.testdev_dir: 177 | # Skip annotation for test-dev case. 178 | self.image_id += 1 179 | continue 180 | 181 | self.dataset['images'].append({ 182 | 'id': int(image_id), 183 | }) 184 | 185 | # Add annotations. 186 | indices = np.where(groundtruth_data[i, :, -1] > -1)[0] 187 | for data in groundtruth_data[i, indices]: 188 | box = data[0:4] 189 | is_crowd = data[4] 190 | area = data[5] 191 | category_id = data[6] 192 | if category_id < 0: 193 | break 194 | if area == -1: 195 | area = (box[3] - box[1]) * (box[2] - box[0]) 196 | self.dataset['annotations'].append({ 197 | 'id': int(self.annotation_id), 198 | 'image_id': int(image_id), 199 | 'category_id': int(category_id), 200 | 'bbox': [box[1], box[0], box[3] - box[1], box[2] - box[0]], 201 | 'area': area, 202 | 'iscrowd': int(is_crowd) 203 | }) 204 | self.annotation_id += 1 205 | self.category_ids.append(category_id) 206 | self.image_id += 1 207 | self.category_ids = list(set(self.category_ids)) 208 | self.dataset['categories'] = [ 209 | {'id': int(category_id)} for category_id in self.category_ids 210 | ] 211 | 212 | with tf.name_scope('coco_metric'): 213 | if self.testdev_dir: 214 | update_op = tf.py_func(_update_op, [detections, groundtruth_data], []) 215 | metrics = tf.py_func(_evaluate, [], tf.float32) 216 | metrics_dict = {'AP': (metrics, update_op)} 217 | return metrics_dict 218 | else: 219 | update_op = tf.py_func(_update_op, [detections, groundtruth_data], []) 220 | metrics = tf.py_func(_evaluate, [], tf.float32) 221 | metrics_dict = {} 222 | for i, name in enumerate(self.metric_names): 223 | metrics_dict[name] = (metrics[i], update_op) 224 | return metrics_dict 225 | -------------------------------------------------------------------------------- /efficientdet/backbone/efficientnet_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for efficientnet_model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | 23 | import utils 24 | from backbone import efficientnet_model 25 | 26 | 27 | class ModelTest(tf.test.TestCase): 28 | 29 | def test_bottleneck_block(self): 30 | """Test for creating a model with bottleneck block arguments.""" 31 | images = tf.zeros((10, 128, 128, 3), dtype=tf.float32) 32 | global_params = efficientnet_model.GlobalParams( 33 | 1.0, 34 | 1.0, 35 | 0, 36 | 'channels_last', 37 | num_classes=10, 38 | batch_norm=utils.TpuBatchNormalization) 39 | blocks_args = [ 40 | efficientnet_model.BlockArgs( 41 | kernel_size=3, 42 | num_repeat=3, 43 | input_filters=3, 44 | output_filters=6, 45 | expand_ratio=6, 46 | id_skip=True, 47 | strides=[2, 2], 48 | conv_type=0, 49 | fused_conv=0, 50 | super_pixel=0) 51 | ] 52 | model = efficientnet_model.Model(blocks_args, global_params) 53 | outputs = model(images, training=True) 54 | self.assertEqual((10, 10), outputs.shape) 55 | 56 | def test_fused_bottleneck_block(self): 57 | """Test for creating a model with fused bottleneck block arguments.""" 58 | images = tf.zeros((10, 128, 128, 3), dtype=tf.float32) 59 | global_params = efficientnet_model.GlobalParams( 60 | 1.0, 61 | 1.0, 62 | 0, 63 | 'channels_last', 64 | num_classes=10, 65 | batch_norm=utils.TpuBatchNormalization) 66 | blocks_args = [ 67 | efficientnet_model.BlockArgs( 68 | kernel_size=3, 69 | num_repeat=3, 70 | input_filters=3, 71 | output_filters=6, 72 | expand_ratio=6, 73 | id_skip=True, 74 | strides=[2, 2], 75 | conv_type=0, 76 | fused_conv=1, 77 | super_pixel=0) 78 | ] 79 | model = efficientnet_model.Model(blocks_args, global_params) 80 | outputs = model(images, training=True) 81 | self.assertEqual((10, 10), outputs.shape) 82 | 83 | def test_bottleneck_block_with_superpixel_layer(self): 84 | """Test for creating a model with fused bottleneck block arguments.""" 85 | images = tf.zeros((10, 128, 128, 3), dtype=tf.float32) 86 | global_params = efficientnet_model.GlobalParams( 87 | 1.0, 88 | 1.0, 89 | 0, 90 | 'channels_last', 91 | num_classes=10, 92 | batch_norm=utils.TpuBatchNormalization) 93 | blocks_args = [ 94 | efficientnet_model.BlockArgs( 95 | kernel_size=3, 96 | num_repeat=3, 97 | input_filters=3, 98 | output_filters=6, 99 | expand_ratio=6, 100 | id_skip=True, 101 | strides=[2, 2], 102 | conv_type=0, 103 | fused_conv=0, 104 | super_pixel=1) 105 | ] 106 | model = efficientnet_model.Model(blocks_args, global_params) 107 | outputs = model(images, training=True) 108 | self.assertEqual((10, 10), outputs.shape) 109 | 110 | def test_bottleneck_block_with_superpixel_tranformation(self): 111 | """Test for creating a model with fused bottleneck block arguments.""" 112 | images = tf.zeros((10, 128, 128, 3), dtype=tf.float32) 113 | global_params = efficientnet_model.GlobalParams( 114 | 1.0, 115 | 1.0, 116 | 0, 117 | 'channels_last', 118 | num_classes=10, 119 | batch_norm=utils.TpuBatchNormalization) 120 | blocks_args = [ 121 | efficientnet_model.BlockArgs( 122 | kernel_size=3, 123 | num_repeat=3, 124 | input_filters=3, 125 | output_filters=6, 126 | expand_ratio=6, 127 | id_skip=True, 128 | strides=[2, 2], 129 | conv_type=0, 130 | fused_conv=0, 131 | super_pixel=2) 132 | ] 133 | model = efficientnet_model.Model(blocks_args, global_params) 134 | outputs = model(images, training=True) 135 | self.assertEqual((10, 10), outputs.shape) 136 | 137 | def test_se_block(self): 138 | """Test for creating a model with SE block arguments.""" 139 | images = tf.zeros((10, 128, 128, 3), dtype=tf.float32) 140 | global_params = efficientnet_model.GlobalParams( 141 | 1.0, 142 | 1.0, 143 | 0, 144 | 'channels_last', 145 | num_classes=10, 146 | batch_norm=utils.TpuBatchNormalization) 147 | blocks_args = [ 148 | efficientnet_model.BlockArgs( 149 | kernel_size=3, 150 | num_repeat=3, 151 | input_filters=3, 152 | output_filters=6, 153 | expand_ratio=6, 154 | id_skip=False, 155 | strides=[2, 2], 156 | se_ratio=0.8, 157 | conv_type=0, 158 | fused_conv=0, 159 | super_pixel=0) 160 | ] 161 | model = efficientnet_model.Model(blocks_args, global_params) 162 | outputs = model(images, training=True) 163 | self.assertEqual((10, 10), outputs.shape) 164 | 165 | def test_variables(self): 166 | """Test for variables in blocks to be included in `model.variables`.""" 167 | images = tf.zeros((10, 128, 128, 3), dtype=tf.float32) 168 | global_params = efficientnet_model.GlobalParams( 169 | 1.0, 170 | 1.0, 171 | 0, 172 | 'channels_last', 173 | num_classes=10, 174 | batch_norm=utils.TpuBatchNormalization) 175 | blocks_args = [ 176 | efficientnet_model.BlockArgs( 177 | kernel_size=3, 178 | num_repeat=3, 179 | input_filters=3, 180 | output_filters=6, 181 | expand_ratio=6, 182 | id_skip=False, 183 | strides=[2, 2], 184 | se_ratio=0.8, 185 | conv_type=0, 186 | fused_conv=0, 187 | super_pixel=0) 188 | ] 189 | model = efficientnet_model.Model(blocks_args, global_params) 190 | _ = model(images, training=True) 191 | var_names = {var.name for var in model.variables} 192 | self.assertIn('blocks_0/conv2d/kernel:0', var_names) 193 | 194 | def test_reduction_endpoint_with_single_block_with_sp(self): 195 | """Test reduction point with single block/layer.""" 196 | images = tf.zeros((10, 128, 128, 3), dtype=tf.float32) 197 | global_params = efficientnet_model.GlobalParams( 198 | 1.0, 199 | 1.0, 200 | 0, 201 | 'channels_last', 202 | num_classes=10, 203 | batch_norm=utils.TpuBatchNormalization) 204 | blocks_args = [ 205 | efficientnet_model.BlockArgs( 206 | kernel_size=3, 207 | num_repeat=1, 208 | input_filters=3, 209 | output_filters=6, 210 | expand_ratio=6, 211 | id_skip=False, 212 | strides=[2, 2], 213 | se_ratio=0.8, 214 | conv_type=0, 215 | fused_conv=0, 216 | super_pixel=1) 217 | ] 218 | model = efficientnet_model.Model(blocks_args, global_params) 219 | _ = model(images, training=True) 220 | self.assertIn('reduction_1', model.endpoints) 221 | # single block should have one and only one reduction endpoint 222 | self.assertNotIn('reduction_2', model.endpoints) 223 | 224 | def test_reduction_endpoint_with_single_block_without_sp(self): 225 | """Test reduction point with single block/layer.""" 226 | images = tf.zeros((10, 128, 128, 3), dtype=tf.float32) 227 | global_params = efficientnet_model.GlobalParams( 228 | 1.0, 229 | 1.0, 230 | 0, 231 | 'channels_last', 232 | num_classes=10, 233 | batch_norm=utils.TpuBatchNormalization) 234 | blocks_args = [ 235 | efficientnet_model.BlockArgs( 236 | kernel_size=3, 237 | num_repeat=1, 238 | input_filters=3, 239 | output_filters=6, 240 | expand_ratio=6, 241 | id_skip=False, 242 | strides=[2, 2], 243 | se_ratio=0.8, 244 | conv_type=0, 245 | fused_conv=0, 246 | super_pixel=0) 247 | ] 248 | model = efficientnet_model.Model(blocks_args, global_params) 249 | _ = model(images, training=True) 250 | self.assertIn('reduction_1', model.endpoints) 251 | # single block should have one and only one reduction endpoint 252 | self.assertNotIn('reduction_2', model.endpoints) 253 | 254 | if __name__ == '__main__': 255 | tf.disable_v2_behavior() 256 | tf.test.main() 257 | -------------------------------------------------------------------------------- /efficientdet/object_detection/argmax_matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Argmax matcher implementation. 16 | 17 | This class takes a similarity matrix and matches columns to rows based on the 18 | maximum value per column. One can specify matched_thresholds and 19 | to prevent columns from matching to rows (generally resulting in a negative 20 | training example) and unmatched_theshold to ignore the match (generally 21 | resulting in neither a positive or negative training example). 22 | 23 | This matcher is used in Fast(er)-RCNN. 24 | 25 | Note: matchers are used in TargetAssigners. There is a create_target_assigner 26 | factory function for popular implementations. 27 | """ 28 | import tensorflow.compat.v1 as tf 29 | 30 | from object_detection import matcher 31 | from object_detection import shape_utils 32 | 33 | 34 | class ArgMaxMatcher(matcher.Matcher): 35 | """Matcher based on highest value. 36 | 37 | This class computes matches from a similarity matrix. Each column is matched 38 | to a single row. 39 | 40 | To support object detection target assignment this class enables setting both 41 | matched_threshold (upper threshold) and unmatched_threshold (lower thresholds) 42 | defining three categories of similarity which define whether examples are 43 | positive, negative, or ignored: 44 | (1) similarity >= matched_threshold: Highest similarity. Matched/Positive! 45 | (2) matched_threshold > similarity >= unmatched_threshold: Medium similarity. 46 | Depending on negatives_lower_than_unmatched, this is either 47 | Unmatched/Negative OR Ignore. 48 | (3) unmatched_threshold > similarity: Lowest similarity. Depending on flag 49 | negatives_lower_than_unmatched, either Unmatched/Negative OR Ignore. 50 | For ignored matches this class sets the values in the Match object to -2. 51 | """ 52 | 53 | def __init__(self, 54 | matched_threshold, 55 | unmatched_threshold=None, 56 | negatives_lower_than_unmatched=True, 57 | force_match_for_each_row=False): 58 | """Construct ArgMaxMatcher. 59 | 60 | Args: 61 | matched_threshold: Threshold for positive matches. Positive if 62 | sim >= matched_threshold, where sim is the maximum value of the 63 | similarity matrix for a given column. Set to None for no threshold. 64 | unmatched_threshold: Threshold for negative matches. Negative if 65 | sim < unmatched_threshold. Defaults to matched_threshold 66 | when set to None. 67 | negatives_lower_than_unmatched: Boolean which defaults to True. If True 68 | then negative matches are the ones below the unmatched_threshold, 69 | whereas ignored matches are in between the matched and unmatched 70 | threshold. If False, then negative matches are in between the matched 71 | and unmatched threshold, and everything lower than unmatched is ignored. 72 | force_match_for_each_row: If True, ensures that each row is matched to 73 | at least one column (which is not guaranteed otherwise if the 74 | matched_threshold is high). Defaults to False. See 75 | argmax_matcher_test.testMatcherForceMatch() for an example. 76 | 77 | Raises: 78 | ValueError: if unmatched_threshold is set but matched_threshold is not set 79 | or if unmatched_threshold > matched_threshold. 80 | """ 81 | if (matched_threshold is None) and (unmatched_threshold is not None): 82 | raise ValueError('Need to also define matched_threshold when' 83 | 'unmatched_threshold is defined') 84 | self._matched_threshold = matched_threshold 85 | if unmatched_threshold is None: 86 | self._unmatched_threshold = matched_threshold 87 | else: 88 | if unmatched_threshold > matched_threshold: 89 | raise ValueError('unmatched_threshold needs to be smaller or equal' 90 | 'to matched_threshold') 91 | self._unmatched_threshold = unmatched_threshold 92 | if not negatives_lower_than_unmatched: 93 | if self._unmatched_threshold == self._matched_threshold: 94 | raise ValueError('When negatives are in between matched and ' 95 | 'unmatched thresholds, these cannot be of equal ' 96 | 'value. matched: %s, unmatched: %s', 97 | self._matched_threshold, self._unmatched_threshold) 98 | self._force_match_for_each_row = force_match_for_each_row 99 | self._negatives_lower_than_unmatched = negatives_lower_than_unmatched 100 | 101 | def _match(self, similarity_matrix): 102 | """Tries to match each column of the similarity matrix to a row. 103 | 104 | Args: 105 | similarity_matrix: tensor of shape [N, M] representing any similarity 106 | metric. 107 | 108 | Returns: 109 | Match object with corresponding matches for each of M columns. 110 | """ 111 | 112 | def _match_when_rows_are_empty(): 113 | """Performs matching when the rows of similarity matrix are empty. 114 | 115 | When the rows are empty, all detections are false positives. So we return 116 | a tensor of -1's to indicate that the columns do not match to any rows. 117 | 118 | Returns: 119 | matches: int32 tensor indicating the row each column matches to. 120 | """ 121 | similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape( 122 | similarity_matrix) 123 | return -1 * tf.ones([similarity_matrix_shape[1]], dtype=tf.int32) 124 | 125 | def _match_when_rows_are_non_empty(): 126 | """Performs matching when the rows of similarity matrix are non empty. 127 | 128 | Returns: 129 | matches: int32 tensor indicating the row each column matches to. 130 | """ 131 | # Matches for each column 132 | matches = tf.argmax(similarity_matrix, 0, output_type=tf.int32) 133 | 134 | # Deal with matched and unmatched threshold 135 | if self._matched_threshold is not None: 136 | # Get logical indices of ignored and unmatched columns as tf.int64 137 | matched_vals = tf.reduce_max(similarity_matrix, 0) 138 | below_unmatched_threshold = tf.greater(self._unmatched_threshold, 139 | matched_vals) 140 | between_thresholds = tf.logical_and( 141 | tf.greater_equal(matched_vals, self._unmatched_threshold), 142 | tf.greater(self._matched_threshold, matched_vals)) 143 | 144 | if self._negatives_lower_than_unmatched: 145 | matches = self._set_values_using_indicator(matches, 146 | below_unmatched_threshold, 147 | -1) 148 | matches = self._set_values_using_indicator(matches, 149 | between_thresholds, 150 | -2) 151 | else: 152 | matches = self._set_values_using_indicator(matches, 153 | below_unmatched_threshold, 154 | -2) 155 | matches = self._set_values_using_indicator(matches, 156 | between_thresholds, 157 | -1) 158 | 159 | if self._force_match_for_each_row: 160 | similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape( 161 | similarity_matrix) 162 | force_match_column_ids = tf.argmax(similarity_matrix, 1, 163 | output_type=tf.int32) 164 | force_match_column_indicators = tf.one_hot( 165 | force_match_column_ids, depth=similarity_matrix_shape[1]) 166 | force_match_row_ids = tf.argmax(force_match_column_indicators, 0, 167 | output_type=tf.int32) 168 | force_match_column_mask = tf.cast( 169 | tf.reduce_max(force_match_column_indicators, 0), tf.bool) 170 | final_matches = tf.where(force_match_column_mask, 171 | force_match_row_ids, matches) 172 | return final_matches 173 | else: 174 | return matches 175 | 176 | if similarity_matrix.shape.is_fully_defined(): 177 | if similarity_matrix.shape[0].value == 0: 178 | return _match_when_rows_are_empty() 179 | else: 180 | return _match_when_rows_are_non_empty() 181 | else: 182 | return tf.cond( 183 | tf.greater(tf.shape(similarity_matrix)[0], 0), 184 | _match_when_rows_are_non_empty, _match_when_rows_are_empty) 185 | 186 | def _set_values_using_indicator(self, x, indicator, val): 187 | """Set the indicated fields of x to val. 188 | 189 | Args: 190 | x: tensor. 191 | indicator: boolean with same shape as x. 192 | val: scalar with value to set. 193 | 194 | Returns: 195 | modified tensor. 196 | """ 197 | indicator = tf.cast(indicator, x.dtype) 198 | return tf.add(tf.multiply(x, 1 - indicator), val * indicator) 199 | -------------------------------------------------------------------------------- /efficientdet/object_detection/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Matcher interface and Match class. 16 | 17 | This module defines the Matcher interface and the Match object. The job of the 18 | matcher is to match row and column indices based on the similarity matrix and 19 | other optional parameters. Each column is matched to at most one row. There 20 | are three possibilities for the matching: 21 | 22 | 1) match: A column matches a row. 23 | 2) no_match: A column does not match any row. 24 | 3) ignore: A column that is neither 'match' nor no_match. 25 | 26 | The ignore case is regularly encountered in object detection: when an anchor has 27 | a relatively small overlap with a ground-truth box, one neither wants to 28 | consider this box a positive example (match) nor a negative example (no match). 29 | 30 | The Match class is used to store the match results and it provides simple apis 31 | to query the results. 32 | """ 33 | from abc import ABCMeta 34 | from abc import abstractmethod 35 | 36 | import tensorflow.compat.v1 as tf 37 | 38 | 39 | class Match(object): 40 | """Class to store results from the matcher. 41 | 42 | This class is used to store the results from the matcher. It provides 43 | convenient methods to query the matching results. 44 | """ 45 | 46 | def __init__(self, match_results): 47 | """Constructs a Match object. 48 | 49 | Args: 50 | match_results: Integer tensor of shape [N] with (1) match_results[i]>=0, 51 | meaning that column i is matched with row match_results[i]. 52 | (2) match_results[i]=-1, meaning that column i is not matched. 53 | (3) match_results[i]=-2, meaning that column i is ignored. 54 | 55 | Raises: 56 | ValueError: if match_results does not have rank 1 or is not an 57 | integer int32 scalar tensor 58 | """ 59 | if match_results.shape.ndims != 1: 60 | raise ValueError('match_results should have rank 1') 61 | if match_results.dtype != tf.int32: 62 | raise ValueError('match_results should be an int32 or int64 scalar ' 63 | 'tensor') 64 | self._match_results = match_results 65 | 66 | @property 67 | def match_results(self): 68 | """The accessor for match results. 69 | 70 | Returns: 71 | the tensor which encodes the match results. 72 | """ 73 | return self._match_results 74 | 75 | def matched_column_indices(self): 76 | """Returns column indices that match to some row. 77 | 78 | The indices returned by this op are always sorted in increasing order. 79 | 80 | Returns: 81 | column_indices: int32 tensor of shape [K] with column indices. 82 | """ 83 | return self._reshape_and_cast(tf.where(tf.greater(self._match_results, -1))) 84 | 85 | def matched_column_indicator(self): 86 | """Returns column indices that are matched. 87 | 88 | Returns: 89 | column_indices: int32 tensor of shape [K] with column indices. 90 | """ 91 | return tf.greater_equal(self._match_results, 0) 92 | 93 | def num_matched_columns(self): 94 | """Returns number (int32 scalar tensor) of matched columns.""" 95 | return tf.size(self.matched_column_indices()) 96 | 97 | def unmatched_column_indices(self): 98 | """Returns column indices that do not match any row. 99 | 100 | The indices returned by this op are always sorted in increasing order. 101 | 102 | Returns: 103 | column_indices: int32 tensor of shape [K] with column indices. 104 | """ 105 | return self._reshape_and_cast(tf.where(tf.equal(self._match_results, -1))) 106 | 107 | def unmatched_column_indicator(self): 108 | """Returns column indices that are unmatched. 109 | 110 | Returns: 111 | column_indices: int32 tensor of shape [K] with column indices. 112 | """ 113 | return tf.equal(self._match_results, -1) 114 | 115 | def num_unmatched_columns(self): 116 | """Returns number (int32 scalar tensor) of unmatched columns.""" 117 | return tf.size(self.unmatched_column_indices()) 118 | 119 | def ignored_column_indices(self): 120 | """Returns column indices that are ignored (neither Matched nor Unmatched). 121 | 122 | The indices returned by this op are always sorted in increasing order. 123 | 124 | Returns: 125 | column_indices: int32 tensor of shape [K] with column indices. 126 | """ 127 | return self._reshape_and_cast(tf.where(self.ignored_column_indicator())) 128 | 129 | def ignored_column_indicator(self): 130 | """Returns boolean column indicator where True means the column is ignored. 131 | 132 | Returns: 133 | column_indicator: boolean vector which is True for all ignored column 134 | indices. 135 | """ 136 | return tf.equal(self._match_results, -2) 137 | 138 | def num_ignored_columns(self): 139 | """Returns number (int32 scalar tensor) of matched columns.""" 140 | return tf.size(self.ignored_column_indices()) 141 | 142 | def unmatched_or_ignored_column_indices(self): 143 | """Returns column indices that are unmatched or ignored. 144 | 145 | The indices returned by this op are always sorted in increasing order. 146 | 147 | Returns: 148 | column_indices: int32 tensor of shape [K] with column indices. 149 | """ 150 | return self._reshape_and_cast(tf.where(tf.greater(0, self._match_results))) 151 | 152 | def matched_row_indices(self): 153 | """Returns row indices that match some column. 154 | 155 | The indices returned by this op are ordered so as to be in correspondence 156 | with the output of matched_column_indicator(). For example if 157 | self.matched_column_indicator() is [0,2], and self.matched_row_indices() is 158 | [7, 3], then we know that column 0 was matched to row 7 and column 2 was 159 | matched to row 3. 160 | 161 | Returns: 162 | row_indices: int32 tensor of shape [K] with row indices. 163 | """ 164 | return self._reshape_and_cast( 165 | tf.gather(self._match_results, self.matched_column_indices())) 166 | 167 | def _reshape_and_cast(self, t): 168 | return tf.cast(tf.reshape(t, [-1]), tf.int32) 169 | 170 | def gather_based_on_match(self, input_tensor, unmatched_value, 171 | ignored_value): 172 | """Gathers elements from `input_tensor` based on match results. 173 | 174 | For columns that are matched to a row, gathered_tensor[col] is set to 175 | input_tensor[match_results[col]]. For columns that are unmatched, 176 | gathered_tensor[col] is set to unmatched_value. Finally, for columns that 177 | are ignored gathered_tensor[col] is set to ignored_value. 178 | 179 | Note that the input_tensor.shape[1:] must match with unmatched_value.shape 180 | and ignored_value.shape 181 | 182 | Args: 183 | input_tensor: Tensor to gather values from. 184 | unmatched_value: Constant tensor value for unmatched columns. 185 | ignored_value: Constant tensor value for ignored columns. 186 | 187 | Returns: 188 | gathered_tensor: A tensor containing values gathered from input_tensor. 189 | The shape of the gathered tensor is [match_results.shape[0]] + 190 | input_tensor.shape[1:]. 191 | """ 192 | input_tensor = tf.concat([tf.stack([ignored_value, unmatched_value]), 193 | input_tensor], axis=0) 194 | gather_indices = tf.maximum(self.match_results + 2, 0) 195 | gathered_tensor = tf.gather(input_tensor, gather_indices) 196 | return gathered_tensor 197 | 198 | 199 | class Matcher(object): 200 | """Abstract base class for matcher. 201 | """ 202 | __metaclass__ = ABCMeta 203 | 204 | def match(self, similarity_matrix, scope=None, **params): 205 | """Computes matches among row and column indices and returns the result. 206 | 207 | Computes matches among the row and column indices based on the similarity 208 | matrix and optional arguments. 209 | 210 | Args: 211 | similarity_matrix: Float tensor of shape [N, M] with pairwise similarity 212 | where higher value means more similar. 213 | scope: Op scope name. Defaults to 'Match' if None. 214 | **params: Additional keyword arguments for specific implementations of 215 | the Matcher. 216 | 217 | Returns: 218 | A Match object with the results of matching. 219 | """ 220 | with tf.name_scope(scope, 'Match', [similarity_matrix, params]) as scope: 221 | return Match(self._match(similarity_matrix, **params)) 222 | 223 | @abstractmethod 224 | def _match(self, similarity_matrix, **params): 225 | """Method to be overridden by implementations. 226 | 227 | Args: 228 | similarity_matrix: Float tensor of shape [N, M] with pairwise similarity 229 | where higher value means more similar. 230 | **params: Additional keyword arguments for specific implementations of 231 | the Matcher. 232 | 233 | Returns: 234 | match_results: Integer tensor of shape [M]: match_results[i]>=0 means 235 | that column i is matched to row match_results[i], match_results[i]=-1 236 | means that the column is not matched. match_results[i]=-2 means that 237 | the column is ignored (usually this happens when there is a very weak 238 | match which one neither wants as positive nor negative example). 239 | """ 240 | pass 241 | -------------------------------------------------------------------------------- /efficientdet/hparams_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Hparams for model architecture and trainer.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | # gtype import 20 | from __future__ import print_function 21 | 22 | import ast 23 | import copy 24 | import json 25 | import six 26 | 27 | 28 | def eval_str_fn(val): 29 | if val in {'true', 'false'}: 30 | return val == 'true' 31 | try: 32 | return ast.literal_eval(val) 33 | except ValueError: 34 | return val 35 | 36 | 37 | # pylint: disable=protected-access 38 | class Config(object): 39 | """A config utility class.""" 40 | 41 | def __init__(self, config_dict=None): 42 | self.update(config_dict) 43 | 44 | def __setattr__(self, k, v): 45 | self.__dict__[k] = Config(v) if isinstance(v, dict) else copy.deepcopy(v) 46 | 47 | def __getattr__(self, k): 48 | return self.__dict__[k] 49 | 50 | def __repr__(self): 51 | return repr(self.as_dict()) 52 | 53 | def __str__(self): 54 | try: 55 | return json.dumps(self.as_dict(), indent=4) 56 | except TypeError: 57 | return str(self.as_dict()) 58 | 59 | def _update(self, config_dict, allow_new_keys=True): 60 | """Recursively update internal members.""" 61 | if not config_dict: 62 | return 63 | 64 | for k, v in six.iteritems(config_dict): 65 | if k not in self.__dict__.keys(): 66 | if allow_new_keys: 67 | self.__setattr__(k, v) 68 | else: 69 | raise KeyError('Key `{}` does not exist for overriding. '.format(k)) 70 | else: 71 | if isinstance(v, dict): 72 | self.__dict__[k]._update(v, allow_new_keys) 73 | else: 74 | self.__dict__[k] = copy.deepcopy(v) 75 | 76 | def get(self, k, default_value=None): 77 | return self.__dict__.get(k, default_value) 78 | 79 | def update(self, config_dict): 80 | """Update members while allowing new keys.""" 81 | self._update(config_dict, allow_new_keys=True) 82 | 83 | def override(self, config_dict_or_str): 84 | """Update members while disallowing new keys.""" 85 | if isinstance(config_dict_or_str, str): 86 | config_dict = self.parse_from_str(config_dict_or_str) 87 | elif isinstance(config_dict_or_str, dict): 88 | config_dict = config_dict_or_str 89 | else: 90 | raise ValueError('Unknown value type: {}'.format(config_dict_or_str)) 91 | 92 | self._update(config_dict, allow_new_keys=False) 93 | 94 | def parse_from_str(self, config_str): 95 | """parse from a string in format 'x=a,y=2' and return the dict.""" 96 | if not config_str: 97 | return {} 98 | config_dict = {} 99 | try: 100 | for kv_pair in config_str.split(','): 101 | if not kv_pair: # skip empty string 102 | continue 103 | k, v = kv_pair.split('=') 104 | config_dict[k.strip()] = eval_str_fn(v.strip()) 105 | return config_dict 106 | except ValueError: 107 | raise ValueError('Invalid config_str: {}'.format(config_str)) 108 | 109 | def as_dict(self): 110 | """Returns a dict representation.""" 111 | config_dict = {} 112 | for k, v in six.iteritems(self.__dict__): 113 | if isinstance(v, Config): 114 | config_dict[k] = v.as_dict() 115 | else: 116 | config_dict[k] = copy.deepcopy(v) 117 | return config_dict 118 | 119 | 120 | # pylint: enable=protected-access 121 | 122 | 123 | def default_detection_configs(): 124 | """Returns a default detection configs.""" 125 | h = Config() 126 | 127 | # model name. 128 | h.name = 'efficientdet-d1' 129 | 130 | # input preprocessing parameters 131 | h.image_size = 640 132 | h.input_rand_hflip = True 133 | h.train_scale_min = 0.1 134 | h.train_scale_max = 2.0 135 | h.autoaugment_policy = None 136 | 137 | # dataset specific parameters 138 | h.num_classes = 90 139 | h.skip_crowd_during_training = True 140 | 141 | # model architecture 142 | h.min_level = 3 143 | h.max_level = 7 144 | h.num_scales = 3 145 | h.aspect_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)] 146 | h.anchor_scale = 4.0 147 | # is batchnorm training mode 148 | h.is_training_bn = True 149 | # optimization 150 | h.momentum = 0.9 151 | h.learning_rate = 0.08 152 | h.lr_warmup_init = 0.008 153 | h.lr_warmup_epoch = 1.0 154 | h.first_lr_drop_epoch = 200.0 155 | h.second_lr_drop_epoch = 250.0 156 | h.clip_gradients_norm = 10.0 157 | h.num_epochs = 300 158 | 159 | # classification loss 160 | h.alpha = 0.25 161 | h.gamma = 1.5 162 | # localization loss 163 | h.delta = 0.1 164 | h.box_loss_weight = 50.0 165 | # regularization l2 loss. 166 | h.weight_decay = 4e-5 167 | # enable bfloat 168 | h.use_bfloat16 = True 169 | 170 | # For detection. 171 | h.box_class_repeats = 3 172 | h.fpn_cell_repeats = 3 173 | h.fpn_num_filters = 88 174 | h.separable_conv = True 175 | h.apply_bn_for_resampling = True 176 | h.conv_after_downsample = False 177 | h.conv_bn_relu_pattern = False 178 | h.use_native_resize_op = False 179 | h.pooling_type = None 180 | 181 | # version. 182 | h.fpn_name = None 183 | h.fpn_config = None 184 | 185 | # No stochastic depth in default. 186 | h.survival_prob = None 187 | 188 | h.lr_decay_method = 'cosine' 189 | h.moving_average_decay = 0.9998 190 | h.ckpt_var_scope = None # ckpt variable scope. 191 | # exclude vars when loading pretrained ckpts. 192 | h.var_exclude_expr = '.*/class-predict/.*' # exclude class weights in default 193 | 194 | h.backbone_name = 'efficientnet-b1' 195 | h.backbone_config = None 196 | 197 | # RetinaNet. 198 | h.resnet_depth = 50 199 | return h 200 | 201 | 202 | efficientdet_model_param_dict = { 203 | 'efficientdet-d0': 204 | dict( 205 | name='efficientdet-d0', 206 | backbone_name='efficientnet-b0', 207 | image_size=512, 208 | fpn_num_filters=64, 209 | fpn_cell_repeats=3, 210 | box_class_repeats=3, 211 | ), 212 | 'efficientdet-d1': 213 | dict( 214 | name='efficientdet-d1', 215 | backbone_name='efficientnet-b1', 216 | image_size=640, 217 | fpn_num_filters=88, 218 | fpn_cell_repeats=4, 219 | box_class_repeats=3, 220 | ), 221 | 'efficientdet-d2': 222 | dict( 223 | name='efficientdet-d2', 224 | backbone_name='efficientnet-b2', 225 | image_size=768, 226 | fpn_num_filters=112, 227 | fpn_cell_repeats=5, 228 | box_class_repeats=3, 229 | ), 230 | 'efficientdet-d3': 231 | dict( 232 | name='efficientdet-d3', 233 | backbone_name='efficientnet-b3', 234 | image_size=896, 235 | fpn_num_filters=160, 236 | fpn_cell_repeats=6, 237 | box_class_repeats=4, 238 | ), 239 | 'efficientdet-d4': 240 | dict( 241 | name='efficientdet-d4', 242 | backbone_name='efficientnet-b4', 243 | image_size=1024, 244 | fpn_num_filters=224, 245 | fpn_cell_repeats=7, 246 | box_class_repeats=4, 247 | ), 248 | 'efficientdet-d5': 249 | dict( 250 | name='efficientdet-d5', 251 | backbone_name='efficientnet-b5', 252 | image_size=1280, 253 | fpn_num_filters=288, 254 | fpn_cell_repeats=7, 255 | box_class_repeats=4, 256 | ), 257 | 'efficientdet-d6': 258 | dict( 259 | name='efficientdet-d6', 260 | backbone_name='efficientnet-b6', 261 | image_size=1280, 262 | fpn_num_filters=384, 263 | fpn_cell_repeats=8, 264 | box_class_repeats=5, 265 | fpn_name='bifpn_sum', # Use unweighted sum for training stability. 266 | ), 267 | 'efficientdet-d7': 268 | dict( 269 | name='efficientdet-d7', 270 | backbone_name='efficientnet-b6', 271 | image_size=1536, 272 | fpn_num_filters=384, 273 | fpn_cell_repeats=8, 274 | box_class_repeats=5, 275 | anchor_scale=5.0, 276 | fpn_name='bifpn_sum', # Use unweighted sum for training stability. 277 | ), 278 | } 279 | 280 | 281 | def get_efficientdet_config(model_name='efficientdet-d1'): 282 | """Get the default config for EfficientDet based on model name.""" 283 | h = default_detection_configs() 284 | h.override(efficientdet_model_param_dict[model_name]) 285 | return h 286 | 287 | 288 | retinanet_model_param_dict = { 289 | 'retinanet-50': 290 | dict(name='retinanet-50', backbone_name='resnet50', resnet_depth=50), 291 | 'retinanet-101': 292 | dict(name='retinanet-101', backbone_name='resnet101', resnet_depth=101), 293 | } 294 | 295 | 296 | def get_retinanet_config(model_name='retinanet-50'): 297 | """Get the default config for EfficientDet based on model name.""" 298 | h = default_detection_configs() 299 | h.override( 300 | dict( 301 | retinanet_model_param_dict[model_name], 302 | ckpt_var_scope='', 303 | )) 304 | # cosine + ema often cause NaN for RetinaNet, so we use the default 305 | # stepwise without ema used in the original RetinaNet implementation. 306 | h.lr_decay_method = 'stepwise' 307 | h.moving_average_decay = 0 308 | 309 | return h 310 | 311 | 312 | def get_detection_config(model_name) -> Config: 313 | if model_name.startswith('efficientdet'): 314 | return get_efficientdet_config(model_name) 315 | elif model_name.startswith('retinanet'): 316 | return get_retinanet_config(model_name) 317 | else: 318 | raise ValueError('model name must start with efficientdet or retinanet.') 319 | -------------------------------------------------------------------------------- /efficientdet/normalization_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The V2 implementation of Normalization layers. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow.python.framework import dtypes 22 | from tensorflow.python.framework import ops 23 | from tensorflow.python.ops import array_ops 24 | from tensorflow.python.ops import math_ops 25 | 26 | import normalization 27 | from horovod_estimator import hvd, hvd_info 28 | 29 | 30 | # pylint: disable=g-classes-have-attributes 31 | class SyncBatchNormalization(normalization.BatchNormalizationBase): 32 | r"""Normalize and scale inputs or activations synchronously across replicas. 33 | 34 | Applies batch normalization to activations of the previous layer at each batch 35 | by synchronizing the global batch statistics across all devices that are 36 | training the model. For specific details about batch normalization please 37 | refer to the `tf.keras.layers.BatchNormalization` layer docs. 38 | 39 | If this layer is used when using tf.distribute strategy to train models 40 | across devices/workers, there will be an allreduce call to aggregate batch 41 | statistics across all replicas at every training step. Without tf.distribute 42 | strategy, this layer behaves as a regular `tf.keras.layers.BatchNormalization` 43 | layer. 44 | 45 | Example usage: 46 | ``` 47 | strategy = tf.distribute.MirroredStrategy() 48 | 49 | with strategy.scope(): 50 | model = tf.keras.Sequential() 51 | model.add(tf.keras.layers.Dense(16)) 52 | model.add(tf.keras.layers.experimental.SyncBatchNormalization()) 53 | ``` 54 | 55 | Arguments: 56 | axis: Integer, the axis that should be normalized 57 | (typically the features axis). 58 | For instance, after a `Conv2D` layer with 59 | `data_format="channels_first"`, 60 | set `axis=1` in `BatchNormalization`. 61 | momentum: Momentum for the moving average. 62 | epsilon: Small float added to variance to avoid dividing by zero. 63 | center: If True, add offset of `beta` to normalized tensor. 64 | If False, `beta` is ignored. 65 | scale: If True, multiply by `gamma`. 66 | If False, `gamma` is not used. 67 | When the next layer is linear (also e.g. `nn.relu`), 68 | this can be disabled since the scaling 69 | will be done by the next layer. 70 | beta_initializer: Initializer for the beta weight. 71 | gamma_initializer: Initializer for the gamma weight. 72 | moving_mean_initializer: Initializer for the moving mean. 73 | moving_variance_initializer: Initializer for the moving variance. 74 | beta_regularizer: Optional regularizer for the beta weight. 75 | gamma_regularizer: Optional regularizer for the gamma weight. 76 | beta_constraint: Optional constraint for the beta weight. 77 | gamma_constraint: Optional constraint for the gamma weight. 78 | renorm: Whether to use Batch Renormalization 79 | (https://arxiv.org/abs/1702.03275). This adds extra variables during 80 | training. The inference is the same for either value of this parameter. 81 | renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to 82 | scalar `Tensors` used to clip the renorm correction. The correction 83 | `(r, d)` is used as `corrected_value = normalized_value * r + d`, with 84 | `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, 85 | dmax are set to inf, 0, inf, respectively. 86 | renorm_momentum: Momentum used to update the moving means and standard 87 | deviations with renorm. Unlike `momentum`, this affects training 88 | and should be neither too small (which would add noise) nor too large 89 | (which would give stale estimates). Note that `momentum` is still applied 90 | to get the means and variances for inference. 91 | trainable: Boolean, if `True` the variables will be marked as trainable. 92 | 93 | Call arguments: 94 | inputs: Input tensor (of any rank). 95 | training: Python boolean indicating whether the layer should behave in 96 | training mode or in inference mode. 97 | - `training=True`: The layer will normalize its inputs using the 98 | mean and variance of the current batch of inputs. 99 | - `training=False`: The layer will normalize its inputs using the 100 | mean and variance of its moving statistics, learned during training. 101 | 102 | Input shape: 103 | Arbitrary. Use the keyword argument `input_shape` 104 | (tuple of integers, does not include the samples axis) 105 | when using this layer as the first layer in a model. 106 | 107 | Output shape: 108 | Same shape as input. 109 | 110 | """ 111 | 112 | def __init__(self, 113 | axis=-1, 114 | momentum=0.99, 115 | epsilon=1e-3, 116 | center=True, 117 | scale=True, 118 | beta_initializer='zeros', 119 | gamma_initializer='ones', 120 | moving_mean_initializer='zeros', 121 | moving_variance_initializer='ones', 122 | beta_regularizer=None, 123 | gamma_regularizer=None, 124 | beta_constraint=None, 125 | gamma_constraint=None, 126 | renorm=False, 127 | renorm_clipping=None, 128 | renorm_momentum=0.99, 129 | trainable=True, 130 | adjustment=None, 131 | **kwargs): 132 | 133 | # Currently we only support aggregating over the global batch size. 134 | 135 | # if name is None: 136 | # name = 'tpu_batch_normalization' 137 | 138 | if not kwargs.get('name', None): 139 | kwargs['name'] = 'tpu_batch_normalization' 140 | 141 | super(SyncBatchNormalization, self).__init__( 142 | axis=axis, 143 | momentum=momentum, 144 | epsilon=epsilon, 145 | center=center, 146 | scale=scale, 147 | beta_initializer=beta_initializer, 148 | gamma_initializer=gamma_initializer, 149 | moving_mean_initializer=moving_mean_initializer, 150 | moving_variance_initializer=moving_variance_initializer, 151 | beta_regularizer=beta_regularizer, 152 | gamma_regularizer=gamma_regularizer, 153 | beta_constraint=beta_constraint, 154 | gamma_constraint=gamma_constraint, 155 | renorm=renorm, 156 | renorm_clipping=renorm_clipping, 157 | renorm_momentum=renorm_momentum, 158 | fused=False, 159 | trainable=trainable, 160 | virtual_batch_size=None, 161 | **kwargs) 162 | 163 | def _calculate_mean_and_var(self, x, axes, keep_dims): 164 | 165 | with ops.name_scope('moments', values=[x, axes]): 166 | # The dynamic range of fp16 is too limited to support the collection of 167 | # sufficient statistics. As a workaround we simply perform the operations 168 | # on 32-bit floats before converting the mean and variance back to fp16 169 | y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x 170 | 171 | if hvd is not None: 172 | num_shards = hvd.size() 173 | else: 174 | num_shards = 1 175 | 176 | if num_shards > 1: 177 | local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) 178 | local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True) 179 | batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32) 180 | # y_sum, y_squared_sum, global_batch_size = ( 181 | # replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [ 182 | # local_sum, local_squared_sum, batch_size])) 183 | 184 | # hvd_info(f'local_sum {local_sum.shape}, local_squared_sum {local_squared_sum.shape}') 185 | 186 | y_sum = hvd.allreduce(local_sum, average=False) 187 | y_squared_sum = hvd.allreduce(local_squared_sum, average=False) 188 | 189 | global_batch_size = batch_size * num_shards 190 | axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))] 191 | multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32) 192 | multiplier = multiplier * global_batch_size 193 | 194 | mean = y_sum / multiplier 195 | y_squared_mean = y_squared_sum / multiplier 196 | # var = E(x^2) - E(x)^2 197 | variance = y_squared_mean - math_ops.square(mean) 198 | else: 199 | # Compute true mean while keeping the dims for proper broadcasting. 200 | mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean') 201 | # sample variance, not unbiased variance 202 | # Note: stop_gradient does not change the gradient that gets 203 | # backpropagated to the mean from the variance calculation, 204 | # because that gradient is zero 205 | variance = math_ops.reduce_mean( 206 | math_ops.squared_difference(y, array_ops.stop_gradient(mean)), 207 | axes, 208 | keepdims=True, 209 | name='variance') 210 | if not keep_dims: 211 | mean = array_ops.squeeze(mean, axes) 212 | variance = array_ops.squeeze(variance, axes) 213 | if x.dtype == dtypes.float16: 214 | return (math_ops.cast(mean, dtypes.float16), 215 | math_ops.cast(variance, dtypes.float16)) 216 | else: 217 | return (mean, variance) 218 | -------------------------------------------------------------------------------- /efficientdet/dataset/create_pascal_tfrecord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Convert PASCAL dataset to TFRecord. 16 | 17 | Example usage: 18 | python create_pascal_tfrecord.py --data_dir=/tmp/VOCdevkit \ 19 | --year=VOC2012 --output_path=/tmp/pascal 20 | """ 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import hashlib 26 | import io 27 | import json 28 | import logging 29 | import os 30 | 31 | from lxml import etree 32 | import PIL.Image 33 | import tensorflow.compat.v1 as tf 34 | 35 | from dataset import tfrecord_util 36 | 37 | 38 | flags = tf.app.flags 39 | flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.') 40 | flags.DEFINE_string('set', 'train', 'Convert training set, validation set or ' 41 | 'merged set.') 42 | flags.DEFINE_string('annotations_dir', 'Annotations', 43 | '(Relative) path to annotations directory.') 44 | flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.') 45 | flags.DEFINE_string('output_path', '', 'Path to output TFRecord and json.') 46 | flags.DEFINE_string('label_map_json_path', None, 47 | 'Path to label map json file with a dictionary.') 48 | flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore ' 49 | 'difficult instances') 50 | flags.DEFINE_integer('num_shards', 100, 'Number of shards for output file.') 51 | flags.DEFINE_integer('num_images', None, 'Max number of imags to process.') 52 | FLAGS = flags.FLAGS 53 | 54 | SETS = ['train', 'val', 'trainval', 'test'] 55 | YEARS = ['VOC2007', 'VOC2012', 'merged'] 56 | 57 | pascal_label_map_dict = { 58 | 'background': 0, 'aeroplane': 1, 'bicycle': 2, 'bird': 3, 'boat': 4, 59 | 'bottle': 5, 'bus': 6, 'car': 7, 'cat': 8, 'chair': 9, 'cow': 10, 60 | 'diningtable': 11, 'dog': 12, 'horse': 13, 'motorbike': 14, 'person': 15, 61 | 'pottedplant': 16, 'sheep': 17, 'sofa': 18, 'train': 19, 'tvmonitor': 20, 62 | } 63 | 64 | 65 | GLOBAL_ID = 0 66 | 67 | 68 | def get_image_id(filename): 69 | """Convert a string to a integer.""" 70 | # Warning: this function is highly specific to pascal filename!! 71 | # Given filename like '2008_000002', we cannot use id 2008000002 because our 72 | # code internally will convert the int value to float32 and back to int, which 73 | # would cause value mismatch int(float32(2008000002)) != int(2008000002). 74 | # COCO needs int values, here we just use a incremental global_id, but 75 | # users should customize their own ways to generate filename. 76 | del filename 77 | global GLOBAL_ID 78 | GLOBAL_ID += 1 79 | return GLOBAL_ID 80 | 81 | 82 | def dict_to_tf_example(data, 83 | dataset_directory, 84 | label_map_dict, 85 | ignore_difficult_instances=False, 86 | image_subdirectory='JPEGImages', 87 | ann_json_dict=None): 88 | """Convert XML derived dict to tf.Example proto. 89 | 90 | Notice that this function normalizes the bounding box coordinates provided 91 | by the raw data. 92 | 93 | Args: 94 | data: dict holding PASCAL XML fields for a single image (obtained by 95 | running tfrecord_util.recursive_parse_xml_to_dict) 96 | dataset_directory: Path to root directory holding PASCAL dataset 97 | label_map_dict: A map from string label names to integers ids. 98 | ignore_difficult_instances: Whether to skip difficult instances in the 99 | dataset (default: False). 100 | image_subdirectory: String specifying subdirectory within the 101 | PASCAL dataset directory holding the actual image data. 102 | ann_json_dict: annotation json dictionary. 103 | 104 | Returns: 105 | example: The converted tf.Example. 106 | 107 | Raises: 108 | ValueError: if the image pointed to by data['filename'] is not a valid JPEG 109 | """ 110 | img_path = os.path.join(data['folder'], image_subdirectory, data['filename']) 111 | full_path = os.path.join(dataset_directory, img_path) 112 | with tf.gfile.GFile(full_path, 'rb') as fid: 113 | encoded_jpg = fid.read() 114 | encoded_jpg_io = io.BytesIO(encoded_jpg) 115 | image = PIL.Image.open(encoded_jpg_io) 116 | if image.format != 'JPEG': 117 | raise ValueError('Image format not JPEG') 118 | key = hashlib.sha256(encoded_jpg).hexdigest() 119 | 120 | width = int(data['size']['width']) 121 | height = int(data['size']['height']) 122 | image_id = get_image_id(data['filename']) 123 | if ann_json_dict: 124 | image = { 125 | 'file_name': data['filename'], 126 | 'height': height, 127 | 'width': width, 128 | 'id': image_id, 129 | } 130 | ann_json_dict['images'].append(image) 131 | ann_box_id = 1 132 | 133 | xmin = [] 134 | ymin = [] 135 | xmax = [] 136 | ymax = [] 137 | classes = [] 138 | classes_text = [] 139 | truncated = [] 140 | poses = [] 141 | difficult_obj = [] 142 | if 'object' in data: 143 | for obj in data['object']: 144 | difficult = bool(int(obj['difficult'])) 145 | if ignore_difficult_instances and difficult: 146 | continue 147 | 148 | difficult_obj.append(int(difficult)) 149 | 150 | xmin.append(float(obj['bndbox']['xmin']) / width) 151 | ymin.append(float(obj['bndbox']['ymin']) / height) 152 | xmax.append(float(obj['bndbox']['xmax']) / width) 153 | ymax.append(float(obj['bndbox']['ymax']) / height) 154 | classes_text.append(obj['name'].encode('utf8')) 155 | classes.append(label_map_dict[obj['name']]) 156 | truncated.append(int(obj['truncated'])) 157 | poses.append(obj['pose'].encode('utf8')) 158 | 159 | if ann_json_dict: 160 | abs_xmin = int(obj['bndbox']['xmin']) 161 | abs_ymin = int(obj['bndbox']['ymin']) 162 | abs_xmax = int(obj['bndbox']['xmax']) 163 | abs_ymax = int(obj['bndbox']['ymax']) 164 | abs_width = abs_xmax - abs_xmin 165 | abs_height = abs_ymax - abs_ymin 166 | ann = { 167 | 'area': abs_width * abs_height, 168 | 'iscrowd': 0, 169 | 'image_id': image_id, 170 | 'bbox': [abs_xmin, abs_ymin, abs_width, abs_height], 171 | 'category_id': label_map_dict[obj['name']], 172 | 'id': ann_box_id, 173 | 'ignore': 0, 174 | 'segmentation': [], 175 | } 176 | ann_json_dict['annotations'].append(ann) 177 | ann_box_id += 1 178 | 179 | example = tf.train.Example(features=tf.train.Features(feature={ 180 | 'image/height': tfrecord_util.int64_feature(height), 181 | 'image/width': tfrecord_util.int64_feature(width), 182 | 'image/filename': tfrecord_util.bytes_feature( 183 | data['filename'].encode('utf8')), 184 | 'image/source_id': tfrecord_util.bytes_feature( 185 | str(image_id).encode('utf8')), 186 | 'image/key/sha256': tfrecord_util.bytes_feature(key.encode('utf8')), 187 | 'image/encoded': tfrecord_util.bytes_feature(encoded_jpg), 188 | 'image/format': tfrecord_util.bytes_feature('jpeg'.encode('utf8')), 189 | 'image/object/bbox/xmin': tfrecord_util.float_list_feature(xmin), 190 | 'image/object/bbox/xmax': tfrecord_util.float_list_feature(xmax), 191 | 'image/object/bbox/ymin': tfrecord_util.float_list_feature(ymin), 192 | 'image/object/bbox/ymax': tfrecord_util.float_list_feature(ymax), 193 | 'image/object/class/text': tfrecord_util.bytes_list_feature(classes_text), 194 | 'image/object/class/label': tfrecord_util.int64_list_feature(classes), 195 | 'image/object/difficult': tfrecord_util.int64_list_feature(difficult_obj), 196 | 'image/object/truncated': tfrecord_util.int64_list_feature(truncated), 197 | 'image/object/view': tfrecord_util.bytes_list_feature(poses), 198 | })) 199 | return example 200 | 201 | 202 | def main(_): 203 | if FLAGS.set not in SETS: 204 | raise ValueError('set must be in : {}'.format(SETS)) 205 | if FLAGS.year not in YEARS: 206 | raise ValueError('year must be in : {}'.format(YEARS)) 207 | if not FLAGS.output_path: 208 | raise ValueError('output_path cannot be empty.') 209 | 210 | data_dir = FLAGS.data_dir 211 | years = ['VOC2007', 'VOC2012'] 212 | if FLAGS.year != 'merged': 213 | years = [FLAGS.year] 214 | 215 | logging.info('writing to output path: %s', FLAGS.output_path) 216 | writers = [ 217 | tf.python_io.TFRecordWriter( 218 | FLAGS.output_path + '-%05d-of-%05d.tfrecord' % (i, FLAGS.num_shards)) 219 | for i in range(FLAGS.num_shards) 220 | ] 221 | 222 | if FLAGS.label_map_json_path: 223 | with tf.io.gfile.GFile(FLAGS.label_map_json_path, 'rb') as f: 224 | label_map_dict = json.load(f) 225 | else: 226 | label_map_dict = pascal_label_map_dict 227 | 228 | for year in years: 229 | ann_json_dict = { 230 | 'images': [], 231 | 'type': 'instances', 232 | 'annotations': [], 233 | 'categories': [] 234 | } 235 | for class_name, class_id in label_map_dict.items(): 236 | cls = {'supercategory': 'none', 'id': class_id, 'name': class_name} 237 | ann_json_dict['categories'].append(cls) 238 | 239 | logging.info('Reading from PASCAL %s dataset.', year) 240 | examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', 241 | 'aeroplane_' + FLAGS.set + '.txt') 242 | annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir) 243 | examples_list = tfrecord_util.read_examples_list(examples_path) 244 | for idx, example in enumerate(examples_list): 245 | if FLAGS.num_images and idx >= FLAGS.num_images: 246 | break 247 | if idx % 100 == 0: 248 | logging.info('On image %d of %d', idx, len(examples_list)) 249 | path = os.path.join(annotations_dir, example + '.xml') 250 | with tf.gfile.GFile(path, 'r') as fid: 251 | xml_str = fid.read() 252 | xml = etree.fromstring(xml_str) 253 | data = tfrecord_util.recursive_parse_xml_to_dict(xml)['annotation'] 254 | 255 | tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict, 256 | FLAGS.ignore_difficult_instances, 257 | ann_json_dict=ann_json_dict) 258 | writers[idx % FLAGS.num_shards].write(tf_example.SerializeToString()) 259 | 260 | for writer in writers: 261 | writer.close() 262 | 263 | json_file_path = os.path.join( 264 | os.path.dirname(FLAGS.output_path), 265 | 'json_' + os.path.basename(FLAGS.output_path) + '.json') 266 | with tf.io.gfile.GFile(json_file_path, 'w') as f: 267 | json.dump(ann_json_dict, f) 268 | 269 | 270 | if __name__ == '__main__': 271 | tf.app.run() 272 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Google Research. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2020, Google Research. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. -------------------------------------------------------------------------------- /efficientdet/backbone/efficientnet_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Model Builder for EfficientNet. 16 | 17 | efficientnet-bx (x=0,1,2,3,4,5,6,7) checkpoints are located in: 18 | https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-bx.tar.gz 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import functools 26 | import os 27 | import re 28 | from absl import logging 29 | import numpy as np 30 | import six 31 | import tensorflow.compat.v1 as tf 32 | 33 | import utils 34 | from backbone import efficientnet_model 35 | 36 | 37 | def efficientnet_params(model_name): 38 | """Get efficientnet params based on model name.""" 39 | params_dict = { 40 | # (width_coefficient, depth_coefficient, resolution, dropout_rate) 41 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 42 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 43 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 44 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 45 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 46 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 47 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 48 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 49 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 50 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5), 51 | } 52 | return params_dict[model_name] 53 | 54 | 55 | class BlockDecoder(object): 56 | """Block Decoder for readability.""" 57 | 58 | def _decode_block_string(self, block_string): 59 | """Gets a block through a string notation of arguments.""" 60 | if six.PY2: 61 | assert isinstance(block_string, (str, unicode)) 62 | else: 63 | assert isinstance(block_string, str) 64 | ops = block_string.split('_') 65 | options = {} 66 | for op in ops: 67 | splits = re.split(r'(\d.*)', op) 68 | if len(splits) >= 2: 69 | key, value = splits[:2] 70 | options[key] = value 71 | 72 | if 's' not in options or len(options['s']) != 2: 73 | raise ValueError('Strides options should be a pair of integers.') 74 | 75 | return efficientnet_model.BlockArgs( 76 | kernel_size=int(options['k']), 77 | num_repeat=int(options['r']), 78 | input_filters=int(options['i']), 79 | output_filters=int(options['o']), 80 | expand_ratio=int(options['e']), 81 | id_skip=('noskip' not in block_string), 82 | se_ratio=float(options['se']) if 'se' in options else None, 83 | strides=[int(options['s'][0]), 84 | int(options['s'][1])], 85 | conv_type=int(options['c']) if 'c' in options else 0, 86 | fused_conv=int(options['f']) if 'f' in options else 0, 87 | super_pixel=int(options['p']) if 'p' in options else 0, 88 | condconv=('cc' in block_string)) 89 | 90 | def _encode_block_string(self, block): 91 | """Encodes a block to a string.""" 92 | args = [ 93 | 'r%d' % block.num_repeat, 94 | 'k%d' % block.kernel_size, 95 | 's%d%d' % (block.strides[0], block.strides[1]), 96 | 'e%s' % block.expand_ratio, 97 | 'i%d' % block.input_filters, 98 | 'o%d' % block.output_filters, 99 | 'c%d' % block.conv_type, 100 | 'f%d' % block.fused_conv, 101 | 'p%d' % block.super_pixel, 102 | ] 103 | if block.se_ratio > 0 and block.se_ratio <= 1: 104 | args.append('se%s' % block.se_ratio) 105 | if block.id_skip is False: # pylint: disable=g-bool-id-comparison 106 | args.append('noskip') 107 | if block.condconv: 108 | args.append('cc') 109 | return '_'.join(args) 110 | 111 | def decode(self, string_list): 112 | """Decodes a list of string notations to specify blocks inside the network. 113 | 114 | Args: 115 | string_list: a list of strings, each string is a notation of block. 116 | 117 | Returns: 118 | A list of namedtuples to represent blocks arguments. 119 | """ 120 | assert isinstance(string_list, list) 121 | blocks_args = [] 122 | for block_string in string_list: 123 | blocks_args.append(self._decode_block_string(block_string)) 124 | return blocks_args 125 | 126 | def encode(self, blocks_args): 127 | """Encodes a list of Blocks to a list of strings. 128 | 129 | Args: 130 | blocks_args: A list of namedtuples to represent blocks arguments. 131 | Returns: 132 | a list of strings, each string is a notation of block. 133 | """ 134 | block_strings = [] 135 | for block in blocks_args: 136 | block_strings.append(self._encode_block_string(block)) 137 | return block_strings 138 | 139 | 140 | def swish(features, use_native=True, use_hard=False): 141 | """Computes the Swish activation function. 142 | 143 | We provide three alternatives: 144 | - Native tf.nn.swish, use less memory during training than composable swish. 145 | - Quantization friendly hard swish. 146 | - A composable swish, equivalent to tf.nn.swish, but more general for 147 | finetuning and TF-Hub. 148 | 149 | Args: 150 | features: A `Tensor` representing preactivation values. 151 | use_native: Whether to use the native swish from tf.nn that uses a custom 152 | gradient to reduce memory usage, or to use customized swish that uses 153 | default TensorFlow gradient computation. 154 | use_hard: Whether to use quantization-friendly hard swish. 155 | 156 | Returns: 157 | The activation value. 158 | """ 159 | if use_native and use_hard: 160 | raise ValueError('Cannot specify both use_native and use_hard.') 161 | 162 | if use_native: 163 | return tf.nn.swish(features) 164 | 165 | if use_hard: 166 | return features * tf.nn.relu6(features + np.float32(3)) * (1. / 6.) 167 | 168 | features = tf.convert_to_tensor(features, name='features') 169 | return features * tf.nn.sigmoid(features) 170 | 171 | 172 | _DEFAULT_BLOCKS_ARGS = [ 173 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 174 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 175 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 176 | 'r1_k3_s11_e6_i192_o320_se0.25', 177 | ] 178 | 179 | 180 | def efficientnet(width_coefficient=None, 181 | depth_coefficient=None, 182 | dropout_rate=0.2, 183 | survival_prob=0.8): 184 | """Creates a efficientnet model.""" 185 | global_params = efficientnet_model.GlobalParams( 186 | blocks_args=_DEFAULT_BLOCKS_ARGS, 187 | batch_norm_momentum=0.99, 188 | batch_norm_epsilon=1e-3, 189 | dropout_rate=dropout_rate, 190 | survival_prob=survival_prob, 191 | data_format='channels_last', 192 | num_classes=1000, 193 | width_coefficient=width_coefficient, 194 | depth_coefficient=depth_coefficient, 195 | depth_divisor=8, 196 | min_depth=None, 197 | relu_fn=tf.nn.swish, 198 | # The default is TPU-specific batch norm. 199 | # The alternative is tf.layers.BatchNormalization. 200 | batch_norm=utils.batch_norm_class, # TPU-specific requirement. 201 | use_se=True, 202 | clip_projection_output=False) 203 | return global_params 204 | 205 | 206 | def get_model_params(model_name, override_params): 207 | """Get the block args and global params for a given model.""" 208 | if model_name.startswith('efficientnet'): 209 | width_coefficient, depth_coefficient, _, dropout_rate = ( 210 | efficientnet_params(model_name)) 211 | global_params = efficientnet( 212 | width_coefficient, depth_coefficient, dropout_rate) 213 | else: 214 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 215 | 216 | if override_params: 217 | # ValueError will be raised here if override_params has fields not included 218 | # in global_params. 219 | global_params = global_params._replace(**override_params) 220 | 221 | decoder = BlockDecoder() 222 | blocks_args = decoder.decode(global_params.blocks_args) 223 | 224 | logging.info('global_params= %s', global_params) 225 | return blocks_args, global_params 226 | 227 | 228 | def build_model(images, 229 | model_name, 230 | training, 231 | override_params=None, 232 | model_dir=None, 233 | fine_tuning=False, 234 | features_only=False, 235 | pooled_features_only=False): 236 | """A helper function to create a model and return predicted logits. 237 | 238 | Args: 239 | images: input images tensor. 240 | model_name: string, the predefined model name. 241 | training: boolean, whether the model is constructed for training. 242 | override_params: A dictionary of params for overriding. Fields must exist in 243 | efficientnet_model.GlobalParams. 244 | model_dir: string, optional model dir for saving configs. 245 | fine_tuning: boolean, whether the model is used for finetuning. 246 | features_only: build the base feature network only (excluding final 247 | 1x1 conv layer, global pooling, dropout and fc head). 248 | pooled_features_only: build the base network for features extraction (after 249 | 1x1 conv layer and global pooling, but before dropout and fc head). 250 | 251 | Returns: 252 | logits: the logits tensor of classes. 253 | endpoints: the endpoints for each layer. 254 | 255 | Raises: 256 | When model_name specified an undefined model, raises NotImplementedError. 257 | When override_params has invalid fields, raises ValueError. 258 | """ 259 | assert isinstance(images, tf.Tensor) 260 | assert not (features_only and pooled_features_only) 261 | 262 | # For backward compatibility. 263 | if override_params and override_params.get('drop_connect_rate', None): 264 | override_params['survival_prob'] = 1 - override_params['drop_connect_rate'] 265 | 266 | if not training or fine_tuning: 267 | if not override_params: 268 | override_params = {} 269 | override_params['batch_norm'] = utils.BatchNormalization 270 | if fine_tuning: 271 | override_params['relu_fn'] = functools.partial(swish, use_native=False) 272 | blocks_args, global_params = get_model_params(model_name, override_params) 273 | 274 | if model_dir: 275 | param_file = os.path.join(model_dir, 'model_params.txt') 276 | if not tf.gfile.Exists(param_file): 277 | if not tf.gfile.Exists(model_dir): 278 | tf.gfile.MakeDirs(model_dir) 279 | with tf.gfile.GFile(param_file, 'w') as f: 280 | logging.info('writing to %s', param_file) 281 | f.write('model_name= %s\n\n' % model_name) 282 | f.write('global_params= %s\n\n' % str(global_params)) 283 | f.write('blocks_args= %s\n\n' % str(blocks_args)) 284 | 285 | with tf.variable_scope(model_name): 286 | model = efficientnet_model.Model(blocks_args, global_params) 287 | outputs = model( 288 | images, 289 | training=training, 290 | features_only=features_only, 291 | pooled_features_only=pooled_features_only) 292 | if features_only: 293 | outputs = tf.identity(outputs, 'features') 294 | elif pooled_features_only: 295 | outputs = tf.identity(outputs, 'pooled_features') 296 | else: 297 | outputs = tf.identity(outputs, 'logits') 298 | return outputs, model.endpoints 299 | 300 | 301 | def build_model_base(images, model_name, training, override_params=None): 302 | """Create a base feature network and return the features before pooling. 303 | 304 | Args: 305 | images: input images tensor. 306 | model_name: string, the predefined model name. 307 | training: boolean, whether the model is constructed for training. 308 | override_params: A dictionary of params for overriding. Fields must exist in 309 | efficientnet_model.GlobalParams. 310 | 311 | Returns: 312 | features: base features before pooling. 313 | endpoints: the endpoints for each layer. 314 | 315 | Raises: 316 | When model_name specified an undefined model, raises NotImplementedError. 317 | When override_params has invalid fields, raises ValueError. 318 | """ 319 | assert isinstance(images, tf.Tensor) 320 | # For backward compatibility. 321 | if override_params and override_params.get('drop_connect_rate', None): 322 | override_params['survival_prob'] = 1 - override_params['drop_connect_rate'] 323 | 324 | blocks_args, global_params = get_model_params(model_name, override_params) 325 | 326 | with tf.variable_scope(model_name): 327 | model = efficientnet_model.Model(blocks_args, global_params) 328 | features = model(images, training=training, features_only=True) 329 | 330 | features = tf.identity(features, 'features') 331 | return features, model.endpoints 332 | -------------------------------------------------------------------------------- /efficientdet/visualize/standard_fields.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains classes specifying naming conventions used for object detection. 17 | 18 | 19 | Specifies: 20 | InputDataFields: standard fields used by reader/preprocessor/batcher. 21 | DetectionResultFields: standard fields returned by object detector. 22 | BoxListFields: standard field used by BoxList 23 | TfExampleFields: standard fields for tf-example data format (go/tf-example). 24 | """ 25 | 26 | 27 | class InputDataFields(object): 28 | """Names for the input tensors. 29 | 30 | Holds the standard data field names to use for identifying input tensors. This 31 | should be used by the decoder to identify keys for the returned tensor_dict 32 | containing input tensors. And it should be used by the model to identify the 33 | tensors it needs. 34 | 35 | Attributes: 36 | image: image. 37 | image_additional_channels: additional channels. 38 | original_image: image in the original input size. 39 | original_image_spatial_shape: image in the original input size. 40 | key: unique key corresponding to image. 41 | source_id: source of the original image. 42 | filename: original filename of the dataset (without common path). 43 | groundtruth_image_classes: image-level class labels. 44 | groundtruth_image_confidences: image-level class confidences. 45 | groundtruth_boxes: coordinates of the ground truth boxes in the image. 46 | groundtruth_classes: box-level class labels. 47 | groundtruth_confidences: box-level class confidences. The shape should be 48 | the same as the shape of groundtruth_classes. 49 | groundtruth_label_types: box-level label types (e.g. explicit negative). 50 | groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] 51 | is the groundtruth a single object or a crowd. 52 | groundtruth_area: area of a groundtruth segment. 53 | groundtruth_difficult: is a `difficult` object 54 | groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the 55 | same class, forming a connected group, where instances are heavily 56 | occluding each other. 57 | proposal_boxes: coordinates of object proposal boxes. 58 | proposal_objectness: objectness score of each proposal. 59 | groundtruth_instance_masks: ground truth instance masks. 60 | groundtruth_instance_boundaries: ground truth instance boundaries. 61 | groundtruth_instance_classes: instance mask-level class labels. 62 | groundtruth_keypoints: ground truth keypoints. 63 | groundtruth_keypoint_visibilities: ground truth keypoint visibilities. 64 | groundtruth_keypoint_weights: groundtruth weight factor for keypoints. 65 | groundtruth_label_weights: groundtruth label weights. 66 | groundtruth_weights: groundtruth weight factor for bounding boxes. 67 | num_groundtruth_boxes: number of groundtruth boxes. 68 | is_annotated: whether an image has been labeled or not. 69 | true_image_shapes: true shapes of images in the resized images, as resized 70 | images can be padded with zeros. 71 | multiclass_scores: the label score per class for each box. 72 | context_features: a flattened list of contextual features. 73 | context_feature_length: the fixed length of each feature in 74 | context_features, used for reshaping. 75 | valid_context_size: the valid context size, used in filtering the padded 76 | context features. 77 | """ 78 | image = 'image' 79 | image_additional_channels = 'image_additional_channels' 80 | original_image = 'original_image' 81 | original_image_spatial_shape = 'original_image_spatial_shape' 82 | key = 'key' 83 | source_id = 'source_id' 84 | filename = 'filename' 85 | groundtruth_image_classes = 'groundtruth_image_classes' 86 | groundtruth_image_confidences = 'groundtruth_image_confidences' 87 | groundtruth_boxes = 'groundtruth_boxes' 88 | groundtruth_classes = 'groundtruth_classes' 89 | groundtruth_confidences = 'groundtruth_confidences' 90 | groundtruth_label_types = 'groundtruth_label_types' 91 | groundtruth_is_crowd = 'groundtruth_is_crowd' 92 | groundtruth_area = 'groundtruth_area' 93 | groundtruth_difficult = 'groundtruth_difficult' 94 | groundtruth_group_of = 'groundtruth_group_of' 95 | proposal_boxes = 'proposal_boxes' 96 | proposal_objectness = 'proposal_objectness' 97 | groundtruth_instance_masks = 'groundtruth_instance_masks' 98 | groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' 99 | groundtruth_instance_classes = 'groundtruth_instance_classes' 100 | groundtruth_keypoints = 'groundtruth_keypoints' 101 | groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' 102 | groundtruth_keypoint_weights = 'groundtruth_keypoint_weights' 103 | groundtruth_label_weights = 'groundtruth_label_weights' 104 | groundtruth_weights = 'groundtruth_weights' 105 | num_groundtruth_boxes = 'num_groundtruth_boxes' 106 | is_annotated = 'is_annotated' 107 | true_image_shape = 'true_image_shape' 108 | multiclass_scores = 'multiclass_scores' 109 | context_features = 'context_features' 110 | context_feature_length = 'context_feature_length' 111 | valid_context_size = 'valid_context_size' 112 | 113 | 114 | class DetectionResultFields(object): 115 | """Naming conventions for storing the output of the detector. 116 | 117 | Attributes: 118 | source_id: source of the original image. 119 | key: unique key corresponding to image. 120 | detection_boxes: coordinates of the detection boxes in the image. 121 | detection_scores: detection scores for the detection boxes in the image. 122 | detection_multiclass_scores: class score distribution (including background) 123 | for detection boxes in the image including background class. 124 | detection_classes: detection-level class labels. 125 | detection_masks: contains a segmentation mask for each detection box. 126 | detection_boundaries: contains an object boundary for each detection box. 127 | detection_keypoints: contains detection keypoints for each detection box. 128 | detection_keypoint_scores: contains detection keypoint scores. 129 | num_detections: number of detections in the batch. 130 | raw_detection_boxes: contains decoded detection boxes without Non-Max 131 | suppression. 132 | raw_detection_scores: contains class score logits for raw detection boxes. 133 | detection_anchor_indices: The anchor indices of the detections after NMS. 134 | detection_features: contains extracted features for each detected box 135 | after NMS. 136 | """ 137 | 138 | source_id = 'source_id' 139 | key = 'key' 140 | detection_boxes = 'detection_boxes' 141 | detection_scores = 'detection_scores' 142 | detection_multiclass_scores = 'detection_multiclass_scores' 143 | detection_features = 'detection_features' 144 | detection_classes = 'detection_classes' 145 | detection_masks = 'detection_masks' 146 | detection_boundaries = 'detection_boundaries' 147 | detection_keypoints = 'detection_keypoints' 148 | detection_keypoint_scores = 'detection_keypoint_scores' 149 | num_detections = 'num_detections' 150 | raw_detection_boxes = 'raw_detection_boxes' 151 | raw_detection_scores = 'raw_detection_scores' 152 | detection_anchor_indices = 'detection_anchor_indices' 153 | 154 | 155 | class BoxListFields(object): 156 | """Naming conventions for BoxLists. 157 | 158 | Attributes: 159 | boxes: bounding box coordinates. 160 | classes: classes per bounding box. 161 | scores: scores per bounding box. 162 | weights: sample weights per bounding box. 163 | objectness: objectness score per bounding box. 164 | masks: masks per bounding box. 165 | boundaries: boundaries per bounding box. 166 | keypoints: keypoints per bounding box. 167 | keypoint_heatmaps: keypoint heatmaps per bounding box. 168 | is_crowd: is_crowd annotation per bounding box. 169 | """ 170 | boxes = 'boxes' 171 | classes = 'classes' 172 | scores = 'scores' 173 | weights = 'weights' 174 | confidences = 'confidences' 175 | objectness = 'objectness' 176 | masks = 'masks' 177 | boundaries = 'boundaries' 178 | keypoints = 'keypoints' 179 | keypoint_heatmaps = 'keypoint_heatmaps' 180 | is_crowd = 'is_crowd' 181 | 182 | 183 | class PredictionFields(object): 184 | """Naming conventions for standardized prediction outputs. 185 | 186 | Attributes: 187 | feature_maps: List of feature maps for prediction. 188 | anchors: Generated anchors. 189 | raw_detection_boxes: Decoded detection boxes without NMS. 190 | raw_detection_feature_map_indices: Feature map indices from which each raw 191 | detection box was produced. 192 | """ 193 | feature_maps = 'feature_maps' 194 | anchors = 'anchors' 195 | raw_detection_boxes = 'raw_detection_boxes' 196 | raw_detection_feature_map_indices = 'raw_detection_feature_map_indices' 197 | 198 | 199 | class TfExampleFields(object): 200 | """TF-example proto feature names for object detection. 201 | 202 | Holds the standard feature names to load from an Example proto for object 203 | detection. 204 | 205 | Attributes: 206 | image_encoded: JPEG encoded string 207 | image_format: image format, e.g. "JPEG" 208 | filename: filename 209 | channels: number of channels of image 210 | colorspace: colorspace, e.g. "RGB" 211 | height: height of image in pixels, e.g. 462 212 | width: width of image in pixels, e.g. 581 213 | source_id: original source of the image 214 | image_class_text: image-level label in text format 215 | image_class_label: image-level label in numerical format 216 | object_class_text: labels in text format, e.g. ["person", "cat"] 217 | object_class_label: labels in numbers, e.g. [16, 8] 218 | object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 219 | object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 220 | object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 221 | object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 222 | object_view: viewpoint of object, e.g. ["frontal", "left"] 223 | object_truncated: is object truncated, e.g. [true, false] 224 | object_occluded: is object occluded, e.g. [true, false] 225 | object_difficult: is object difficult, e.g. [true, false] 226 | object_group_of: is object a single object or a group of objects 227 | object_depiction: is object a depiction 228 | object_is_crowd: [DEPRECATED, use object_group_of instead] 229 | is the object a single object or a crowd 230 | object_segment_area: the area of the segment. 231 | object_weight: a weight factor for the object's bounding box. 232 | instance_masks: instance segmentation masks. 233 | instance_boundaries: instance boundaries. 234 | instance_classes: Classes for each instance segmentation mask. 235 | detection_class_label: class label in numbers. 236 | detection_bbox_ymin: ymin coordinates of a detection box. 237 | detection_bbox_xmin: xmin coordinates of a detection box. 238 | detection_bbox_ymax: ymax coordinates of a detection box. 239 | detection_bbox_xmax: xmax coordinates of a detection box. 240 | detection_score: detection score for the class label and box. 241 | """ 242 | image_encoded = 'image/encoded' 243 | image_format = 'image/format' # format is reserved keyword 244 | filename = 'image/filename' 245 | channels = 'image/channels' 246 | colorspace = 'image/colorspace' 247 | height = 'image/height' 248 | width = 'image/width' 249 | source_id = 'image/source_id' 250 | image_class_text = 'image/class/text' 251 | image_class_label = 'image/class/label' 252 | object_class_text = 'image/object/class/text' 253 | object_class_label = 'image/object/class/label' 254 | object_bbox_ymin = 'image/object/bbox/ymin' 255 | object_bbox_xmin = 'image/object/bbox/xmin' 256 | object_bbox_ymax = 'image/object/bbox/ymax' 257 | object_bbox_xmax = 'image/object/bbox/xmax' 258 | object_view = 'image/object/view' 259 | object_truncated = 'image/object/truncated' 260 | object_occluded = 'image/object/occluded' 261 | object_difficult = 'image/object/difficult' 262 | object_group_of = 'image/object/group_of' 263 | object_depiction = 'image/object/depiction' 264 | object_is_crowd = 'image/object/is_crowd' 265 | object_segment_area = 'image/object/segment/area' 266 | object_weight = 'image/object/weight' 267 | instance_masks = 'image/segmentation/object' 268 | instance_boundaries = 'image/boundaries/object' 269 | instance_classes = 'image/segmentation/object/class' 270 | detection_class_label = 'image/detection/label' 271 | detection_bbox_ymin = 'image/detection/bbox/ymin' 272 | detection_bbox_xmin = 'image/detection/bbox/xmin' 273 | detection_bbox_ymax = 'image/detection/bbox/ymax' 274 | detection_bbox_xmax = 'image/detection/bbox/xmax' 275 | detection_score = 'image/detection/score' 276 | -------------------------------------------------------------------------------- /efficientdet/object_detection/target_assigner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base target assigner module. 16 | 17 | The job of a TargetAssigner is, for a given set of anchors (bounding boxes) and 18 | groundtruth detections (bounding boxes), to assign classification and regression 19 | targets to each anchor as well as weights to each anchor (specifying, e.g., 20 | which anchors should not contribute to training loss). 21 | 22 | It assigns classification/regression targets by performing the following steps: 23 | 1) Computing pairwise similarity between anchors and groundtruth boxes using a 24 | provided RegionSimilarity Calculator 25 | 2) Computing a matching based on the similarity matrix using a provided Matcher 26 | 3) Assigning regression targets based on the matching and a provided BoxCoder 27 | 4) Assigning classification targets based on the matching and groundtruth labels 28 | 29 | Note that TargetAssigners only operate on detections from a single 30 | image at a time, so any logic for applying a TargetAssigner to multiple 31 | images must be handled externally. 32 | """ 33 | import tensorflow.compat.v1 as tf 34 | 35 | from object_detection import box_list 36 | from object_detection import shape_utils 37 | 38 | 39 | KEYPOINTS_FIELD_NAME = 'keypoints' 40 | 41 | 42 | class TargetAssigner(object): 43 | """Target assigner to compute classification and regression targets.""" 44 | 45 | def __init__(self, similarity_calc, matcher, box_coder, 46 | negative_class_weight=1.0, unmatched_cls_target=None): 47 | """Construct Object Detection Target Assigner. 48 | 49 | Args: 50 | similarity_calc: a RegionSimilarityCalculator 51 | matcher: Matcher used to match groundtruth to anchors. 52 | box_coder: BoxCoder used to encode matching groundtruth boxes with 53 | respect to anchors. 54 | negative_class_weight: classification weight to be associated to negative 55 | anchors (default: 1.0). The weight must be in [0., 1.]. 56 | unmatched_cls_target: a float32 tensor with shape [d_1, d_2, ..., d_k] 57 | which is consistent with the classification target for each 58 | anchor (and can be empty for scalar targets). This shape must thus be 59 | compatible with the groundtruth labels that are passed to the "assign" 60 | function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]). 61 | If set to None, unmatched_cls_target is set to be [0] for each anchor. 62 | 63 | Raises: 64 | ValueError: if similarity_calc is not a RegionSimilarityCalculator or 65 | if matcher is not a Matcher or if box_coder is not a BoxCoder 66 | """ 67 | self._similarity_calc = similarity_calc 68 | self._matcher = matcher 69 | self._box_coder = box_coder 70 | self._negative_class_weight = negative_class_weight 71 | if unmatched_cls_target is None: 72 | self._unmatched_cls_target = tf.constant([0], tf.float32) 73 | else: 74 | self._unmatched_cls_target = unmatched_cls_target 75 | 76 | @property 77 | def box_coder(self): 78 | return self._box_coder 79 | 80 | def assign(self, anchors, groundtruth_boxes, groundtruth_labels=None, 81 | groundtruth_weights=None, **params): 82 | """Assign classification and regression targets to each anchor. 83 | 84 | For a given set of anchors and groundtruth detections, match anchors 85 | to groundtruth_boxes and assign classification and regression targets to 86 | each anchor as well as weights based on the resulting match (specifying, 87 | e.g., which anchors should not contribute to training loss). 88 | 89 | Anchors that are not matched to anything are given a classification target 90 | of self._unmatched_cls_target which can be specified via the constructor. 91 | 92 | Args: 93 | anchors: a BoxList representing N anchors 94 | groundtruth_boxes: a BoxList representing M groundtruth boxes 95 | groundtruth_labels: a tensor of shape [M, d_1, ... d_k] 96 | with labels for each of the ground_truth boxes. The subshape 97 | [d_1, ... d_k] can be empty (corresponding to scalar inputs). When set 98 | to None, groundtruth_labels assumes a binary problem where all 99 | ground_truth boxes get a positive label (of 1). 100 | groundtruth_weights: a float tensor of shape [M] indicating the weight to 101 | assign to all anchors match to a particular groundtruth box. The weights 102 | must be in [0., 1.]. If None, all weights are set to 1. 103 | **params: Additional keyword arguments for specific implementations of 104 | the Matcher. 105 | 106 | Returns: 107 | cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], 108 | where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels 109 | which has shape [num_gt_boxes, d_1, d_2, ... d_k]. 110 | cls_weights: a float32 tensor with shape [num_anchors] 111 | reg_targets: a float32 tensor with shape [num_anchors, box_code_dimension] 112 | reg_weights: a float32 tensor with shape [num_anchors] 113 | match: a matcher.Match object encoding the match between anchors and 114 | groundtruth boxes, with rows corresponding to groundtruth boxes 115 | and columns corresponding to anchors. 116 | 117 | Raises: 118 | ValueError: if anchors or groundtruth_boxes are not of type 119 | box_list.BoxList 120 | """ 121 | if not isinstance(anchors, box_list.BoxList): 122 | raise ValueError('anchors must be an BoxList') 123 | if not isinstance(groundtruth_boxes, box_list.BoxList): 124 | raise ValueError('groundtruth_boxes must be an BoxList') 125 | 126 | if groundtruth_labels is None: 127 | groundtruth_labels = tf.ones(tf.expand_dims(groundtruth_boxes.num_boxes(), 128 | 0)) 129 | groundtruth_labels = tf.expand_dims(groundtruth_labels, -1) 130 | unmatched_shape_assert = shape_utils.assert_shape_equal( 131 | shape_utils.combined_static_and_dynamic_shape(groundtruth_labels)[1:], 132 | shape_utils.combined_static_and_dynamic_shape( 133 | self._unmatched_cls_target)) 134 | labels_and_box_shapes_assert = shape_utils.assert_shape_equal( 135 | shape_utils.combined_static_and_dynamic_shape( 136 | groundtruth_labels)[:1], 137 | shape_utils.combined_static_and_dynamic_shape( 138 | groundtruth_boxes.get())[:1]) 139 | 140 | if groundtruth_weights is None: 141 | num_gt_boxes = groundtruth_boxes.num_boxes_static() 142 | if not num_gt_boxes: 143 | num_gt_boxes = groundtruth_boxes.num_boxes() 144 | groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32) 145 | with tf.control_dependencies( 146 | [unmatched_shape_assert, labels_and_box_shapes_assert]): 147 | match_quality_matrix = self._similarity_calc.compare(groundtruth_boxes, 148 | anchors) 149 | match = self._matcher.match(match_quality_matrix, **params) 150 | reg_targets = self._create_regression_targets(anchors, 151 | groundtruth_boxes, 152 | match) 153 | cls_targets = self._create_classification_targets(groundtruth_labels, 154 | match) 155 | reg_weights = self._create_regression_weights(match, groundtruth_weights) 156 | cls_weights = self._create_classification_weights(match, 157 | groundtruth_weights) 158 | 159 | num_anchors = anchors.num_boxes_static() 160 | if num_anchors is not None: 161 | reg_targets = self._reset_target_shape(reg_targets, num_anchors) 162 | cls_targets = self._reset_target_shape(cls_targets, num_anchors) 163 | reg_weights = self._reset_target_shape(reg_weights, num_anchors) 164 | cls_weights = self._reset_target_shape(cls_weights, num_anchors) 165 | 166 | return cls_targets, cls_weights, reg_targets, reg_weights, match 167 | 168 | def _reset_target_shape(self, target, num_anchors): 169 | """Sets the static shape of the target. 170 | 171 | Args: 172 | target: the target tensor. Its first dimension will be overwritten. 173 | num_anchors: the number of anchors, which is used to override the target's 174 | first dimension. 175 | 176 | Returns: 177 | A tensor with the shape info filled in. 178 | """ 179 | target_shape = target.get_shape().as_list() 180 | target_shape[0] = num_anchors 181 | target.set_shape(target_shape) 182 | return target 183 | 184 | def _create_regression_targets(self, anchors, groundtruth_boxes, match): 185 | """Returns a regression target for each anchor. 186 | 187 | Args: 188 | anchors: a BoxList representing N anchors 189 | groundtruth_boxes: a BoxList representing M groundtruth_boxes 190 | match: a matcher.Match object 191 | 192 | Returns: 193 | reg_targets: a float32 tensor with shape [N, box_code_dimension] 194 | """ 195 | matched_gt_boxes = match.gather_based_on_match( 196 | groundtruth_boxes.get(), 197 | unmatched_value=tf.zeros(4), 198 | ignored_value=tf.zeros(4)) 199 | matched_gt_boxlist = box_list.BoxList(matched_gt_boxes) 200 | if groundtruth_boxes.has_field(KEYPOINTS_FIELD_NAME): 201 | groundtruth_keypoints = groundtruth_boxes.get_field(KEYPOINTS_FIELD_NAME) 202 | matched_keypoints = match.gather_based_on_match( 203 | groundtruth_keypoints, 204 | unmatched_value=tf.zeros(groundtruth_keypoints.get_shape()[1:]), 205 | ignored_value=tf.zeros(groundtruth_keypoints.get_shape()[1:])) 206 | matched_gt_boxlist.add_field(KEYPOINTS_FIELD_NAME, matched_keypoints) 207 | matched_reg_targets = self._box_coder.encode(matched_gt_boxlist, anchors) 208 | match_results_shape = shape_utils.combined_static_and_dynamic_shape( 209 | match.match_results) 210 | 211 | # Zero out the unmatched and ignored regression targets. 212 | unmatched_ignored_reg_targets = tf.tile( 213 | self._default_regression_target(), [match_results_shape[0], 1]) 214 | matched_anchors_mask = match.matched_column_indicator() 215 | reg_targets = tf.where(matched_anchors_mask, 216 | matched_reg_targets, 217 | unmatched_ignored_reg_targets) 218 | return reg_targets 219 | 220 | def _default_regression_target(self): 221 | """Returns the default target for anchors to regress to. 222 | 223 | Default regression targets are set to zero (though in 224 | this implementation what these targets are set to should 225 | not matter as the regression weight of any box set to 226 | regress to the default target is zero). 227 | 228 | Returns: 229 | default_target: a float32 tensor with shape [1, box_code_dimension] 230 | """ 231 | return tf.constant([self._box_coder.code_size*[0]], tf.float32) 232 | 233 | def _create_classification_targets(self, groundtruth_labels, match): 234 | """Create classification targets for each anchor. 235 | 236 | Assign a classification target of for each anchor to the matching 237 | groundtruth label that is provided by match. Anchors that are not matched 238 | to anything are given the target self._unmatched_cls_target 239 | 240 | Args: 241 | groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k] 242 | with labels for each of the ground_truth boxes. The subshape 243 | [d_1, ... d_k] can be empty (corresponding to scalar labels). 244 | match: a matcher.Match object that provides a matching between anchors 245 | and groundtruth boxes. 246 | 247 | Returns: 248 | a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the 249 | subshape [d_1, ..., d_k] is compatible with groundtruth_labels which has 250 | shape [num_gt_boxes, d_1, d_2, ... d_k]. 251 | """ 252 | return match.gather_based_on_match( 253 | groundtruth_labels, 254 | unmatched_value=self._unmatched_cls_target, 255 | ignored_value=self._unmatched_cls_target) 256 | 257 | def _create_regression_weights(self, match, groundtruth_weights): 258 | """Set regression weight for each anchor. 259 | 260 | Only positive anchors are set to contribute to the regression loss, so this 261 | method returns a weight of 1 for every positive anchor and 0 for every 262 | negative anchor. 263 | 264 | Args: 265 | match: a matcher.Match object that provides a matching between anchors 266 | and groundtruth boxes. 267 | groundtruth_weights: a float tensor of shape [M] indicating the weight to 268 | assign to all anchors match to a particular groundtruth box. 269 | 270 | Returns: 271 | a float32 tensor with shape [num_anchors] representing regression weights. 272 | """ 273 | return match.gather_based_on_match( 274 | groundtruth_weights, ignored_value=0., unmatched_value=0.) 275 | 276 | def _create_classification_weights(self, 277 | match, 278 | groundtruth_weights): 279 | """Create classification weights for each anchor. 280 | 281 | Positive (matched) anchors are associated with a weight of 282 | positive_class_weight and negative (unmatched) anchors are associated with 283 | a weight of negative_class_weight. When anchors are ignored, weights are set 284 | to zero. By default, both positive/negative weights are set to 1.0, 285 | but they can be adjusted to handle class imbalance (which is almost always 286 | the case in object detection). 287 | 288 | Args: 289 | match: a matcher.Match object that provides a matching between anchors 290 | and groundtruth boxes. 291 | groundtruth_weights: a float tensor of shape [M] indicating the weight to 292 | assign to all anchors match to a particular groundtruth box. 293 | 294 | Returns: 295 | a float32 tensor with shape [num_anchors] representing classification 296 | weights. 297 | """ 298 | return match.gather_based_on_match( 299 | groundtruth_weights, 300 | ignored_value=0., 301 | unmatched_value=self._negative_class_weight) 302 | 303 | def get_box_coder(self): 304 | """Get BoxCoder of this TargetAssigner. 305 | 306 | Returns: 307 | BoxCoder object. 308 | """ 309 | return self._box_coder 310 | --------------------------------------------------------------------------------