├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 1.png ├── 2.png └── 3.png ├── configuration.py ├── core ├── __init__.py ├── centernet.py ├── loss.py └── models │ ├── __init__.py │ ├── dla.py │ ├── efficientdet.py │ ├── group_convolution.py │ └── resnet.py ├── data ├── __init__.py ├── dataloader.py ├── datasets │ └── README.md └── voc.py ├── saved_model └── README.md ├── test.py ├── test_pictures └── README.md ├── train.py ├── utils ├── __init__.py ├── gaussian.py └── visualize.py └── write_to_txt.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | /data/datasets/* 4 | !/data/datasets/README.md 5 | data.txt 6 | /saved_model/* 7 | !/saved_model/README.md 8 | /test_pictures/* 9 | !/test_pictures/README.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 calmisential 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CenterNet_TensorFlow2 2 | A tensorflow2.x implementation of CenterNet. 3 | 4 | ## Requirements: 5 | + Python >= 3.7 6 | + TensorFlow >= 2.2.0rc3 7 | + numpy 8 | + opencv-python 9 | 10 | ## Results 11 | The following are the detection results of some pictures in the PASCAL VOC 2012 dataset.
12 | ![img_1](https://github.com/calmisential/CenterNet_TensorFlow2/blob/master/assets/1.png)
13 | ![img_2](https://github.com/calmisential/CenterNet_TensorFlow2/blob/master/assets/2.png)
14 | ![img_3](https://github.com/calmisential/CenterNet_TensorFlow2/blob/master/assets/3.png) 15 | 16 | ## Usage 17 | ### Train on PASCAL VOC 2012 18 | 1. Download the [PASCAL VOC 2012 dataset](http://host.robots.ox.ac.uk/pascal/VOC/). 19 | 2. Unzip the file and place it in the 'data/datasets' folder, make sure the directory is like this : 20 | ``` 21 | |——data 22 | |——datasets 23 | |——VOCdevkit 24 | |——VOC2012 25 | |——Annotations 26 | |——ImageSets 27 | |——JPEGImages 28 | |——SegmentationClass 29 | |——SegmentationObject 30 | ``` 31 | 3. Run **write_to_txt.py** to generate **data.txt**. 32 | 4. Run **train.py** to start training, before that, you can change the value of the parameters in **configuration.py**. 33 | 34 | ### Test on single picture 35 | 1. Change the *test_single_image_dir* in **configuration.py**. 36 | 2. Run **test.py** to test on single picture. 37 | 38 | ## Acknowledgments 39 | 1. Official PyTorch implementation of CenterNet: https://github.com/xingyizhou/CenterNet 40 | 2. A TensorFlow implementation of CenterNet: https://github.com/MioChiu/TF_CenterNet 41 | 42 | 43 | ## References 44 | 1. [Objects as Points](https://arxiv.org/abs/1904.07850) -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calmiLovesAI/CenterNet_TensorFlow2/a30c9b4243c7c5f45a0cf47df655c171dcfef256/assets/1.png -------------------------------------------------------------------------------- /assets/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calmiLovesAI/CenterNet_TensorFlow2/a30c9b4243c7c5f45a0cf47df655c171dcfef256/assets/2.png -------------------------------------------------------------------------------- /assets/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calmiLovesAI/CenterNet_TensorFlow2/a30c9b4243c7c5f45a0cf47df655c171dcfef256/assets/3.png -------------------------------------------------------------------------------- /configuration.py: -------------------------------------------------------------------------------- 1 | 2 | class Config: 3 | epochs = 50 4 | batch_size = 8 5 | learning_rate_decay_epochs = 10 6 | 7 | # save model 8 | save_frequency = 5 9 | save_model_dir = "saved_model/" 10 | load_weights_before_training = False 11 | load_weights_from_epoch = 0 12 | 13 | # test image 14 | test_single_image_dir = "" 15 | test_images_during_training = False 16 | training_results_save_dir = "./test_pictures/" 17 | test_images_dir_list = ["", ""] 18 | 19 | image_size = {"resnet_18": (384, 384), "resnet_34": (384, 384), "resnet_50": (384, 384), 20 | "resnet_101": (384, 384), "resnet_152": (384, 384), 21 | "D0": (512, 512), "D1": (640, 640), "D2": (768, 768), 22 | "D3": (896, 896), "D4": (1024, 1024), "D5": (1280, 1280), 23 | "D6": (1408, 1408), "D7": (1536, 1536)} 24 | image_channels = 3 25 | 26 | # dataset 27 | num_classes = 20 28 | pascal_voc_root = "./data/datasets/VOCdevkit/VOC2012/" 29 | pascal_voc_images = pascal_voc_root + "JPEGImages" 30 | pascal_voc_labels = pascal_voc_root + "Annotations" 31 | 32 | pascal_voc_classes = {"person": 0, "bird": 1, "cat": 2, "cow": 3, "dog": 4, 33 | "horse": 5, "sheep": 6, "aeroplane": 7, "bicycle": 8, 34 | "boat": 9, "bus": 10, "car": 11, "motorbike": 12, 35 | "train": 13, "bottle": 14, "chair": 15, "diningtable": 16, 36 | "pottedplant": 17, "sofa": 18, "tvmonitor": 19} 37 | 38 | # txt file 39 | txt_file_dir = "data.txt" 40 | 41 | max_boxes_per_image = 50 42 | 43 | # network architecture 44 | 45 | backbone_name = "D0" 46 | # can be selected from: resnet_18, resnet_34, resnet_50, resnet_101, resnet_152, D0~D7 47 | 48 | downsampling_ratio = 8 # efficientdet: 8, others: 4 49 | 50 | # efficientdet 51 | width_coefficient = {"D0": 1.0, "D1": 1.0, "D2": 1.1, "D3": 1.2, "D4": 1.4, "D5": 1.6, "D6": 1.8, "D7": 1.8} 52 | depth_coefficient = {"D0": 1.0, "D1": 1.1, "D2": 1.2, "D3": 1.4, "D4": 1.8, "D5": 2.2, "D6": 2.6, "D7": 2.6} 53 | dropout_rate = {"D0": 0.2, "D1": 0.2, "D2": 0.3, "D3": 0.3, "D4": 0.4, "D5": 0.4, "D6": 0.5, "D7": 0.5} 54 | # bifpn channels 55 | w_bifpn = {"D0": 64, "D1": 88, "D2": 112, "D3": 160, "D4": 224, "D5": 288, "D6": 384, "D7": 384} 56 | # bifpn layers 57 | d_bifpn = {"D0": 2, "D1": 3, "D2": 4, "D3": 5, "D4": 6, "D5": 7, "D6": 8, "D7": 8} 58 | 59 | heads = {"heatmap": num_classes, "wh": 2, "reg": 2} 60 | head_conv = {"no_conv_layer": 0, "resnets": 64, "dla": 256, 61 | "D0": w_bifpn["D0"], "D1": w_bifpn["D1"], "D2": w_bifpn["D2"], "D3": w_bifpn["D3"], 62 | "D4": w_bifpn["D4"], "D5": w_bifpn["D5"], "D6": w_bifpn["D6"], "D7": w_bifpn["D7"]} 63 | 64 | 65 | # loss 66 | hm_weight = 1.0 67 | wh_weight = 0.1 68 | off_weight = 1.0 69 | 70 | score_threshold = 0.3 71 | 72 | @classmethod 73 | def get_image_size(cls): 74 | return cls.image_size[cls.backbone_name] 75 | 76 | @classmethod 77 | def get_width_coefficient(cls, backbone_name): 78 | return cls.width_coefficient[backbone_name] 79 | 80 | @classmethod 81 | def get_depth_coefficient(cls, backbone_name): 82 | return cls.depth_coefficient[backbone_name] 83 | 84 | @classmethod 85 | def get_dropout_rate(cls, backbone_name): 86 | return cls.dropout_rate[backbone_name] 87 | 88 | @classmethod 89 | def get_w_bifpn(cls, backbone_name): 90 | return cls.w_bifpn[backbone_name] 91 | 92 | @classmethod 93 | def get_d_bifpn(cls, backbone_name): 94 | return cls.d_bifpn[backbone_name] 95 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calmiLovesAI/CenterNet_TensorFlow2/a30c9b4243c7c5f45a0cf47df655c171dcfef256/core/__init__.py -------------------------------------------------------------------------------- /core/centernet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from configuration import Config 5 | from core.models.resnet import resnet_18, resnet_34, resnet_50, resnet_101, resnet_152 6 | from core.models.dla import dla_34, dla_60, dla_102, dla_169 7 | from core.models.efficientdet import d0, d1, d2, d3, d4, d5, d6, d7 8 | from data.dataloader import GT 9 | from core.loss import CombinedLoss, RegL1Loss 10 | 11 | backbone_zoo = {"resnet_18": resnet_18(), 12 | "resnet_34": resnet_34(), 13 | "resnet_50": resnet_50(), 14 | "resnet_101": resnet_101(), 15 | "resnet_152": resnet_152(), 16 | "dla_34": dla_34(), 17 | "dla_60": dla_60(), 18 | "dla_102": dla_102(), 19 | "dla_169": dla_169(), 20 | "D0": d0(), "D1": d1(), "D2": d2(), "D3": d3(), "D4": d4(), "D5": d5(), "D6": d6(), "D7": d7()} 21 | 22 | 23 | class CenterNet(tf.keras.Model): 24 | def __init__(self): 25 | super(CenterNet, self).__init__() 26 | self.backbone = backbone_zoo[Config.backbone_name] 27 | 28 | def call(self, inputs, training=None, mask=None): 29 | x = self.backbone(inputs, training=training) 30 | x = tf.concat(values=x, axis=-1) 31 | return x 32 | 33 | 34 | class PostProcessing: 35 | @staticmethod 36 | def training_procedure(batch_labels, pred): 37 | gt = GT(batch_labels) 38 | gt_heatmap, gt_reg, gt_wh, gt_reg_mask, gt_indices = gt.get_gt_values() 39 | loss_object = CombinedLoss() 40 | loss = loss_object(y_pred=pred, heatmap_true=gt_heatmap, reg_true=gt_reg, wh_true=gt_wh, reg_mask=gt_reg_mask, indices=gt_indices) 41 | return loss 42 | 43 | @staticmethod 44 | def testing_procedure(pred, original_image_size): 45 | decoder = Decoder(original_image_size) 46 | detections = decoder(pred) 47 | bboxes = detections[:, 0:4] 48 | scores = detections[:, 4] 49 | clses = detections[:, 5] 50 | return bboxes, scores, clses 51 | 52 | 53 | class Decoder: 54 | def __init__(self, original_image_size): 55 | self.K = Config.max_boxes_per_image 56 | self.original_image_size = np.array(original_image_size, dtype=np.float32) 57 | self.input_image_size = np.array(Config.get_image_size(), dtype=np.float32) 58 | self.downsampling_ratio = Config.downsampling_ratio 59 | self.score_threshold = Config.score_threshold 60 | 61 | def __call__(self, pred, *args, **kwargs): 62 | heatmap, reg, wh = tf.split(value=pred, num_or_size_splits=[Config.num_classes, 2, 2], axis=-1) 63 | heatmap = tf.math.sigmoid(heatmap) 64 | batch_size = heatmap.shape[0] 65 | heatmap = Decoder.__nms(heatmap) 66 | scores, inds, clses, ys, xs = Decoder.__topK(scores=heatmap, K=self.K) 67 | if reg is not None: 68 | reg = RegL1Loss.gather_feat(feat=reg, idx=inds) 69 | xs = tf.reshape(xs, shape=(batch_size, self.K, 1)) + reg[:, :, 0:1] 70 | ys = tf.reshape(ys, shape=(batch_size, self.K, 1)) + reg[:, :, 1:2] 71 | else: 72 | xs = tf.reshape(xs, shape=(batch_size, self.K, 1)) + 0.5 73 | ys = tf.reshape(ys, shape=(batch_size, self.K, 1)) + 0.5 74 | wh = RegL1Loss.gather_feat(feat=wh, idx=inds) 75 | clses = tf.cast(tf.reshape(clses, (batch_size, self.K, 1)), dtype=tf.float32) 76 | scores = tf.reshape(scores, (batch_size, self.K, 1)) 77 | bboxes = tf.concat(values=[xs - wh[..., 0:1] / 2, 78 | ys - wh[..., 1:2] / 2, 79 | xs + wh[..., 0:1] / 2, 80 | ys + wh[..., 1:2] / 2], axis=2) 81 | detections = tf.concat(values=[bboxes, scores, clses], axis=2) 82 | return self.__map_to_original(detections) 83 | 84 | def __map_to_original(self, detections): 85 | bboxes, scores, clses = tf.split(value=detections, num_or_size_splits=[4, 1, 1], axis=2) 86 | bboxes, scores, clses = bboxes.numpy()[0], scores.numpy()[0], clses.numpy()[0] 87 | resize_ratio = self.original_image_size / self.input_image_size 88 | bboxes[:, 0::2] = bboxes[:, 0::2] * self.downsampling_ratio * resize_ratio[1] 89 | bboxes[:, 1::2] = bboxes[:, 1::2] * self.downsampling_ratio * resize_ratio[0] 90 | bboxes[:, 0::2] = np.clip(a=bboxes[:, 0::2], a_min=0, a_max=self.original_image_size[1]) 91 | bboxes[:, 1::2] = np.clip(a=bboxes[:, 1::2], a_min=0, a_max=self.original_image_size[0]) 92 | score_mask = scores >= self.score_threshold 93 | bboxes, scores, clses = Decoder.__numpy_mask(bboxes, np.tile(score_mask, (1, 4))), Decoder.__numpy_mask(scores, score_mask), Decoder.__numpy_mask(clses, score_mask) 94 | detections = np.concatenate([bboxes, scores, clses], axis=-1) 95 | return detections 96 | 97 | @staticmethod 98 | def __numpy_mask(a, mask): 99 | return a[mask].reshape(-1, a.shape[-1]) 100 | 101 | @staticmethod 102 | def __nms(heatmap, pool_size=3): 103 | hmax = tf.keras.layers.MaxPool2D(pool_size=pool_size, strides=1, padding="same")(heatmap) 104 | keep = tf.cast(tf.equal(heatmap, hmax), tf.float32) 105 | return hmax * keep 106 | 107 | @staticmethod 108 | def __topK(scores, K): 109 | B, H, W, C = scores.shape 110 | scores = tf.reshape(scores, shape=(B, -1)) 111 | topk_scores, topk_inds = tf.math.top_k(input=scores, k=K, sorted=True) 112 | topk_clses = topk_inds % C 113 | topk_xs = tf.cast(topk_inds // C % W, tf.float32) 114 | topk_ys = tf.cast(topk_inds // C // W, tf.float32) 115 | topk_inds = tf.cast(topk_ys * tf.cast(W, tf.float32) + topk_xs, tf.int32) 116 | return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs 117 | -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from configuration import Config 4 | 5 | 6 | class FocalLoss: 7 | def __call__(self, y_true, y_pred, *args, **kwargs): 8 | return FocalLoss.__neg_loss(y_pred, y_true) 9 | 10 | @staticmethod 11 | def __neg_loss(pred, gt): 12 | pos_idx = tf.cast(tf.math.equal(gt, 1), dtype=tf.float32) 13 | neg_idx = tf.cast(tf.math.less(gt, 1), dtype=tf.float32) 14 | neg_weights = tf.math.pow(1 - gt, 4) 15 | 16 | loss = 0 17 | num_pos = tf.math.reduce_sum(pos_idx) 18 | pos_loss = tf.math.log(pred) * tf.math.pow(1 - pred, 2) * pos_idx 19 | pos_loss = tf.math.reduce_sum(pos_loss) 20 | neg_loss = tf.math.log(1 - pred) * tf.math.pow(pred, 2) * neg_weights * neg_idx 21 | neg_loss = tf.math.reduce_sum(neg_loss) 22 | 23 | if num_pos == 0: 24 | loss = loss - neg_loss 25 | else: 26 | loss = loss - (pos_loss + neg_loss) / num_pos 27 | return loss 28 | 29 | 30 | class RegL1Loss: 31 | def __call__(self, y_true, y_pred, mask, index, *args, **kwargs): 32 | y_pred = RegL1Loss.gather_feat(y_pred, index) 33 | mask = tf.tile(tf.expand_dims(mask, axis=-1), tf.constant([1, 1, 2], dtype=tf.int32)) 34 | loss = tf.math.reduce_sum(tf.abs(y_true * mask - y_pred * mask)) 35 | reg_loss = loss / (tf.math.reduce_sum(mask) + 1e-4) 36 | return reg_loss 37 | 38 | @staticmethod 39 | def gather_feat(feat, idx): 40 | feat = tf.reshape(feat, shape=(feat.shape[0], -1, feat.shape[-1])) 41 | idx = tf.cast(idx, dtype=tf.int32) 42 | feat = tf.gather(params=feat, indices=idx, batch_dims=1) 43 | return feat 44 | 45 | 46 | class CombinedLoss: 47 | def __init__(self): 48 | self.heatmap_loss_object = FocalLoss() 49 | self.reg_loss_object = RegL1Loss() 50 | self.wh_loss_object = RegL1Loss() 51 | 52 | def __call__(self, y_pred, heatmap_true, reg_true, wh_true, reg_mask, indices, *args, **kwargs): 53 | heatmap, reg, wh = tf.split(value=y_pred, num_or_size_splits=[Config.num_classes, 2, 2], axis=-1) 54 | heatmap = tf.clip_by_value(t=tf.math.sigmoid(heatmap), clip_value_min=1e-4, clip_value_max=1.0 - 1e-4) 55 | heatmap_loss = self.heatmap_loss_object(y_true=heatmap_true, y_pred=heatmap) 56 | off_loss = self.reg_loss_object(y_true=reg_true, y_pred=reg, mask=reg_mask, index=indices) 57 | wh_loss = self.wh_loss_object(y_true=wh_true, y_pred=wh, mask=reg_mask, index=indices) 58 | return Config.hm_weight * heatmap_loss + Config.off_weight * off_loss + Config.wh_weight * wh_loss 59 | -------------------------------------------------------------------------------- /core/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calmiLovesAI/CenterNet_TensorFlow2/a30c9b4243c7c5f45a0cf47df655c171dcfef256/core/models/__init__.py -------------------------------------------------------------------------------- /core/models/dla.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from core.models.group_convolution import GroupConv2D, GroupConv2DTranspose 4 | from configuration import Config 5 | 6 | 7 | class BasicBlock(tf.keras.layers.Layer): 8 | def __init__(self, in_channels, out_channels, stride=1): 9 | super(BasicBlock, self).__init__() 10 | self.conv1 = tf.keras.layers.Conv2D(filters=out_channels, kernel_size=(3, 3), strides=stride, 11 | padding="same", 12 | use_bias=False) 13 | self.bn1 = tf.keras.layers.BatchNormalization() 14 | self.conv2 = tf.keras.layers.Conv2D(filters=out_channels, kernel_size=(3, 3), strides=1, 15 | padding="same", 16 | use_bias=False) 17 | self.bn2 = tf.keras.layers.BatchNormalization() 18 | 19 | def call(self, inputs, training=None, residual=None, **kwargs): 20 | if residual is None: 21 | residual = inputs 22 | x = self.conv1(inputs) 23 | x = self.bn1(x, training=training) 24 | x = tf.nn.relu(x) 25 | x = self.conv2(x) 26 | x = self.bn2(x, training=training) 27 | outputs = tf.nn.relu(tf.keras.layers.add([residual, x])) 28 | return outputs 29 | 30 | 31 | class BottleNeck(tf.keras.layers.Layer): 32 | expansion = 2 33 | def __init__(self, in_channels, out_channels, stride=1): 34 | super(BottleNeck, self).__init__() 35 | temp_channels = out_channels // BottleNeck.expansion 36 | self.conv1 = tf.keras.layers.Conv2D(filters=temp_channels, kernel_size=(1, 1), strides=1, 37 | padding="same", 38 | use_bias=False) 39 | self.bn1 = tf.keras.layers.BatchNormalization() 40 | self.conv2 = tf.keras.layers.Conv2D(filters=temp_channels, kernel_size=(3, 3), strides=stride, 41 | padding="same", 42 | use_bias=False) 43 | self.bn2 = tf.keras.layers.BatchNormalization() 44 | self.conv3 = tf.keras.layers.Conv2D(filters=out_channels, kernel_size=(1, 1), strides=1, 45 | padding="same", 46 | use_bias=False) 47 | self.bn3 = tf.keras.layers.BatchNormalization() 48 | 49 | def call(self, inputs, training=None, residual=None, **kwargs): 50 | if residual is None: 51 | residual = inputs 52 | x = self.conv1(inputs) 53 | x = self.bn1(x, training=training) 54 | x = tf.nn.relu(x) 55 | x = self.conv2(x) 56 | x = self.bn2(x, training=training) 57 | x = tf.nn.relu(x) 58 | x = self.conv3(x) 59 | x = self.bn3(x, training=training) 60 | outputs = tf.nn.relu(tf.keras.layers.add([residual, x])) 61 | return outputs 62 | 63 | 64 | class BottleNeckX(tf.keras.layers.Layer): 65 | cardinality = 32 66 | def __init__(self, in_channels, out_channels, stride=1): 67 | super(BottleNeckX, self).__init__() 68 | temp_channels = out_channels * BottleNeckX.cardinality // 32 69 | self.conv1 = GroupConv2D(input_channels=in_channels, output_channels=temp_channels, 70 | kernel_size=(1, 1), strides=1, padding="same", use_bias=False) 71 | self.bn1 = tf.keras.layers.BatchNormalization() 72 | self.conv2 = GroupConv2D(input_channels=temp_channels, output_channels=temp_channels, 73 | kernel_size=(3, 3), strides=stride, padding="same", use_bias=False, 74 | groups=BottleNeckX.cardinality) 75 | self.bn2 = tf.keras.layers.BatchNormalization() 76 | self.conv3 = GroupConv2D(input_channels=temp_channels, output_channels=out_channels, 77 | kernel_size=(1, 1), strides=1, padding="same", use_bias=False) 78 | self.bn3 = tf.keras.layers.BatchNormalization() 79 | 80 | def call(self, inputs, training=None, residual=None, **kwargs): 81 | if residual is None: 82 | residual = inputs 83 | x = self.conv1(inputs) 84 | x = self.bn1(x, training=training) 85 | x = tf.nn.relu(x) 86 | x = self.conv2(x) 87 | x = self.bn2(x, training=training) 88 | x = tf.nn.relu(x) 89 | x = self.conv3(x) 90 | x = self.bn3(x, training=training) 91 | outputs = tf.nn.relu(tf.keras.layers.add([residual, x])) 92 | return outputs 93 | 94 | 95 | class Root(tf.keras.layers.Layer): 96 | def __init__(self, out_channels, residual): 97 | super(Root, self).__init__() 98 | self.conv = tf.keras.layers.Conv2D(filters=out_channels, kernel_size=(1, 1), padding="same", 99 | strides=1, use_bias=False) 100 | self.bn = tf.keras.layers.BatchNormalization() 101 | self.residual = residual 102 | 103 | def call(self, inputs, training=None, **kwargs): 104 | x = self.conv(tf.concat(values=inputs, axis=-1)) 105 | x = self.bn(x, training=training) 106 | if self.residual: 107 | x = tf.keras.layers.add([x, inputs[0]]) 108 | x = tf.nn.relu(x) 109 | return x 110 | 111 | 112 | class Tree(tf.keras.layers.Layer): 113 | def __init__(self, levels, block, in_channels, out_channels, stride=1, 114 | level_root=False, root_dim=0, root_kernel_size=1, 115 | root_residual=False): 116 | super(Tree, self).__init__() 117 | if root_dim == 0: 118 | root_dim = 2 * out_channels 119 | if level_root: 120 | root_dim += in_channels 121 | if levels == 1: 122 | self.tree1 = block(in_channels, out_channels, stride) 123 | self.tree2 = block(out_channels, out_channels, 1) 124 | else: 125 | self.tree1 = Tree(levels - 1, block, in_channels, out_channels, 126 | stride, root_dim=0, 127 | root_kernel_size=root_kernel_size, 128 | root_residual=root_residual) 129 | self.tree2 = Tree(levels - 1, block, out_channels, out_channels, 130 | root_dim=root_dim + out_channels, 131 | root_kernel_size=root_kernel_size, 132 | root_residual=root_residual) 133 | if levels == 1: 134 | self.root = Root(out_channels, root_residual) 135 | 136 | self.level_root = level_root 137 | self.root_dim = root_dim 138 | self.downsample = None 139 | self.project = None 140 | self.levels = levels 141 | 142 | if stride > 1: 143 | self.downsample = tf.keras.layers.MaxPool2D(pool_size=stride, strides=stride, padding="same") 144 | if in_channels != out_channels: 145 | self.project = tf.keras.Sequential([ 146 | tf.keras.layers.Conv2D(filters=out_channels, kernel_size=(1, 1), strides=1, padding="same", 147 | use_bias=False), 148 | tf.keras.layers.BatchNormalization() 149 | ]) 150 | 151 | def call(self, inputs, training=None, residual=None, children=None, **kwargs): 152 | children = [] if children is None else children 153 | bottom = self.downsample(inputs) if self.downsample else inputs 154 | residual = self.project(bottom, training=training) if self.project else bottom 155 | if self.level_root: 156 | children.append(bottom) 157 | x1 = self.tree1(inputs, training=training, residual=residual) 158 | if self.levels == 1: 159 | x2 = self.tree2(x1, training=training) 160 | outputs = self.root([x2, x1, *children], training=training) 161 | else: 162 | children.append(x1) 163 | outputs = self.tree2(x1, training=training, children=children) 164 | return outputs 165 | 166 | 167 | class DLA(tf.keras.layers.Layer): 168 | def __init__(self, levels, channels, num_classes=1000, block=BasicBlock, 169 | residual_root=False, return_levels=False, pool_size=7): 170 | super(DLA, self).__init__() 171 | self.channels = channels 172 | self.return_levels = return_levels 173 | self.num_classes = num_classes 174 | 175 | self.base_layer = tf.keras.Sequential([ 176 | tf.keras.layers.Conv2D(filters=channels[0], kernel_size=7, strides=1, padding="same", use_bias=False), 177 | tf.keras.layers.BatchNormalization(), 178 | tf.keras.layers.ReLU() 179 | ]) 180 | self.level_0 = DLA.__make_conv_level(out_channels=channels[0], convs=levels[0]) 181 | self.level_1 = DLA.__make_conv_level(out_channels=channels[1], convs=levels[1], stride=2) 182 | self.level_2 = Tree(levels=levels[2], block=block, in_channels=channels[1], 183 | out_channels=channels[2], stride=2, 184 | level_root=False, root_residual=residual_root) 185 | self.level_3 = Tree(levels=levels[3], block=block, in_channels=channels[2], 186 | out_channels=channels[3], stride=2, 187 | level_root=True, root_residual=residual_root) 188 | self.level_4 = Tree(levels=levels[4], block=block, in_channels=channels[3], 189 | out_channels=channels[4], stride=2, 190 | level_root=True, root_residual=residual_root) 191 | self.level_5 = Tree(levels=levels[5], block=block, in_channels=channels[4], 192 | out_channels=channels[5], stride=2, 193 | level_root=True, root_residual=residual_root) 194 | 195 | self.avgpool = tf.keras.layers.AveragePooling2D(pool_size=pool_size) 196 | self.final = tf.keras.layers.Conv2D(filters=num_classes, kernel_size=(1, 1), strides=1, 197 | padding="same", use_bias=True) 198 | 199 | @staticmethod 200 | def __make_conv_level(out_channels, convs, stride=1): 201 | layers = [] 202 | for i in range(convs): 203 | if i == 0: 204 | layers.extend([tf.keras.layers.Conv2D(filters=out_channels, 205 | kernel_size=(3, 3), 206 | strides=stride, 207 | padding="same", 208 | use_bias=False), 209 | tf.keras.layers.BatchNormalization(), 210 | tf.keras.layers.ReLU()]) 211 | else: 212 | layers.extend([tf.keras.layers.Conv2D(filters=out_channels, 213 | kernel_size=(3, 3), 214 | strides=1, 215 | padding="same", 216 | use_bias=False), 217 | tf.keras.layers.BatchNormalization(), 218 | tf.keras.layers.ReLU()]) 219 | return tf.keras.Sequential(layers) 220 | 221 | def call(self, inputs, training=None, **kwargs): 222 | y = [] 223 | x = self.base_layer(inputs, training=training) 224 | 225 | x = self.level_0(x, training=training) 226 | y.append(x) 227 | x = self.level_1(x, training=training) 228 | y.append(x) 229 | x = self.level_2(x, training=training) 230 | y.append(x) 231 | x = self.level_3(x, training=training) 232 | y.append(x) 233 | x = self.level_4(x, training=training) 234 | y.append(x) 235 | x = self.level_5(x, training=training) 236 | y.append(x) 237 | 238 | if self.return_levels: 239 | return y 240 | else: 241 | x = self.avgpool(x) 242 | x = self.final(x) 243 | x = tf.reshape(x, (x.shape[0], -1)) 244 | return x 245 | 246 | 247 | class Identity(tf.keras.layers.Layer): 248 | def __init__(self): 249 | super(Identity, self).__init__() 250 | 251 | def call(self, inputs, **kwargs): 252 | return inputs 253 | 254 | 255 | class IDAUp(tf.keras.layers.Layer): 256 | def __init__(self, node_kernel, out_dim, channels, up_factors): 257 | super(IDAUp, self).__init__() 258 | self.channels = channels 259 | self.out_dim = out_dim 260 | for i, c in enumerate(channels): 261 | if c == out_dim: 262 | proj = Identity() 263 | else: 264 | proj = tf.keras.Sequential([ 265 | tf.keras.layers.Conv2D(filters=out_dim, kernel_size=(1, 1), strides=1, padding="same", 266 | use_bias=False), 267 | tf.keras.layers.BatchNormalization(), 268 | tf.keras.layers.ReLU() 269 | ]) 270 | f = int(up_factors[i]) 271 | if f == 1: 272 | up = Identity() 273 | else: 274 | up = GroupConv2DTranspose(input_channels=out_dim, output_channels=out_dim, kernel_size=f * 2, 275 | strides=f, padding="same", groups=out_dim, use_bias=False) 276 | setattr(self, "proj_" + str(i), proj) 277 | setattr(self, "up_" + str(i), up) 278 | 279 | for i in range(1, len(channels)): 280 | node = tf.keras.Sequential([ 281 | tf.keras.layers.Conv2D(filters=out_dim, kernel_size=node_kernel, strides=1, 282 | padding="same", use_bias=False), 283 | tf.keras.layers.BatchNormalization(), 284 | tf.keras.layers.ReLU() 285 | ]) 286 | setattr(self, "node_" + str(i), node) 287 | 288 | def call(self, inputs, training=None, **kwargs): 289 | layers = list(inputs) 290 | for i, l in enumerate(layers): 291 | upsample = getattr(self, "up_" + str(i)) 292 | project = getattr(self, "proj_" + str(i)) 293 | layers[i] = upsample(project(l, training=training)) 294 | x = layers[0] 295 | y = [] 296 | for i in range(1, len(layers)): 297 | node = getattr(self, "node_" + str(i)) 298 | x = node(tf.concat([x, layers[i]], -1)) 299 | y.append(x) 300 | return x, y 301 | 302 | 303 | class DLAUp(tf.keras.layers.Layer): 304 | def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): 305 | super(DLAUp, self).__init__() 306 | if in_channels is None: 307 | in_channels = channels 308 | self.channels = channels 309 | channels = list(channels) 310 | scales = np.array(scales, dtype=np.int32) 311 | for i in range(len(channels) - 1): 312 | j = -i - 2 313 | setattr(self, 'ida_{}'.format(i), 314 | IDAUp(3, channels[j], in_channels[j:], 315 | scales[j:] // scales[j])) 316 | scales[j + 1:] = scales[j] 317 | in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] 318 | 319 | def call(self, inputs, training=None, **kwargs): 320 | layers = list(inputs) 321 | assert len(layers) > 1 322 | for i in range(len(layers) - 1): 323 | ida = getattr(self, 'ida_{}'.format(i)) 324 | x, y = ida(layers[-i - 2:], training=training) 325 | layers[-i - 1:] = y 326 | return x 327 | 328 | 329 | class DLASeg(tf.keras.layers.Layer): 330 | def __init__(self, base_name, heads, down_ratio=4, head_conv=256): 331 | super(DLASeg, self).__init__() 332 | self.heads = heads 333 | self.first_level = int(np.log2(down_ratio)) 334 | self.base = DLASeg.__get_base_block(base_name) 335 | channels = self.base.channels 336 | scales = [2 ** i for i in range(len(channels[self.first_level:]))] 337 | self.dla_up = DLAUp(channels[self.first_level:], scales=scales) 338 | for head in self.heads: 339 | classes = self.heads[head] 340 | if head_conv > 0: 341 | fc = tf.keras.Sequential([ 342 | tf.keras.layers.Conv2D(filters=head_conv, kernel_size=(3, 3), strides=1, 343 | padding="same", use_bias=True), 344 | tf.keras.layers.ReLU(), 345 | tf.keras.layers.Conv2D(filters=classes, kernel_size=(1, 1), strides=1, 346 | padding="same", use_bias=True) 347 | ]) 348 | else: 349 | fc = tf.keras.layers.Conv2D(filters=classes, kernel_size=(1, 1), strides=1, 350 | padding="same", use_bias=True) 351 | self.__setattr__(head, fc) 352 | 353 | 354 | @staticmethod 355 | def __get_base_block(base_name): 356 | if base_name == "dla34": 357 | return DLA(levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], block=BasicBlock, 358 | return_levels=True) 359 | elif base_name == "dla60": 360 | return DLA(levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], block=BottleNeck, 361 | return_levels=True) 362 | elif base_name == "dla102": 363 | return DLA(levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], block=BottleNeck, 364 | residual_root=True, return_levels=True) 365 | elif base_name == "dla169": 366 | return DLA(levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], block=BottleNeck, 367 | residual_root=True, return_levels=True) 368 | else: 369 | raise ValueError("The 'base_name' is invalid.") 370 | 371 | def call(self, inputs, training=None, **kwargs): 372 | x = self.base(inputs, training=training) 373 | x = self.dla_up(x[self.first_level:], training=training) 374 | outputs = [] 375 | for head in self.heads: 376 | outputs.append(self.__getattribute__(head)(x, training=training)) 377 | return outputs 378 | 379 | 380 | def dla_34(): 381 | return DLASeg(base_name="dla34", heads=Config.heads, down_ratio=Config.downsampling_ratio, 382 | head_conv=Config.head_conv["dla"]) 383 | 384 | 385 | def dla_60(): 386 | return DLASeg(base_name="dla60", heads=Config.heads, down_ratio=Config.downsampling_ratio, 387 | head_conv=Config.head_conv["dla"]) 388 | 389 | 390 | def dla_102(): 391 | return DLASeg(base_name="dla102", heads=Config.heads, down_ratio=Config.downsampling_ratio, 392 | head_conv=Config.head_conv["dla"]) 393 | 394 | 395 | def dla_169(): 396 | return DLASeg(base_name="dla169", heads=Config.heads, down_ratio=Config.downsampling_ratio, 397 | head_conv=Config.head_conv["dla"]) 398 | -------------------------------------------------------------------------------- /core/models/efficientdet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | 4 | from configuration import Config 5 | 6 | 7 | def round_filters(filters, multiplier): 8 | depth_divisor = 8 9 | min_depth = None 10 | min_depth = min_depth or depth_divisor 11 | filters = filters * multiplier 12 | new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor) 13 | if new_filters < 0.9 * filters: 14 | new_filters += depth_divisor 15 | return int(new_filters) 16 | 17 | 18 | def round_repeats(repeats, multiplier): 19 | if not multiplier: 20 | return repeats 21 | return int(math.ceil(multiplier * repeats)) 22 | 23 | 24 | class SEBlock(tf.keras.layers.Layer): 25 | def __init__(self, input_channels, ratio=0.25): 26 | super(SEBlock, self).__init__() 27 | self.num_reduced_filters = max(1, int(input_channels * ratio)) 28 | self.pool = tf.keras.layers.GlobalAveragePooling2D() 29 | self.reduce_conv = tf.keras.layers.Conv2D(filters=self.num_reduced_filters, 30 | kernel_size=(1, 1), 31 | strides=1, 32 | padding="same") 33 | self.expand_conv = tf.keras.layers.Conv2D(filters=input_channels, 34 | kernel_size=(1, 1), 35 | strides=1, 36 | padding="same") 37 | 38 | def call(self, inputs, **kwargs): 39 | branch = self.pool(inputs) 40 | branch = tf.expand_dims(input=branch, axis=1) 41 | branch = tf.expand_dims(input=branch, axis=1) 42 | branch = self.reduce_conv(branch) 43 | branch = tf.nn.swish(branch) 44 | branch = self.expand_conv(branch) 45 | branch = tf.nn.sigmoid(branch) 46 | output = inputs * branch 47 | return output 48 | 49 | 50 | class MBConv(tf.keras.layers.Layer): 51 | def __init__(self, in_channels, out_channels, expansion_factor, stride, k, drop_connect_rate): 52 | super(MBConv, self).__init__() 53 | self.in_channels = in_channels 54 | self.out_channels = out_channels 55 | self.stride = stride 56 | self.drop_connect_rate = drop_connect_rate 57 | self.conv1 = tf.keras.layers.Conv2D(filters=in_channels * expansion_factor, 58 | kernel_size=(1, 1), 59 | strides=1, 60 | padding="same", 61 | use_bias=False) 62 | self.bn1 = tf.keras.layers.BatchNormalization() 63 | self.dwconv = tf.keras.layers.DepthwiseConv2D(kernel_size=(k, k), 64 | strides=stride, 65 | padding="same", 66 | use_bias=False) 67 | self.bn2 = tf.keras.layers.BatchNormalization() 68 | self.se = SEBlock(input_channels=in_channels * expansion_factor) 69 | self.conv2 = tf.keras.layers.Conv2D(filters=out_channels, 70 | kernel_size=(1, 1), 71 | strides=1, 72 | padding="same", 73 | use_bias=False) 74 | self.bn3 = tf.keras.layers.BatchNormalization() 75 | self.dropout = tf.keras.layers.Dropout(rate=drop_connect_rate) 76 | 77 | def call(self, inputs, training=None, **kwargs): 78 | x = self.conv1(inputs) 79 | x = self.bn1(x, training=training) 80 | x = tf.nn.swish(x) 81 | x = self.dwconv(x) 82 | x = self.bn2(x, training=training) 83 | x = self.se(x) 84 | x = tf.nn.swish(x) 85 | x = self.conv2(x) 86 | x = self.bn3(x, training=training) 87 | if self.stride == 1 and self.in_channels == self.out_channels: 88 | if self.drop_connect_rate: 89 | x = self.dropout(x, training=training) 90 | x = tf.keras.layers.add([x, inputs]) 91 | return x 92 | 93 | 94 | def build_mbconv_block(in_channels, out_channels, layers, stride, expansion_factor, k, drop_connect_rate): 95 | block = tf.keras.Sequential() 96 | for i in range(layers): 97 | if i == 0: 98 | block.add(MBConv(in_channels=in_channels, 99 | out_channels=out_channels, 100 | expansion_factor=expansion_factor, 101 | stride=stride, 102 | k=k, 103 | drop_connect_rate=drop_connect_rate)) 104 | else: 105 | block.add(MBConv(in_channels=out_channels, 106 | out_channels=out_channels, 107 | expansion_factor=expansion_factor, 108 | stride=1, 109 | k=k, 110 | drop_connect_rate=drop_connect_rate)) 111 | return block 112 | 113 | 114 | class EfficientNet(tf.keras.Model): 115 | def __init__(self, width_coefficient, depth_coefficient, dropout_rate, drop_connect_rate=0.2): 116 | super(EfficientNet, self).__init__() 117 | 118 | self.conv1 = tf.keras.layers.Conv2D(filters=round_filters(32, width_coefficient), 119 | kernel_size=(3, 3), 120 | strides=2, 121 | padding="same", 122 | use_bias=False) 123 | self.bn1 = tf.keras.layers.BatchNormalization() 124 | self.block1 = build_mbconv_block(in_channels=round_filters(32, width_coefficient), 125 | out_channels=round_filters(16, width_coefficient), 126 | layers=round_repeats(1, depth_coefficient), 127 | stride=1, 128 | expansion_factor=1, k=3, drop_connect_rate=drop_connect_rate) 129 | self.block2 = build_mbconv_block(in_channels=round_filters(16, width_coefficient), 130 | out_channels=round_filters(24, width_coefficient), 131 | layers=round_repeats(2, depth_coefficient), 132 | stride=2, 133 | expansion_factor=6, k=3, drop_connect_rate=drop_connect_rate) 134 | self.block3 = build_mbconv_block(in_channels=round_filters(24, width_coefficient), 135 | out_channels=round_filters(40, width_coefficient), 136 | layers=round_repeats(2, depth_coefficient), 137 | stride=2, 138 | expansion_factor=6, k=5, drop_connect_rate=drop_connect_rate) 139 | self.block4 = build_mbconv_block(in_channels=round_filters(40, width_coefficient), 140 | out_channels=round_filters(80, width_coefficient), 141 | layers=round_repeats(3, depth_coefficient), 142 | stride=2, 143 | expansion_factor=6, k=3, drop_connect_rate=drop_connect_rate) 144 | self.block5 = build_mbconv_block(in_channels=round_filters(80, width_coefficient), 145 | out_channels=round_filters(112, width_coefficient), 146 | layers=round_repeats(3, depth_coefficient), 147 | stride=2, 148 | expansion_factor=6, k=5, drop_connect_rate=drop_connect_rate) 149 | self.block6 = build_mbconv_block(in_channels=round_filters(112, width_coefficient), 150 | out_channels=round_filters(192, width_coefficient), 151 | layers=round_repeats(4, depth_coefficient), 152 | stride=2, 153 | expansion_factor=6, k=5, drop_connect_rate=drop_connect_rate) 154 | self.block7 = build_mbconv_block(in_channels=round_filters(192, width_coefficient), 155 | out_channels=round_filters(320, width_coefficient), 156 | layers=round_repeats(1, depth_coefficient), 157 | stride=2, 158 | expansion_factor=6, k=3, drop_connect_rate=drop_connect_rate) 159 | 160 | def call(self, inputs, training=None, mask=None): 161 | features = [] 162 | x = self.conv1(inputs) 163 | x = self.bn1(x, training=training) 164 | x = tf.nn.swish(x) 165 | 166 | x = self.block1(x) 167 | x = self.block2(x) 168 | x = self.block3(x) 169 | features.append(x) 170 | x = self.block4(x) 171 | features.append(x) 172 | x = self.block5(x) 173 | features.append(x) 174 | x = self.block6(x) 175 | features.append(x) 176 | x = self.block7(x) 177 | features.append(x) 178 | 179 | return features 180 | 181 | 182 | def get_efficient_net(width_coefficient, depth_coefficient, dropout_rate): 183 | net = EfficientNet(width_coefficient=width_coefficient, 184 | depth_coefficient=depth_coefficient, 185 | dropout_rate=dropout_rate) 186 | 187 | return net 188 | 189 | 190 | class BiFPN(tf.keras.layers.Layer): 191 | def __init__(self, output_channels, layers): 192 | super(BiFPN, self).__init__() 193 | self.levels = 5 194 | self.output_channels = output_channels 195 | self.layers = layers 196 | self.transform_convs = [] 197 | self.bifpn_modules = [] 198 | for _ in range(self.levels): 199 | self.transform_convs.append(ConvNormAct(filters=output_channels, 200 | kernel_size=(1, 1), 201 | strides=1, 202 | padding="same")) 203 | for _ in range(self.layers): 204 | self.bifpn_modules.append(BiFPNModule(self.output_channels)) 205 | 206 | def call(self, inputs, training=None, **kwargs): 207 | """ 208 | :param inputs: list of features 209 | :param training: 210 | :param kwargs: 211 | :return: list of features 212 | """ 213 | assert len(inputs) == self.levels 214 | x = [] 215 | for i in range(len(inputs)): 216 | x.append(self.transform_convs[i](inputs[i], training=training)) 217 | for j in range(self.layers): 218 | x = self.bifpn_modules[j](x, training=training) 219 | return x 220 | 221 | 222 | class BiFPNModule(tf.keras.layers.Layer): 223 | def __init__(self, out_channels): 224 | super(BiFPNModule, self).__init__() 225 | self.w_fusion_list = [] 226 | self.conv_list = [] 227 | for i in range(8): 228 | self.w_fusion_list.append(WeightedFeatureFusion(out_channels)) 229 | self.upsampling_1 = tf.keras.layers.UpSampling2D(size=(2, 2)) 230 | self.upsampling_2 = tf.keras.layers.UpSampling2D(size=(2, 2)) 231 | self.upsampling_3 = tf.keras.layers.UpSampling2D(size=(2, 2)) 232 | self.upsampling_4 = tf.keras.layers.UpSampling2D(size=(2, 2)) 233 | self.maxpool_1 = tf.keras.layers.MaxPool2D(pool_size=(2, 2)) 234 | self.maxpool_2 = tf.keras.layers.MaxPool2D(pool_size=(2, 2)) 235 | self.maxpool_3 = tf.keras.layers.MaxPool2D(pool_size=(2, 2)) 236 | self.maxpool_4 = tf.keras.layers.MaxPool2D(pool_size=(2, 2)) 237 | 238 | def call(self, inputs, training=None, **kwargs): 239 | """ 240 | :param inputs: list of features 241 | :param training: 242 | :param kwargs: 243 | :return: 244 | """ 245 | assert len(inputs) == 5 246 | f3, f4, f5, f6, f7 = inputs 247 | f6_d = self.w_fusion_list[0]([f6, self.upsampling_1(f7)], training=training) 248 | f5_d = self.w_fusion_list[1]([f5, self.upsampling_2(f6_d)], training=training) 249 | f4_d = self.w_fusion_list[2]([f4, self.upsampling_3(f5_d)], training=training) 250 | 251 | f3_u = self.w_fusion_list[3]([f3, self.upsampling_4(f4_d)], training=training) 252 | f4_u = self.w_fusion_list[4]([f4, f4_d, self.maxpool_1(f3_u)], training=training) 253 | f5_u = self.w_fusion_list[5]([f5, f5_d, self.maxpool_2(f4_u)], training=training) 254 | f6_u = self.w_fusion_list[6]([f6, f6_d, self.maxpool_3(f5_u)], training=training) 255 | f7_u = self.w_fusion_list[7]([f7, self.maxpool_4(f6_u)], training=training) 256 | 257 | return [f3_u, f4_u, f5_u, f6_u, f7_u] 258 | 259 | 260 | class SeparableConvNormAct(tf.keras.layers.Layer): 261 | def __init__(self, 262 | filters, 263 | kernel_size, 264 | strides, 265 | padding): 266 | super(SeparableConvNormAct, self).__init__() 267 | self.conv = tf.keras.layers.SeparableConv2D(filters=filters, 268 | kernel_size=kernel_size, 269 | strides=strides, 270 | padding=padding) 271 | self.bn = tf.keras.layers.BatchNormalization() 272 | 273 | def call(self, inputs, training=None, **kwargs): 274 | x = self.conv(inputs) 275 | x = self.bn(x, training=training) 276 | x = tf.nn.swish(x) 277 | return x 278 | 279 | 280 | class ConvNormAct(tf.keras.layers.Layer): 281 | def __init__(self, 282 | filters, 283 | kernel_size, 284 | strides, 285 | padding): 286 | super(ConvNormAct, self).__init__() 287 | self.conv = tf.keras.layers.Conv2D(filters=filters, 288 | kernel_size=kernel_size, 289 | strides=strides, 290 | padding=padding) 291 | self.bn = tf.keras.layers.BatchNormalization() 292 | 293 | def call(self, inputs, training=None, **kwargs): 294 | x = self.conv(inputs) 295 | x = self.bn(x, training=training) 296 | x = tf.nn.swish(x) 297 | return x 298 | 299 | 300 | class WeightedFeatureFusion(tf.keras.layers.Layer): 301 | def __init__(self, out_channels): 302 | super(WeightedFeatureFusion, self).__init__() 303 | self.epsilon = 1e-4 304 | self.conv = SeparableConvNormAct(filters=out_channels, kernel_size=(3, 3), strides=1, padding="same") 305 | 306 | def build(self, input_shape): 307 | self.num_features = len(input_shape) 308 | assert self.num_features >= 2 309 | self.fusion_weights = self.add_weight(name="fusion_w", 310 | shape=(self.num_features, ), 311 | dtype=tf.dtypes.float32, 312 | initializer=tf.constant_initializer(value=1.0 / self.num_features), 313 | trainable=True) 314 | 315 | def call(self, inputs, training=None, **kwargs): 316 | """ 317 | :param inputs: list of features 318 | :param kwargs: 319 | :return: 320 | """ 321 | fusion_w = tf.nn.relu(self.fusion_weights) 322 | sum_features = [] 323 | for i in range(self.num_features): 324 | sum_features.append(fusion_w[i] * inputs[i]) 325 | output_feature = tf.reduce_sum(input_tensor=sum_features, axis=0) / (tf.reduce_sum(input_tensor=fusion_w) + self.epsilon) 326 | output_feature = self.conv(output_feature, training=training) 327 | return output_feature 328 | 329 | 330 | class TransposeLayer(tf.keras.layers.Layer): 331 | def __init__(self, out_channels, num_layers=5): 332 | super(TransposeLayer, self).__init__() 333 | self.layers = num_layers 334 | self.transpose_layers = [] 335 | for i in range(self.layers - 1): 336 | self.transpose_layers.append(tf.keras.Sequential([ 337 | tf.keras.layers.Conv2DTranspose(filters=out_channels, kernel_size=(4, 4), strides=2, padding="same"), 338 | tf.keras.layers.BatchNormalization(), 339 | ])) 340 | 341 | def call(self, inputs, training=None, **kwargs): 342 | assert len(inputs) == self.layers 343 | f3, f4, f5, f6, f7 = inputs 344 | f6 += tf.nn.swish(self.transpose_layers[0](f7, training=training)) 345 | f5 += tf.nn.swish(self.transpose_layers[1](f6, training=training)) 346 | f4 += tf.nn.swish(self.transpose_layers[2](f5, training=training)) 347 | f3 += tf.nn.swish(self.transpose_layers[3](f4, training=training)) 348 | return f3 349 | 350 | 351 | class EfficientDet(tf.keras.layers.Layer): 352 | def __init__(self, efficient_det): 353 | super(EfficientDet, self).__init__() 354 | self.heads = Config.heads 355 | self.head_conv = Config.head_conv[efficient_det] 356 | self.efficient_net = get_efficient_net(width_coefficient=Config.get_width_coefficient(efficient_det), 357 | depth_coefficient=Config.get_depth_coefficient(efficient_det), 358 | dropout_rate=Config.get_dropout_rate(efficient_det)) 359 | self.bifpn = BiFPN(output_channels=Config.get_w_bifpn(efficient_det), layers=Config.get_d_bifpn(efficient_det)) 360 | self.transpose = TransposeLayer(out_channels=Config.get_w_bifpn(efficient_det)) 361 | for head in self.heads: 362 | classes = self.heads[head] 363 | if self.head_conv > 0: 364 | fc = tf.keras.Sequential([ 365 | tf.keras.layers.Conv2D(filters=self.head_conv, kernel_size=(3, 3), strides=1, 366 | padding="same", use_bias=True), 367 | tf.keras.layers.ReLU(), 368 | tf.keras.layers.Conv2D(filters=classes, kernel_size=(1, 1), strides=1, 369 | padding="same", use_bias=True) 370 | ]) 371 | else: 372 | fc = tf.keras.layers.Conv2D(filters=classes, kernel_size=(1, 1), strides=1, 373 | padding="same", use_bias=True) 374 | self.__setattr__(head, fc) 375 | 376 | def call(self, inputs, training=None, **kwargs): 377 | x = self.efficient_net(inputs, training=training) 378 | x = self.bifpn(x, training=training) 379 | x = self.transpose(x, training=training) 380 | outputs = [] 381 | for head in self.heads: 382 | outputs.append(self.__getattribute__(head)(x, training=training)) 383 | return outputs 384 | 385 | 386 | def d0(): 387 | return EfficientDet("D0") 388 | 389 | 390 | def d1(): 391 | return EfficientDet("D1") 392 | 393 | 394 | def d2(): 395 | return EfficientDet("D2") 396 | 397 | 398 | def d3(): 399 | return EfficientDet("D3") 400 | 401 | 402 | def d4(): 403 | return EfficientDet("D4") 404 | 405 | 406 | def d5(): 407 | return EfficientDet("D5") 408 | 409 | 410 | def d6(): 411 | return EfficientDet("D6") 412 | 413 | 414 | def d7(): 415 | return EfficientDet("D7") -------------------------------------------------------------------------------- /core/models/group_convolution.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import initializers, regularizers, constraints 3 | from tensorflow.keras import activations 4 | 5 | 6 | class GroupConv2D(tf.keras.layers.Layer): 7 | def __init__(self, 8 | input_channels, 9 | output_channels, 10 | kernel_size, 11 | strides=(1, 1), 12 | padding='valid', 13 | data_format=None, 14 | dilation_rate=(1, 1), 15 | activation=None, 16 | groups=1, 17 | use_bias=True, 18 | kernel_initializer='glorot_uniform', 19 | bias_initializer='zeros', 20 | kernel_regularizer=None, 21 | bias_regularizer=None, 22 | activity_regularizer=None, 23 | kernel_constraint=None, 24 | bias_constraint=None, 25 | **kwargs): 26 | super(GroupConv2D, self).__init__() 27 | 28 | if not input_channels % groups == 0: 29 | raise ValueError("The value of input_channels must be divisible by the value of groups.") 30 | if not output_channels % groups == 0: 31 | raise ValueError("The value of output_channels must be divisible by the value of groups.") 32 | 33 | self.input_channels = input_channels 34 | self.output_channels = output_channels 35 | self.kernel_size = kernel_size 36 | self.strides = strides 37 | self.padding = padding 38 | self.data_format = data_format 39 | self.dilation_rate = dilation_rate 40 | self.activation = activation 41 | self.groups = groups 42 | self.use_bias = use_bias 43 | self.kernel_initializer = kernel_initializer 44 | self.bias_initializer = bias_initializer 45 | self.kernel_regularizer = kernel_regularizer 46 | self.bias_regularizer = bias_regularizer 47 | self.activity_regularizer = activity_regularizer 48 | self.kernel_constraint = kernel_constraint 49 | self.bias_constraint = bias_constraint 50 | 51 | self.group_in_num = input_channels // groups 52 | self.group_out_num = output_channels // groups 53 | self.conv_list = [] 54 | for i in range(self.groups): 55 | self.conv_list.append(tf.keras.layers.Conv2D(filters=self.group_out_num, 56 | kernel_size=kernel_size, 57 | strides=strides, 58 | padding=padding, 59 | data_format=data_format, 60 | dilation_rate=dilation_rate, 61 | activation=activations.get(activation), 62 | use_bias=use_bias, 63 | kernel_initializer=initializers.get(kernel_initializer), 64 | bias_initializer=initializers.get(bias_initializer), 65 | kernel_regularizer=regularizers.get(kernel_regularizer), 66 | bias_regularizer=regularizers.get(bias_regularizer), 67 | activity_regularizer=regularizers.get(activity_regularizer), 68 | kernel_constraint=constraints.get(kernel_constraint), 69 | bias_constraint=constraints.get(bias_constraint), 70 | **kwargs)) 71 | 72 | def call(self, inputs, **kwargs): 73 | feature_map_list = [] 74 | for i in range(self.groups): 75 | x_i = self.conv_list[i](inputs[:, :, :, i*self.group_in_num: (i + 1) * self.group_in_num]) 76 | feature_map_list.append(x_i) 77 | out = tf.concat(feature_map_list, axis=-1) 78 | return out 79 | 80 | def get_config(self): 81 | config = { 82 | "input_channels": self.input_channels, 83 | "output_channels": self.output_channels, 84 | "kernel_size": self.kernel_size, 85 | "strides": self.strides, 86 | "padding": self.padding, 87 | "data_format": self.data_format, 88 | "dilation_rate": self.dilation_rate, 89 | "activation": activations.serialize(self.activation), 90 | "groups": self.groups, 91 | "use_bias": self.use_bias, 92 | "kernel_initializer": initializers.serialize(self.kernel_initializer), 93 | "bias_initializer": initializers.serialize(self.bias_initializer), 94 | "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), 95 | "bias_regularizer": regularizers.serialize(self.bias_regularizer), 96 | "activity_regularizer": regularizers.serialize(self.activity_regularizer), 97 | "kernel_constraint": constraints.serialize(self.kernel_constraint), 98 | "bias_constraint": constraints.serialize(self.bias_constraint) 99 | } 100 | base_config = super(GroupConv2D, self).get_config() 101 | return {**base_config, **config} 102 | 103 | 104 | class GroupConv2DTranspose(tf.keras.layers.Layer): 105 | def __init__(self, 106 | input_channels, 107 | output_channels, 108 | kernel_size, 109 | strides=(1, 1), 110 | padding='valid', 111 | output_padding=None, 112 | data_format=None, 113 | dilation_rate=(1, 1), 114 | activation=None, 115 | groups=1, 116 | use_bias=True, 117 | kernel_initializer='glorot_uniform', 118 | bias_initializer='zeros', 119 | kernel_regularizer=None, 120 | bias_regularizer=None, 121 | activity_regularizer=None, 122 | kernel_constraint=None, 123 | bias_constraint=None, 124 | **kwargs 125 | ): 126 | super(GroupConv2DTranspose, self).__init__() 127 | 128 | if not input_channels % groups == 0: 129 | raise ValueError("The value of input_channels must be divisible by the value of groups.") 130 | if not output_channels % groups == 0: 131 | raise ValueError("The value of output_channels must be divisible by the value of groups.") 132 | 133 | self.input_channels = input_channels 134 | self.output_channels = output_channels 135 | self.kernel_size = kernel_size 136 | self.strides = strides 137 | self.padding = padding 138 | self.output_padding = output_padding 139 | self.data_format = data_format 140 | self.dilation_rate = dilation_rate 141 | self.activation = activation 142 | self.groups = groups 143 | self.use_bias = use_bias 144 | self.kernel_initializer = kernel_initializer 145 | self.bias_initializer = bias_initializer 146 | self.kernel_regularizer = kernel_regularizer 147 | self.bias_regularizer = bias_regularizer 148 | self.activity_regularizer = activity_regularizer 149 | self.kernel_constraint = kernel_constraint 150 | self.bias_constraint = bias_constraint 151 | 152 | self.group_in_num = input_channels // groups 153 | self.group_out_num = output_channels // groups 154 | self.conv_list = [] 155 | for i in range(self.groups): 156 | self.conv_list.append(tf.keras.layers.Conv2DTranspose(filters=self.group_out_num, 157 | kernel_size=kernel_size, 158 | strides=strides, 159 | padding=padding, 160 | output_padding=output_padding, 161 | data_format=data_format, 162 | dilation_rate=dilation_rate, 163 | activation=activations.get(activation), 164 | use_bias=use_bias, 165 | kernel_initializer=initializers.get(kernel_initializer), 166 | bias_initializer=initializers.get(bias_initializer), 167 | kernel_regularizer=regularizers.get(kernel_regularizer), 168 | bias_regularizer=regularizers.get(bias_regularizer), 169 | activity_regularizer=regularizers.get(activity_regularizer), 170 | kernel_constraint=constraints.get(kernel_constraint), 171 | bias_constraint=constraints.get(bias_constraint), 172 | **kwargs)) 173 | 174 | def call(self, inputs, **kwargs): 175 | feature_map_list = [] 176 | for i in range(self.groups): 177 | x_i = self.conv_list[i](inputs[:, :, :, i*self.group_in_num: (i + 1) * self.group_in_num]) 178 | feature_map_list.append(x_i) 179 | out = tf.concat(feature_map_list, axis=-1) 180 | return out 181 | 182 | def get_config(self): 183 | config = { 184 | "input_channels": self.input_channels, 185 | "output_channels": self.output_channels, 186 | "kernel_size": self.kernel_size, 187 | "strides": self.strides, 188 | "padding": self.padding, 189 | "output_padding": self.output_padding, 190 | "data_format": self.data_format, 191 | "dilation_rate": self.dilation_rate, 192 | "activation": activations.serialize(self.activation), 193 | "groups": self.groups, 194 | "use_bias": self.use_bias, 195 | "kernel_initializer": initializers.serialize(self.kernel_initializer), 196 | "bias_initializer": initializers.serialize(self.bias_initializer), 197 | "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), 198 | "bias_regularizer": regularizers.serialize(self.bias_regularizer), 199 | "activity_regularizer": regularizers.serialize(self.activity_regularizer), 200 | "kernel_constraint": constraints.serialize(self.kernel_constraint), 201 | "bias_constraint": constraints.serialize(self.bias_constraint) 202 | } 203 | base_config = super(GroupConv2DTranspose, self).get_config() 204 | return {**base_config, **config} 205 | -------------------------------------------------------------------------------- /core/models/resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from configuration import Config 4 | 5 | 6 | class BasicBlock(tf.keras.layers.Layer): 7 | def __init__(self, filter_num, stride=1): 8 | super(BasicBlock, self).__init__() 9 | self.conv1 = tf.keras.layers.Conv2D(filters=filter_num, 10 | kernel_size=(3, 3), 11 | strides=stride, 12 | padding="same", 13 | use_bias=False) 14 | self.bn1 = tf.keras.layers.BatchNormalization() 15 | self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, 16 | kernel_size=(3, 3), 17 | strides=1, 18 | padding="same", 19 | use_bias=False) 20 | self.bn2 = tf.keras.layers.BatchNormalization() 21 | if stride != 1: 22 | self.downsample = tf.keras.Sequential([ 23 | tf.keras.layers.Conv2D(filters=filter_num, 24 | kernel_size=(1, 1), 25 | strides=stride, 26 | use_bias=False), 27 | tf.keras.layers.BatchNormalization() 28 | ]) 29 | else: 30 | self.downsample = tf.keras.layers.Lambda(lambda x: x) 31 | 32 | def call(self, inputs, training=None, **kwargs): 33 | residual = self.downsample(inputs, training=training) 34 | x = self.conv1(inputs) 35 | x = self.bn1(x, training=training) 36 | x = tf.nn.relu(x) 37 | x = self.conv2(x) 38 | x = self.bn2(x, training=training) 39 | output = tf.nn.relu(tf.keras.layers.add([residual, x])) 40 | return output 41 | 42 | 43 | class BottleNeck(tf.keras.layers.Layer): 44 | def __init__(self, filter_num, stride=1): 45 | super(BottleNeck, self).__init__() 46 | self.conv1 = tf.keras.layers.Conv2D(filters=filter_num, 47 | kernel_size=(1, 1), 48 | strides=1, 49 | padding="same", 50 | use_bias=False) 51 | self.bn1 = tf.keras.layers.BatchNormalization() 52 | self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, 53 | kernel_size=(3, 3), 54 | strides=stride, 55 | padding="same", 56 | use_bias=False) 57 | self.bn2 = tf.keras.layers.BatchNormalization() 58 | self.conv3 = tf.keras.layers.Conv2D(filters=filter_num * 4, 59 | kernel_size=(1, 1), 60 | strides=1, 61 | padding="same", 62 | use_bias=False) 63 | self.bn3 = tf.keras.layers.BatchNormalization() 64 | self.downsample = tf.keras.Sequential([ 65 | tf.keras.layers.Conv2D(filters=filter_num * 4, 66 | kernel_size=(1, 1), 67 | strides=stride, 68 | use_bias=False), 69 | tf.keras.layers.BatchNormalization() 70 | ]) 71 | 72 | def call(self, inputs, training=None, **kwargs): 73 | residual = self.downsample(inputs, training=training) 74 | x = self.conv1(inputs) 75 | x = self.bn1(x, training=training) 76 | x = tf.nn.relu(x) 77 | x = self.conv2(x) 78 | x = self.bn2(x, training=training) 79 | x = tf.nn.relu(x) 80 | x = self.conv3(x) 81 | x = self.bn3(x, training=training) 82 | output = tf.nn.relu(tf.keras.layers.add([residual, x])) 83 | return output 84 | 85 | 86 | class ResNetTypeI(tf.keras.layers.Layer): 87 | def __init__(self, layer_params, heads, head_conv): 88 | super(ResNetTypeI, self).__init__() 89 | 90 | self.conv1 = tf.keras.layers.Conv2D(filters=64, 91 | kernel_size=(7, 7), 92 | strides=2, 93 | padding="same", 94 | use_bias=False) 95 | self.bn1 = tf.keras.layers.BatchNormalization() 96 | self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3), 97 | strides=2, 98 | padding="same") 99 | 100 | self.layer1 = ResNetTypeI.__make_basic_block_layer(filter_num=64, 101 | blocks=layer_params[0]) 102 | self.layer2 = ResNetTypeI.__make_basic_block_layer(filter_num=128, 103 | blocks=layer_params[1], 104 | stride=2) 105 | self.layer3 = ResNetTypeI.__make_basic_block_layer(filter_num=256, 106 | blocks=layer_params[2], 107 | stride=2) 108 | self.layer4 = ResNetTypeI.__make_basic_block_layer(filter_num=512, 109 | blocks=layer_params[3], 110 | stride=2) 111 | self.transposed_conv_layers = ResNetTypeI.__make_transposed_conv_layer(num_layers=3, 112 | num_filters=[256, 256, 256], 113 | num_kernels=[4, 4, 4]) 114 | 115 | if head_conv > 0: 116 | self.heatmap_layer = tf.keras.Sequential([ 117 | tf.keras.layers.Conv2D(filters=head_conv, kernel_size=(3, 3), strides=1, padding="same"), 118 | tf.keras.layers.ReLU(), 119 | tf.keras.layers.Conv2D(filters=heads["heatmap"], kernel_size=(1, 1), strides=1, padding="same") 120 | ]) 121 | self.reg_layer = tf.keras.Sequential([ 122 | tf.keras.layers.Conv2D(filters=head_conv, kernel_size=(3, 3), strides=1, padding="same"), 123 | tf.keras.layers.ReLU(), 124 | tf.keras.layers.Conv2D(filters=heads["reg"], kernel_size=(1, 1), strides=1, padding="same") 125 | ]) 126 | self.wh_layer = tf.keras.Sequential([ 127 | tf.keras.layers.Conv2D(filters=head_conv, kernel_size=(3, 3), strides=1, padding="same"), 128 | tf.keras.layers.ReLU(), 129 | tf.keras.layers.Conv2D(filters=heads["wh"], kernel_size=(1, 1), strides=1, padding="same") 130 | ]) 131 | else: 132 | self.heatmap_layer = tf.keras.layers.Conv2D(filters=heads["heatmap"], kernel_size=(1, 1), strides=1, padding="same") 133 | self.reg_layer = tf.keras.layers.Conv2D(filters=heads["reg"], kernel_size=(1, 1), strides=1, padding="same") 134 | self.wh_layer = tf.keras.layers.Conv2D(filters=heads["wh"], kernel_size=(1, 1), strides=1, padding="same") 135 | 136 | @staticmethod 137 | def __make_basic_block_layer(filter_num, blocks, stride=1): 138 | res_block = tf.keras.Sequential() 139 | res_block.add(BasicBlock(filter_num, stride=stride)) 140 | 141 | for _ in range(1, blocks): 142 | res_block.add(BasicBlock(filter_num, stride=1)) 143 | 144 | return res_block 145 | 146 | @staticmethod 147 | def __make_transposed_conv_layer(num_layers, num_filters, num_kernels): 148 | layers = tf.keras.Sequential() 149 | for i in range(num_layers): 150 | layers.add(tf.keras.layers.Conv2DTranspose(filters=num_filters[i], 151 | kernel_size=num_kernels[i], 152 | strides=2, 153 | padding="same", 154 | use_bias=False)) 155 | layers.add(tf.keras.layers.BatchNormalization()) 156 | layers.add(tf.keras.layers.ReLU()) 157 | return layers 158 | 159 | def call(self, inputs, training=None, **kwargs): 160 | x = self.conv1(inputs) 161 | x = self.bn1(x, training=training) 162 | x = tf.nn.relu(x) 163 | x = self.pool1(x) 164 | x = self.layer1(x, training=training) 165 | x = self.layer2(x, training=training) 166 | x = self.layer3(x, training=training) 167 | x = self.layer4(x, training=training) 168 | 169 | x = self.transposed_conv_layers(x, training=training) 170 | heatmap = self.heatmap_layer(x, training=training) 171 | reg = self.reg_layer(x, training=training) 172 | wh = self.wh_layer(x, training=training) 173 | 174 | return [heatmap, reg, wh] 175 | 176 | 177 | class ResNetTypeII(tf.keras.layers.Layer): 178 | def __init__(self, layer_params, heads, head_conv): 179 | super(ResNetTypeII, self).__init__() 180 | self.conv1 = tf.keras.layers.Conv2D(filters=64, 181 | kernel_size=(7, 7), 182 | strides=2, 183 | padding="same") 184 | self.bn1 = tf.keras.layers.BatchNormalization() 185 | self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3), 186 | strides=2, 187 | padding="same") 188 | 189 | self.layer1 = ResNetTypeII.__make_bottleneck_layer(filter_num=64, 190 | blocks=layer_params[0]) 191 | self.layer2 = ResNetTypeII.__make_bottleneck_layer(filter_num=128, 192 | blocks=layer_params[1], 193 | stride=2) 194 | self.layer3 = ResNetTypeII.__make_bottleneck_layer(filter_num=256, 195 | blocks=layer_params[2], 196 | stride=2) 197 | self.layer4 = ResNetTypeII.__make_bottleneck_layer(filter_num=512, 198 | blocks=layer_params[3], 199 | stride=2) 200 | 201 | self.transposed_conv_layers = ResNetTypeII.__make_transposed_conv_layer(num_layers=3, 202 | num_filters=[256, 256, 256], 203 | num_kernels=[4, 4, 4]) 204 | 205 | if head_conv > 0: 206 | self.heatmap_layer = tf.keras.Sequential([ 207 | tf.keras.layers.Conv2D(filters=head_conv, kernel_size=(3, 3), strides=1, padding="same"), 208 | tf.keras.layers.ReLU(), 209 | tf.keras.layers.Conv2D(filters=heads["heatmap"], kernel_size=(1, 1), strides=1, padding="same") 210 | ]) 211 | self.reg_layer = tf.keras.Sequential([ 212 | tf.keras.layers.Conv2D(filters=head_conv, kernel_size=(3, 3), strides=1, padding="same"), 213 | tf.keras.layers.ReLU(), 214 | tf.keras.layers.Conv2D(filters=heads["reg"], kernel_size=(1, 1), strides=1, padding="same") 215 | ]) 216 | self.wh_layer = tf.keras.Sequential([ 217 | tf.keras.layers.Conv2D(filters=head_conv, kernel_size=(3, 3), strides=1, padding="same"), 218 | tf.keras.layers.ReLU(), 219 | tf.keras.layers.Conv2D(filters=heads["wh"], kernel_size=(1, 1), strides=1, padding="same") 220 | ]) 221 | else: 222 | self.heatmap_layer = tf.keras.layers.Conv2D(filters=heads["heatmap"], kernel_size=(1, 1), strides=1, padding="same") 223 | self.reg_layer = tf.keras.layers.Conv2D(filters=heads["reg"], kernel_size=(1, 1), strides=1, padding="same") 224 | self.wh_layer = tf.keras.layers.Conv2D(filters=heads["wh"], kernel_size=(1, 1), strides=1, padding="same") 225 | 226 | @staticmethod 227 | def __make_bottleneck_layer(filter_num, blocks, stride=1): 228 | res_block = tf.keras.Sequential() 229 | res_block.add(BottleNeck(filter_num, stride=stride)) 230 | 231 | for _ in range(1, blocks): 232 | res_block.add(BottleNeck(filter_num, stride=1)) 233 | 234 | return res_block 235 | 236 | @staticmethod 237 | def __make_transposed_conv_layer(num_layers, num_filters, num_kernels): 238 | layers = tf.keras.Sequential() 239 | for i in range(num_layers): 240 | layers.add(tf.keras.layers.Conv2DTranspose(filters=num_filters[i], 241 | kernel_size=num_kernels[i], 242 | strides=2, 243 | padding="same", 244 | use_bias=False)) 245 | layers.add(tf.keras.layers.BatchNormalization()) 246 | layers.add(tf.keras.layers.ReLU()) 247 | return layers 248 | 249 | def call(self, inputs, training=None, **kwargs): 250 | x = self.conv1(inputs) 251 | x = self.bn1(x, training=training) 252 | x = tf.nn.relu(x) 253 | x = self.pool1(x) 254 | x = self.layer1(x, training=training) 255 | x = self.layer2(x, training=training) 256 | x = self.layer3(x, training=training) 257 | x = self.layer4(x, training=training) 258 | 259 | x = self.transposed_conv_layers(x, training=training) 260 | heatmap = self.heatmap_layer(x, training=training) 261 | reg = self.reg_layer(x, training=training) 262 | wh = self.wh_layer(x, training=training) 263 | 264 | return [heatmap, reg, wh] 265 | 266 | 267 | def resnet_18(): 268 | return ResNetTypeI(layer_params=[2, 2, 2, 2], heads=Config.heads, head_conv=Config.head_conv["resnets"]) 269 | 270 | 271 | def resnet_34(): 272 | return ResNetTypeI(layer_params=[3, 4, 6, 3], heads=Config.heads, head_conv=Config.head_conv["resnets"]) 273 | 274 | 275 | def resnet_50(): 276 | return ResNetTypeII(layer_params=[3, 4, 6, 3], heads=Config.heads, head_conv=Config.head_conv["resnets"]) 277 | 278 | 279 | def resnet_101(): 280 | return ResNetTypeII(layer_params=[3, 4, 23, 3], heads=Config.heads, head_conv=Config.head_conv["resnets"]) 281 | 282 | 283 | def resnet_152(): 284 | return ResNetTypeII(layer_params=[3, 8, 36, 3], heads=Config.heads, head_conv=Config.head_conv["resnets"]) 285 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calmiLovesAI/CenterNet_TensorFlow2/a30c9b4243c7c5f45a0cf47df655c171dcfef256/data/__init__.py -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from configuration import Config 5 | from utils.gaussian import gaussian_radius, draw_umich_gaussian 6 | 7 | 8 | class DetectionDataset: 9 | def __init__(self): 10 | self.txt_file = Config.txt_file_dir 11 | self.batch_size = Config.batch_size 12 | 13 | @staticmethod 14 | def __get_length_of_dataset(dataset): 15 | length = 0 16 | for _ in dataset: 17 | length += 1 18 | return length 19 | 20 | def generate_datatset(self): 21 | dataset = tf.data.TextLineDataset(filenames=self.txt_file) 22 | length_of_dataset = DetectionDataset.__get_length_of_dataset(dataset) 23 | train_dataset = dataset.batch(batch_size=self.batch_size) 24 | return train_dataset, length_of_dataset 25 | 26 | 27 | class DataLoader: 28 | 29 | input_image_height = Config.get_image_size()[0] 30 | input_image_width = Config.get_image_size()[1] 31 | input_image_channels = Config.image_channels 32 | 33 | def __init__(self): 34 | self.max_boxes_per_image = Config.max_boxes_per_image 35 | 36 | def read_batch_data(self, batch_data): 37 | batch_size = batch_data.shape[0] 38 | image_file_list = [] 39 | boxes_list = [] 40 | for n in range(batch_size): 41 | image_file, boxes = self.__get_image_information(single_line=batch_data[n]) 42 | image_file_list.append(image_file) 43 | boxes_list.append(boxes) 44 | boxes = np.stack(boxes_list, axis=0) 45 | image_tensor_list = [] 46 | for image in image_file_list: 47 | image_tensor = DataLoader.image_preprocess(is_training=True, image_dir=image) 48 | image_tensor_list.append(image_tensor) 49 | images = tf.stack(values=image_tensor_list, axis=0) 50 | return images, boxes 51 | 52 | def __get_image_information(self, single_line): 53 | """ 54 | :param single_line: tensor 55 | :return: 56 | image_file: string, image file dir 57 | boxes_array: numpy array, shape = (max_boxes_per_image, 5(xmin, ymin, xmax, ymax, class_id)) 58 | """ 59 | line_string = bytes.decode(single_line.numpy(), encoding="utf-8") 60 | line_list = line_string.strip().split(" ") 61 | image_file, image_height, image_width = line_list[:3] 62 | image_height, image_width = int(float(image_height)), int(float(image_width)) 63 | boxes = [] 64 | num_of_boxes = (len(line_list) - 3) / 5 65 | if int(num_of_boxes) == num_of_boxes: 66 | num_of_boxes = int(num_of_boxes) 67 | else: 68 | raise ValueError("num_of_boxes must be type 'int'.") 69 | for index in range(num_of_boxes): 70 | if index < self.max_boxes_per_image: 71 | xmin = int(float(line_list[3 + index * 5])) 72 | ymin = int(float(line_list[3 + index * 5 + 1])) 73 | xmax = int(float(line_list[3 + index * 5 + 2])) 74 | ymax = int(float(line_list[3 + index * 5 + 3])) 75 | class_id = int(line_list[3 + index * 5 + 4]) 76 | xmin, ymin, xmax, ymax = DataLoader.box_preprocess(image_height, image_width, xmin, ymin, xmax, ymax) 77 | boxes.append([xmin, ymin, xmax, ymax, class_id]) 78 | num_padding_boxes = self.max_boxes_per_image - num_of_boxes 79 | if num_padding_boxes > 0: 80 | for i in range(num_padding_boxes): 81 | boxes.append([0, 0, 0, 0, -1]) 82 | boxes_array = np.array(boxes, dtype=np.float32) 83 | return image_file, boxes_array 84 | 85 | @classmethod 86 | def box_preprocess(cls, h, w, xmin, ymin, xmax, ymax): 87 | resize_ratio = [DataLoader.input_image_height / h, DataLoader.input_image_width / w] 88 | xmin = int(resize_ratio[1] * xmin) 89 | xmax = int(resize_ratio[1] * xmax) 90 | ymin = int(resize_ratio[0] * ymin) 91 | ymax = int(resize_ratio[0] * ymax) 92 | return xmin, ymin, xmax, ymax 93 | 94 | @classmethod 95 | def image_preprocess(cls, is_training, image_dir): 96 | image_raw = tf.io.read_file(filename=image_dir) 97 | decoded_image = tf.io.decode_image(contents=image_raw, channels=DataLoader.input_image_channels, dtype=tf.dtypes.float32) 98 | decoded_image = tf.image.resize(images=decoded_image, size=(DataLoader.input_image_height, DataLoader.input_image_width)) 99 | return decoded_image 100 | 101 | 102 | class GT: 103 | def __init__(self, batch_labels): 104 | self.downsampling_ratio = Config.downsampling_ratio 105 | self.features_shape = np.array(Config.get_image_size(), dtype=np.int32) // self.downsampling_ratio 106 | self.batch_labels = batch_labels 107 | self.batch_size = batch_labels.shape[0] 108 | 109 | def get_gt_values(self): 110 | gt_heatmap = np.zeros(shape=(self.batch_size, self.features_shape[0], self.features_shape[1], Config.num_classes), dtype=np.float32) 111 | gt_reg = np.zeros(shape=(self.batch_size, Config.max_boxes_per_image, 2), dtype=np.float32) 112 | gt_wh = np.zeros(shape=(self.batch_size, Config.max_boxes_per_image, 2), dtype=np.float32) 113 | gt_reg_mask = np.zeros(shape=(self.batch_size, Config.max_boxes_per_image), dtype=np.float32) 114 | gt_indices = np.zeros(shape=(self.batch_size, Config.max_boxes_per_image), dtype=np.float32) 115 | for i, label in enumerate(self.batch_labels): 116 | label = label[label[:, 4] != -1] 117 | hm, reg, wh, reg_mask, ind = self.__decode_label(label) 118 | gt_heatmap[i, :, :, :] = hm 119 | gt_reg[i, :, :] = reg 120 | gt_wh[i, :, :] = wh 121 | gt_reg_mask[i, :] = reg_mask 122 | gt_indices[i, :] = ind 123 | return gt_heatmap, gt_reg, gt_wh, gt_reg_mask, gt_indices 124 | 125 | def __decode_label(self, label): 126 | hm = np.zeros(shape=(self.features_shape[0], self.features_shape[1], Config.num_classes), dtype=np.float32) 127 | reg = np.zeros(shape=(Config.max_boxes_per_image, 2), dtype=np.float32) 128 | wh = np.zeros(shape=(Config.max_boxes_per_image, 2), dtype=np.float32) 129 | reg_mask = np.zeros(shape=(Config.max_boxes_per_image), dtype=np.float32) 130 | ind = np.zeros(shape=(Config.max_boxes_per_image), dtype=np.float32) 131 | for j, item in enumerate(label): 132 | item[:4] = item[:4] / self.downsampling_ratio 133 | xmin, ymin, xmax, ymax, class_id = item 134 | class_id = class_id.astype(np.int32) 135 | h, w = int(ymax - ymin), int(xmax - xmin) 136 | radius = gaussian_radius((h, w)) 137 | radius = max(0, int(radius)) 138 | ctr_x, ctr_y = (xmin + xmax) / 2, (ymin + ymax) / 2 139 | center_point = np.array([ctr_x, ctr_y], dtype=np.float32) 140 | center_point_int = center_point.astype(np.int32) 141 | draw_umich_gaussian(hm[:, :, class_id], center_point_int, radius) 142 | reg[j] = center_point - center_point_int 143 | wh[j] = 1. * w, 1. * h 144 | reg_mask[j] = 1 145 | ind[j] = center_point_int[1] * self.features_shape[1] + center_point_int[0] 146 | return hm, reg, wh, reg_mask, ind 147 | -------------------------------------------------------------------------------- /data/datasets/README.md: -------------------------------------------------------------------------------- 1 | Put your datasets here. -------------------------------------------------------------------------------- /data/voc.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | 3 | from pathlib import Path 4 | from configuration import Config 5 | 6 | 7 | class VOC: 8 | def __init__(self): 9 | self.annotations_dir = Config.pascal_voc_labels 10 | self.images_dir = Config.pascal_voc_images 11 | self.label_names = VOC.get_filenames(self.annotations_dir, "*.xml") 12 | 13 | @staticmethod 14 | def get_filenames(root_dir, pattern): 15 | p = Path(root_dir) 16 | filenames = [x for x in p.glob(pattern)] 17 | return filenames 18 | 19 | def __len__(self): 20 | return len(self.label_names) 21 | 22 | def __getitem__(self, item): 23 | label_file = str(self.label_names[item]) 24 | tree = ET.parse(label_file) 25 | image_name = tree.find("filename").text 26 | image_width = float(tree.find("size").find("width").text) 27 | image_height = float(tree.find("size").find("height").text) 28 | objects = tree.findall("object") 29 | class_ids = [] 30 | bboxes = [] 31 | for i, obj in enumerate(objects): 32 | class_id = Config.pascal_voc_classes[obj.find("name").text] 33 | bbox = obj.find("bndbox") 34 | xmin = float(bbox.find("xmin").text) 35 | ymin = float(bbox.find("ymin").text) 36 | xmax = float(bbox.find("xmax").text) 37 | ymax = float(bbox.find("ymax").text) 38 | class_ids.append(class_id) 39 | bboxes.append([xmin, ymin, xmax, ymax]) 40 | sample = { 41 | "image_file_dir": self.images_dir + "/" + image_name, 42 | "image_height": image_height, 43 | "image_width": image_width, 44 | "class_ids": class_ids, 45 | "bboxes": bboxes 46 | } 47 | return sample 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /saved_model/README.md: -------------------------------------------------------------------------------- 1 | The model will be saved here. -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cv2 3 | import numpy as np 4 | 5 | from configuration import Config 6 | from core.centernet import CenterNet, PostProcessing 7 | from data.dataloader import DataLoader 8 | 9 | 10 | def idx2class(): 11 | return dict((v, k) for k, v in Config.pascal_voc_classes.items()) 12 | 13 | 14 | def draw_boxes_on_image(image, boxes, scores, classes): 15 | idx2class_dict = idx2class() 16 | num_boxes = boxes.shape[0] 17 | for i in range(num_boxes): 18 | class_and_score = "{}: {:.3f}".format(str(idx2class_dict[classes[i]]), scores[i]) 19 | cv2.rectangle(img=image, pt1=(boxes[i, 0], boxes[i, 1]), pt2=(boxes[i, 2], boxes[i, 3]), color=(250, 206, 135), thickness=2) 20 | 21 | text_size = cv2.getTextSize(text=class_and_score, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1) 22 | text_width, text_height = text_size[0][0], text_size[0][1] 23 | cv2.rectangle(img=image, pt1=(boxes[i, 0], boxes[i, 1]), pt2=(boxes[i, 0] + text_width, boxes[i, 1] - text_height), color=(203, 192, 255), thickness=-1) 24 | cv2.putText(img=image, text=class_and_score, org=(boxes[i, 0], boxes[i, 1] - 2), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, color=(0, 0, 0), thickness=1) 25 | return image 26 | 27 | 28 | def test_single_picture(picture_dir, model): 29 | image_array = cv2.imread(picture_dir) 30 | image = DataLoader.image_preprocess(is_training=False, image_dir=picture_dir) 31 | image = tf.expand_dims(input=image, axis=0) 32 | 33 | outputs = model(image, training=False) 34 | post_process = PostProcessing() 35 | boxes, scores, classes = post_process.testing_procedure(outputs, [image_array.shape[0], image_array.shape[1]]) 36 | image_with_boxes = draw_boxes_on_image(image_array, boxes.astype(np.int), scores, classes) 37 | return image_with_boxes 38 | 39 | 40 | if __name__ == '__main__': 41 | # GPU settings 42 | gpus = tf.config.list_physical_devices("GPU") 43 | if gpus: 44 | for gpu in gpus: 45 | tf.config.experimental.set_memory_growth(gpu, True) 46 | 47 | centernet = CenterNet() 48 | centernet.load_weights(filepath=Config.save_model_dir + "saved_model") 49 | 50 | image = test_single_picture(picture_dir=Config.test_single_image_dir, model=centernet) 51 | 52 | cv2.namedWindow("detect result", flags=cv2.WINDOW_NORMAL) 53 | cv2.imshow("detect result", image) 54 | cv2.waitKey(0) 55 | -------------------------------------------------------------------------------- /test_pictures/README.md: -------------------------------------------------------------------------------- 1 | Put the test pictures here. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | 4 | from core.centernet import PostProcessing, CenterNet 5 | from data.dataloader import DetectionDataset, DataLoader 6 | from configuration import Config 7 | from utils.visualize import visualize_training_results 8 | 9 | 10 | def print_model_summary(network): 11 | sample_inputs = tf.random.normal(shape=(Config.batch_size, Config.get_image_size()[0], Config.get_image_size()[1], Config.image_channels)) 12 | sample_outputs = network(sample_inputs, training=True) 13 | network.summary() 14 | 15 | if __name__ == '__main__': 16 | # GPU settings 17 | gpus = tf.config.list_physical_devices("GPU") 18 | if gpus: 19 | for gpu in gpus: 20 | tf.config.experimental.set_memory_growth(gpu, True) 21 | 22 | # dataset 23 | train_dataset = DetectionDataset() 24 | train_data, train_size = train_dataset.generate_datatset() 25 | data_loader = DataLoader() 26 | steps_per_epoch = tf.math.ceil(train_size / Config.batch_size) 27 | 28 | 29 | # model 30 | centernet = CenterNet() 31 | print_model_summary(centernet) 32 | load_weights_from_epoch = Config.load_weights_from_epoch 33 | if Config.load_weights_before_training: 34 | centernet.load_weights(filepath=Config.save_model_dir+"epoch-{}".format(load_weights_from_epoch)) 35 | print("Successfully load weights!") 36 | else: 37 | load_weights_from_epoch = -1 38 | 39 | # optimizer 40 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-4, 41 | decay_steps=steps_per_epoch * Config.learning_rate_decay_epochs, 42 | decay_rate=0.96) 43 | optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) 44 | 45 | # metrics 46 | loss_metric = tf.metrics.Mean() 47 | 48 | post_process = PostProcessing() 49 | 50 | def train_step(batch_images, batch_labels): 51 | with tf.GradientTape() as tape: 52 | pred = centernet(batch_images, training=True) 53 | loss_value = post_process.training_procedure(batch_labels=batch_labels, pred=pred) 54 | gradients = tape.gradient(target=loss_value, sources=centernet.trainable_variables) 55 | optimizer.apply_gradients(grads_and_vars=zip(gradients, centernet.trainable_variables)) 56 | loss_metric.update_state(values=loss_value) 57 | 58 | for epoch in range(load_weights_from_epoch + 1, Config.epochs): 59 | for step, batch_data in enumerate(train_data): 60 | step_start_time = time.time() 61 | images, labels = data_loader.read_batch_data(batch_data) 62 | train_step(images, labels) 63 | step_end_time = time.time() 64 | print("Epoch: {}/{}, step: {}/{}, loss: {}, time_cost: {:.3f}s".format(epoch, 65 | Config.epochs, 66 | step, 67 | steps_per_epoch, 68 | loss_metric.result(), 69 | step_end_time - step_start_time)) 70 | loss_metric.reset_states() 71 | 72 | if epoch % Config.save_frequency == 0: 73 | centernet.save_weights(filepath=Config.save_model_dir+"epoch-{}".format(epoch), save_format="tf") 74 | 75 | if Config.test_images_during_training: 76 | visualize_training_results(pictures=Config.test_images_dir_list, model=centernet, epoch=epoch) 77 | 78 | centernet.save_weights(filepath=Config.save_model_dir + "saved_model", save_format="tf") 79 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calmiLovesAI/CenterNet_TensorFlow2/a30c9b4243c7c5f45a0cf47df655c171dcfef256/utils/__init__.py -------------------------------------------------------------------------------- /utils/gaussian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def gaussian_radius(det_size, min_overlap=0.7): 5 | height, width = det_size 6 | 7 | a1 = 1 8 | b1 = (height + width) 9 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 10 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 11 | r1 = (b1 + sq1) / 2 12 | 13 | a2 = 4 14 | b2 = 2 * (height + width) 15 | c2 = (1 - min_overlap) * width * height 16 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 17 | r2 = (b2 + sq2) / 2 18 | 19 | a3 = 4 * min_overlap 20 | b3 = -2 * min_overlap * (height + width) 21 | c3 = (min_overlap - 1) * width * height 22 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 23 | r3 = (b3 + sq3) / 2 24 | return min(r1, r2, r3) 25 | 26 | 27 | def gaussian2D(shape, sigma=1): 28 | m, n = [(ss - 1.) / 2. for ss in shape] 29 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 30 | 31 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 32 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 33 | return h 34 | 35 | 36 | def draw_umich_gaussian(heatmap, center, radius, k=1): 37 | diameter = 2 * radius + 1 38 | gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) 39 | 40 | x, y = int(center[0]), int(center[1]) 41 | 42 | height, width = heatmap.shape[0:2] 43 | 44 | left, right = min(x, radius), min(width - x, radius + 1) 45 | top, bottom = min(y, radius), min(height - y, radius + 1) 46 | 47 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 48 | masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] 49 | if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: 50 | np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) 51 | return heatmap 52 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from test import test_single_picture 3 | from configuration import Config 4 | 5 | 6 | def visualize_training_results(pictures, model, epoch): 7 | """ 8 | :param pictures: List of image directories. 9 | :param model: 10 | :param epoch: 11 | :return: 12 | """ 13 | index = 0 14 | for picture in pictures: 15 | index += 1 16 | result = test_single_picture(picture_dir=picture, model=model) 17 | cv2.imwrite(filename=Config.training_results_save_dir + "epoch-{}-picture-{}.jpg".format(epoch, index), img=result) 18 | 19 | -------------------------------------------------------------------------------- /write_to_txt.py: -------------------------------------------------------------------------------- 1 | from data.voc import VOC 2 | from configuration import Config 3 | 4 | 5 | if __name__ == '__main__': 6 | voc_dataset = VOC() 7 | with open(file=Config.txt_file_dir, mode="a+", encoding="utf-8") as f: 8 | for i, sample in enumerate(voc_dataset): 9 | num_bboxes = len(sample["bboxes"]) 10 | line_text = sample["image_file_dir"] + " " + str(sample["image_height"]) + " " + str(sample["image_width"]) + " " 11 | for j in range(num_bboxes): 12 | bbox = list(map(str, sample["bboxes"][j])) 13 | cls = str(sample["class_ids"][j]) 14 | bbox.append(cls) 15 | line_text += " ".join(bbox) 16 | line_text += " " 17 | line_text = line_text.strip() 18 | line_text += "\n" 19 | print("Writing information of picture {} to {}".format(sample["image_file_dir"], Config.txt_file_dir)) 20 | f.write(line_text) --------------------------------------------------------------------------------