├── .gitignore ├── LICENSE ├── README.md ├── argument.py ├── backbone ├── __init__.py └── vovnet.py ├── boxlist.py ├── checkpoint └── .gitignore ├── coco_meta.py ├── config ├── dataset.py ├── distributed.py ├── evaluate.py ├── loss.py ├── model.py ├── postprocess.py ├── train.py ├── transform.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | *.pth 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fcos-pytorch 2 | Re-implementation of FCOS: Fully Convolutional One-Stage Object Detection (https://arxiv.org/abs/1904.01355) 3 | 4 | I have implemented FCOS to study object detection. Most of the code came from: 5 | 6 | * https://github.com/tianzhi0549/FCOS 7 | * https://github.com/yqyao/FCOS_PLUS 8 | * https://github.com/vov-net/VoVNet-Classification 9 | 10 | I think FCOS is very interesting approach to object detection! -------------------------------------------------------------------------------- /argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_argparser(): 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument('--local_rank', type=int, default=0) 8 | parser.add_argument('--lr', type=float, default=0.01) 9 | parser.add_argument('--l2', type=float, default=0.0001) 10 | parser.add_argument('--batch', type=int, default=16) 11 | parser.add_argument('--epoch', type=int, default=24) 12 | parser.add_argument('--n_save_sample', type=int, default=5) 13 | parser.add_argument('--ckpt', type=str) 14 | parser.add_argument('path', type=str) 15 | 16 | return parser 17 | 18 | 19 | def get_args(): 20 | parser = get_argparser() 21 | args = parser.parse_args() 22 | 23 | args.feat_channels = [0, 0, 512, 768, 1024] 24 | args.out_channel = 256 25 | args.use_p5 = True 26 | args.n_class = 81 27 | args.n_conv = 4 28 | args.prior = 0.01 29 | args.threshold = 0.05 30 | args.top_n = 1000 31 | args.nms_threshold = 0.6 32 | args.post_top_n = 100 33 | args.min_size = 0 34 | args.fpn_strides = [8, 16, 32, 64, 128] 35 | args.gamma = 2.0 36 | args.alpha = 0.25 37 | args.sizes = [[-1, 64], [64, 128], [128, 256], [256, 512], [512, 100000000]] 38 | args.train_min_size_range = (640, 800) 39 | args.train_max_size = 1333 40 | args.test_min_size = 800 41 | args.test_max_size = 1333 42 | args.pixel_mean = [0.40789654, 0.44719302, 0.47026115] 43 | args.pixel_std = [0.28863828, 0.27408164, 0.27809835] 44 | args.size_divisible = 32 45 | args.center_sample = True 46 | args.pos_radius = 1.5 47 | args.iou_loss_type = 'giou' 48 | 49 | return args 50 | -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .vovnet import vovnet27_slim, vovnet39, vovnet57 2 | -------------------------------------------------------------------------------- /backbone/vovnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | 7 | __all__ = ['VoVNet', 'vovnet27_slim', 'vovnet39', 'vovnet57'] 8 | 9 | 10 | model_urls = { 11 | 'vovnet39': './vovnet39_torchvision.pth', 12 | 'vovnet57': './vovnet57_torchvision.pth', 13 | } 14 | 15 | 16 | def conv3x3( 17 | in_channels, 18 | out_channels, 19 | module_name, 20 | postfix, 21 | stride=1, 22 | groups=1, 23 | kernel_size=3, 24 | padding=1, 25 | ): 26 | """3x3 convolution with padding""" 27 | return [ 28 | ( 29 | '{}_{}/conv'.format(module_name, postfix), 30 | nn.Conv2d( 31 | in_channels, 32 | out_channels, 33 | kernel_size=kernel_size, 34 | stride=stride, 35 | padding=padding, 36 | groups=groups, 37 | bias=False, 38 | ), 39 | ), 40 | ('{}_{}/norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)), 41 | ('{}_{}/relu'.format(module_name, postfix), nn.ReLU(inplace=True)), 42 | ] 43 | 44 | 45 | def conv1x1( 46 | in_channels, 47 | out_channels, 48 | module_name, 49 | postfix, 50 | stride=1, 51 | groups=1, 52 | kernel_size=1, 53 | padding=0, 54 | ): 55 | """1x1 convolution""" 56 | return [ 57 | ( 58 | '{}_{}/conv'.format(module_name, postfix), 59 | nn.Conv2d( 60 | in_channels, 61 | out_channels, 62 | kernel_size=kernel_size, 63 | stride=stride, 64 | padding=padding, 65 | groups=groups, 66 | bias=False, 67 | ), 68 | ), 69 | ('{}_{}/norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)), 70 | ('{}_{}/relu'.format(module_name, postfix), nn.ReLU(inplace=True)), 71 | ] 72 | 73 | 74 | class _OSA_module(nn.Module): 75 | def __init__( 76 | self, in_ch, stage_ch, concat_ch, layer_per_block, module_name, identity=False 77 | ): 78 | super(_OSA_module, self).__init__() 79 | 80 | self.identity = identity 81 | self.layers = nn.ModuleList() 82 | in_channel = in_ch 83 | for i in range(layer_per_block): 84 | self.layers.append( 85 | nn.Sequential( 86 | OrderedDict(conv3x3(in_channel, stage_ch, module_name, i)) 87 | ) 88 | ) 89 | in_channel = stage_ch 90 | 91 | # feature aggregation 92 | in_channel = in_ch + layer_per_block * stage_ch 93 | self.concat = nn.Sequential( 94 | OrderedDict(conv1x1(in_channel, concat_ch, module_name, 'concat')) 95 | ) 96 | 97 | def forward(self, x): 98 | identity_feat = x 99 | output = [] 100 | output.append(x) 101 | for layer in self.layers: 102 | x = layer(x) 103 | output.append(x) 104 | 105 | x = torch.cat(output, dim=1) 106 | xt = self.concat(x) 107 | 108 | if self.identity: 109 | xt = xt + identity_feat 110 | 111 | return xt 112 | 113 | 114 | class _OSA_stage(nn.Sequential): 115 | def __init__( 116 | self, in_ch, stage_ch, concat_ch, block_per_stage, layer_per_block, stage_num 117 | ): 118 | super(_OSA_stage, self).__init__() 119 | 120 | if not stage_num == 2: 121 | self.add_module( 122 | 'Pooling', nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 123 | ) 124 | 125 | module_name = f'OSA{stage_num}_1' 126 | self.add_module( 127 | module_name, 128 | _OSA_module(in_ch, stage_ch, concat_ch, layer_per_block, module_name), 129 | ) 130 | for i in range(block_per_stage - 1): 131 | module_name = f'OSA{stage_num}_{i+2}' 132 | self.add_module( 133 | module_name, 134 | _OSA_module( 135 | concat_ch, 136 | stage_ch, 137 | concat_ch, 138 | layer_per_block, 139 | module_name, 140 | identity=True, 141 | ), 142 | ) 143 | 144 | 145 | class VoVNet(nn.Module): 146 | def __init__( 147 | self, 148 | config_stage_ch, 149 | config_concat_ch, 150 | block_per_stage, 151 | layer_per_block, 152 | num_classes=1000, 153 | ): 154 | super(VoVNet, self).__init__() 155 | 156 | # Stem module 157 | stem = conv3x3(3, 64, 'stem', '1', 2) 158 | stem += conv3x3(64, 64, 'stem', '2', 1) 159 | stem += conv3x3(64, 128, 'stem', '3', 2) 160 | self.add_module('stem', nn.Sequential(OrderedDict(stem))) 161 | 162 | stem_out_ch = [128] 163 | in_ch_list = stem_out_ch + config_concat_ch[:-1] 164 | self.stage_names = [] 165 | for i in range(4): # num_stages 166 | name = 'stage%d' % (i + 2) 167 | self.stage_names.append(name) 168 | self.add_module( 169 | name, 170 | _OSA_stage( 171 | in_ch_list[i], 172 | config_stage_ch[i], 173 | config_concat_ch[i], 174 | block_per_stage[i], 175 | layer_per_block, 176 | i + 2, 177 | ), 178 | ) 179 | 180 | # self.classifier = nn.Linear(config_concat_ch[-1], num_classes) 181 | 182 | for m in self.modules(): 183 | if isinstance(m, nn.Conv2d): 184 | nn.init.kaiming_normal_(m.weight) 185 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 186 | nn.init.constant_(m.weight, 1) 187 | nn.init.constant_(m.bias, 0) 188 | elif isinstance(m, nn.Linear): 189 | nn.init.constant_(m.bias, 0) 190 | 191 | def forward(self, x): 192 | features = [] 193 | x = self.stem[:6](x) 194 | features.append(x) 195 | x = self.stem[6:](x) 196 | for name in self.stage_names: 197 | x = getattr(self, name)(x) 198 | features.append(x) 199 | # x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1) 200 | # x = self.classifier(x) 201 | return features 202 | 203 | 204 | def _vovnet( 205 | arch, 206 | config_stage_ch, 207 | config_concat_ch, 208 | block_per_stage, 209 | layer_per_block, 210 | pretrained, 211 | progress, 212 | **kwargs, 213 | ): 214 | model = VoVNet( 215 | config_stage_ch, config_concat_ch, block_per_stage, layer_per_block, **kwargs 216 | ) 217 | if pretrained: 218 | state_dict = torch.load(model_urls[arch]) 219 | new_dict = OrderedDict() 220 | 221 | for k, v in state_dict.items(): 222 | key = k.replace('module.', '') 223 | new_dict[key] = v 224 | 225 | model.load_state_dict(new_dict, strict=False) 226 | return model 227 | 228 | 229 | def vovnet57(pretrained=False, progress=True, **kwargs): 230 | r"""Constructs a VoVNet-57 model as described in 231 | `"An Energy and GPU-Computation Efficient Backbone Networks" 232 | `_. 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | progress (bool): If True, displays a progress bar of the download to stderr 236 | """ 237 | return _vovnet( 238 | 'vovnet57', 239 | [128, 160, 192, 224], 240 | [256, 512, 768, 1024], 241 | [1, 1, 4, 3], 242 | 5, 243 | pretrained, 244 | progress, 245 | **kwargs, 246 | ) 247 | 248 | 249 | def vovnet39(pretrained=False, progress=True, **kwargs): 250 | r"""Constructs a VoVNet-39 model as described in 251 | `"An Energy and GPU-Computation Efficient Backbone Networks" 252 | `_. 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | progress (bool): If True, displays a progress bar of the download to stderr 256 | """ 257 | return _vovnet( 258 | 'vovnet39', 259 | [128, 160, 192, 224], 260 | [256, 512, 768, 1024], 261 | [1, 1, 2, 2], 262 | 5, 263 | pretrained, 264 | progress, 265 | **kwargs, 266 | ) 267 | 268 | 269 | def vovnet27_slim(pretrained=False, progress=True, **kwargs): 270 | r"""Constructs a VoVNet-39 model as described in 271 | `"An Energy and GPU-Computation Efficient Backbone Networks" 272 | `_. 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _vovnet( 278 | 'vovnet27_slim', 279 | [64, 80, 96, 112], 280 | [128, 256, 384, 512], 281 | [1, 1, 1, 1], 282 | 5, 283 | pretrained, 284 | progress, 285 | **kwargs, 286 | ) 287 | 288 | -------------------------------------------------------------------------------- /boxlist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import ops 3 | 4 | 5 | FLIP_LEFT_RIGHT = 0 6 | FLIP_TOP_BOTTOM = 1 7 | 8 | 9 | class BoxList: 10 | def __init__(self, box, image_size, mode='xyxy'): 11 | device = box.device if hasattr(box, 'device') else 'cpu' 12 | box = torch.as_tensor(box, dtype=torch.float32, device=device) 13 | 14 | self.box = box 15 | self.size = image_size 16 | self.mode = mode 17 | 18 | self.fields = {} 19 | 20 | def convert(self, mode): 21 | if mode == self.mode: 22 | return self 23 | 24 | x_min, y_min, x_max, y_max = self.split_to_xyxy() 25 | 26 | if mode == 'xyxy': 27 | box = torch.cat([x_min, y_min, x_max, y_max], -1) 28 | box = BoxList(box, self.size, mode=mode) 29 | 30 | elif mode == 'xywh': 31 | remove = 1 32 | box = torch.cat( 33 | [x_min, y_min, x_max - x_min + remove, y_max - y_min + remove], -1 34 | ) 35 | box = BoxList(box, self.size, mode=mode) 36 | 37 | box.copy_field(self) 38 | 39 | return box 40 | 41 | def copy_field(self, box): 42 | for k, v in box.fields.items(): 43 | self.fields[k] = v 44 | 45 | def area(self): 46 | box = self.box 47 | 48 | if self.mode == 'xyxy': 49 | remove = 1 50 | 51 | area = (box[:, 2] - box[:, 0] + remove) * (box[:, 3] - box[:, 1] + remove) 52 | 53 | elif self.mode == 'xywh': 54 | area = box[:, 2] * box[:, 3] 55 | 56 | return area 57 | 58 | def split_to_xyxy(self): 59 | if self.mode == 'xyxy': 60 | x_min, y_min, x_max, y_max = self.box.split(1, dim=-1) 61 | 62 | return x_min, y_min, x_max, y_max 63 | 64 | elif self.mode == 'xywh': 65 | remove = 1 66 | x_min, y_min, w, h = self.box.split(1, dim=-1) 67 | 68 | return ( 69 | x_min, 70 | y_min, 71 | x_min + (w - remove).clamp(min=0), 72 | y_min + (h - remove).clamp(min=0), 73 | ) 74 | 75 | def __len__(self): 76 | return self.box.shape[0] 77 | 78 | def __getitem__(self, index): 79 | box = BoxList(self.box[index], self.size, self.mode) 80 | 81 | for k, v in self.fields.items(): 82 | box.fields[k] = v[index] 83 | 84 | return box 85 | 86 | def resize(self, size, *args, **kwargs): 87 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) 88 | 89 | if ratios[0] == ratios[1]: 90 | ratio = ratios[0] 91 | scaled = self.box * ratio 92 | box = BoxList(scaled, size, mode=self.mode) 93 | 94 | for k, v in self.fields.items(): 95 | if not isinstance(v, torch.Tensor): 96 | v = v.resize(size, *args, **kwargs) 97 | 98 | box.fields[k] = v 99 | 100 | return box 101 | 102 | ratio_w, ratio_h = ratios 103 | x_min, y_min, x_max, y_max = self.split_to_xyxy() 104 | scaled_x_min = x_min * ratio_w 105 | scaled_x_max = x_max * ratio_w 106 | scaled_y_min = y_min * ratio_h 107 | scaled_y_max = y_max * ratio_h 108 | scaled = torch.cat([scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max], -1) 109 | box = BoxList(scaled, size, mode='xyxy') 110 | 111 | for k, v in self.fields.items(): 112 | if not isinstance(v, torch.Tensor): 113 | v = v.resize(size, *args, **kwargs) 114 | 115 | box.fields[k] = v 116 | 117 | return box.convert(self.mode) 118 | 119 | def transpose(self, method): 120 | width, height = self.size 121 | x_min, y_min, x_max, y_max = self.split_to_xyxy() 122 | 123 | if method == FLIP_LEFT_RIGHT: 124 | remove = 1 125 | 126 | transpose_x_min = width - x_max - remove 127 | transpose_x_max = width - x_min - remove 128 | transpose_y_min = y_min 129 | transpose_y_max = y_max 130 | 131 | elif method == FLIP_TOP_BOTTOM: 132 | transpose_x_min = x_min 133 | transpose_x_max = x_max 134 | transpose_y_min = height - y_max 135 | transpose_y_max = height - y_min 136 | 137 | transpose_box = torch.cat( 138 | [transpose_x_min, transpose_y_min, transpose_x_max, transpose_y_max], -1 139 | ) 140 | box = BoxList(transpose_box, self.size, mode='xyxy') 141 | 142 | for k, v in self.fields.items(): 143 | if not isinstance(v, torch.Tensor): 144 | v = v.transpose(method) 145 | 146 | box.fields[k] = v 147 | 148 | return box.convert(self.mode) 149 | 150 | def clip(self, remove_empty=True): 151 | remove = 1 152 | 153 | max_width = self.size[0] - remove 154 | max_height = self.size[1] - remove 155 | 156 | self.box[:, 0].clamp_(min=0, max=max_width) 157 | self.box[:, 1].clamp_(min=0, max=max_height) 158 | self.box[:, 2].clamp_(min=0, max=max_width) 159 | self.box[:, 3].clamp_(min=0, max=max_height) 160 | 161 | if remove_empty: 162 | box = self.box 163 | keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) 164 | 165 | return self[keep] 166 | 167 | else: 168 | return self 169 | 170 | def to(self, device): 171 | box = BoxList(self.box.to(device), self.size, self.mode) 172 | 173 | for k, v in self.fields.items(): 174 | if hasattr(v, 'to'): 175 | v = v.to(device) 176 | 177 | box.fields[k] = v 178 | 179 | return box 180 | 181 | 182 | def remove_small_box(boxlist, min_size): 183 | box = boxlist.convert('xywh').box 184 | _, _, w, h = box.unbind(dim=1) 185 | keep = (w >= min_size) & (h >= min_size) 186 | keep = keep.nonzero().squeeze(1) 187 | 188 | return boxlist[keep] 189 | 190 | 191 | def cat_boxlist(boxlists): 192 | size = boxlists[0].size 193 | mode = boxlists[0].mode 194 | field_keys = boxlists[0].fields.keys() 195 | 196 | box_cat = torch.cat([boxlist.box for boxlist in boxlists], 0) 197 | new_boxlist = BoxList(box_cat, size, mode) 198 | 199 | for field in field_keys: 200 | data = torch.cat([boxlist.fields[field] for boxlist in boxlists], 0) 201 | new_boxlist.fields[field] = data 202 | 203 | return new_boxlist 204 | 205 | 206 | def boxlist_nms(boxlist, scores, threshold, max_proposal=-1): 207 | if threshold <= 0: 208 | return boxlist 209 | 210 | mode = boxlist.mode 211 | boxlist = boxlist.convert('xyxy') 212 | box = boxlist.box 213 | keep = ops.nms(box, scores, threshold) 214 | 215 | if max_proposal > 0: 216 | keep = keep[:max_proposal] 217 | 218 | boxlist = boxlist[keep] 219 | 220 | return boxlist.convert(mode) 221 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /coco_meta.py: -------------------------------------------------------------------------------- 1 | CLASS_NAME = [ 2 | '__background__', 3 | 'person', 4 | 'bicycle', 5 | 'car', 6 | 'motorcycle', 7 | 'airplane', 8 | 'bus', 9 | 'train', 10 | 'truck', 11 | 'boat', 12 | 'traffic light', 13 | 'fire hydrant', 14 | 'stop sign', 15 | 'parking meter', 16 | 'bench', 17 | 'bird', 18 | 'cat', 19 | 'dog', 20 | 'horse', 21 | 'sheep', 22 | 'cow', 23 | 'elephant', 24 | 'bear', 25 | 'zebra', 26 | 'giraffe', 27 | 'backpack', 28 | 'umbrella', 29 | 'handbag', 30 | 'tie', 31 | 'suitcase', 32 | 'frisbee', 33 | 'skis', 34 | 'snowboard', 35 | 'sports ball', 36 | 'kite', 37 | 'baseball bat', 38 | 'baseball glove', 39 | 'skateboard', 40 | 'surfboard', 41 | 'tennis racket', 42 | 'bottle', 43 | 'wine glass', 44 | 'cup', 45 | 'fork', 46 | 'knife', 47 | 'spoon', 48 | 'bowl', 49 | 'banana', 50 | 'apple', 51 | 'sandwich', 52 | 'orange', 53 | 'broccoli', 54 | 'carrot', 55 | 'hot dog', 56 | 'pizza', 57 | 'donut', 58 | 'cake', 59 | 'chair', 60 | 'couch', 61 | 'potted plant', 62 | 'bed', 63 | 'dining table', 64 | 'toilet', 65 | 'tv', 66 | 'laptop', 67 | 'mouse', 68 | 'remote', 69 | 'keyboard', 70 | 'cell phone', 71 | 'microwave', 72 | 'oven', 73 | 'toaster', 74 | 'sink', 75 | 'refrigerator', 76 | 'book', 77 | 'clock', 78 | 'vase', 79 | 'scissors', 80 | 'teddy bear', 81 | 'hair drier', 82 | 'toothbrush', 83 | ] 84 | -------------------------------------------------------------------------------- /config: -------------------------------------------------------------------------------- 1 | class Config: 2 | pass 3 | 4 | config = Config() 5 | config.feat_channels = [0, 0, 512, 768, 1024] 6 | config.out_channel = 256 7 | config.use_p5 = True 8 | config.n_class = 80 9 | config.n_conv = 4 10 | config.prior = 0.01 11 | config.threshold = 0.05 12 | config.top_n = 1000 13 | config.nms_threshold = 0.6 14 | config.post_top_n = 100 15 | config.min_size = 0 16 | config.fpn_strides = [8, 16, 32, 64, 128] 17 | config.gamma = 2.0 18 | config.alpha = 0.25 19 | config.sizes = [[-1, 64], [64, 128], [128, 256], [256, 512], [512, 100000000]] 20 | config.train_min_size_range = (-1, -1) 21 | config.train_min_size = (800,) 22 | config.train_max_size = 1333 23 | config.test_min_size = 800 24 | config.test_max_size = 1333 25 | config.pixel_mean = [0.40789654, 0.44719302, 0.47026115] 26 | config.pixel_std = [0.28863828, 0.27408164, 0.27809835] 27 | config.size_divisible = 32 28 | config.center_sample = True 29 | config.pos_radius = 1.5 30 | config.loc_loss_type = 'giou' -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torchvision import datasets 5 | 6 | from boxlist import BoxList 7 | 8 | 9 | def has_only_empty_bbox(annot): 10 | return all(any(o <= 1 for o in obj['bbox'][2:]) for obj in annot) 11 | 12 | 13 | def has_valid_annotation(annot): 14 | if len(annot) == 0: 15 | return False 16 | 17 | if has_only_empty_bbox(annot): 18 | return False 19 | 20 | return True 21 | 22 | 23 | class COCODataset(datasets.CocoDetection): 24 | def __init__(self, path, split, transform=None): 25 | root = os.path.join(path, f'{split}2017') 26 | annot = os.path.join(path, 'annotations', f'instances_{split}2017.json') 27 | 28 | super().__init__(root, annot) 29 | 30 | self.ids = sorted(self.ids) 31 | 32 | if split == 'train': 33 | ids = [] 34 | 35 | for id in self.ids: 36 | ann_ids = self.coco.getAnnIds(imgIds=id, iscrowd=None) 37 | annot = self.coco.loadAnns(ann_ids) 38 | 39 | if has_valid_annotation(annot): 40 | ids.append(id) 41 | 42 | self.ids = ids 43 | 44 | self.category2id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())} 45 | self.id2category = {v: k for k, v in self.category2id.items()} 46 | self.id2img = {k: v for k, v in enumerate(self.ids)} 47 | 48 | self.transform = transform 49 | 50 | def __getitem__(self, index): 51 | img, annot = super().__getitem__(index) 52 | 53 | annot = [o for o in annot if o['iscrowd'] == 0] 54 | 55 | boxes = [o['bbox'] for o in annot] 56 | boxes = torch.as_tensor(boxes).reshape(-1, 4) 57 | target = BoxList(boxes, img.size, mode='xywh').convert('xyxy') 58 | 59 | classes = [o['category_id'] for o in annot] 60 | classes = [self.category2id[c] for c in classes] 61 | classes = torch.tensor(classes) 62 | target.fields['labels'] = classes 63 | 64 | target.clip(remove_empty=True) 65 | 66 | if self.transform is not None: 67 | img, target = self.transform(img, target) 68 | 69 | return img, target, index 70 | 71 | def get_image_meta(self, index): 72 | id = self.id2img[index] 73 | img_data = self.coco.imgs[id] 74 | 75 | return img_data 76 | 77 | 78 | class ImageList: 79 | def __init__(self, tensors, sizes): 80 | self.tensors = tensors 81 | self.sizes = sizes 82 | 83 | def to(self, *args, **kwargs): 84 | tensor = self.tensors.to(*args, **kwargs) 85 | 86 | return ImageList(tensor, self.sizes) 87 | 88 | 89 | def image_list(tensors, size_divisible=0): 90 | max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) 91 | 92 | if size_divisible > 0: 93 | stride = size_divisible 94 | max_size = list(max_size) 95 | max_size[1] = (max_size[1] | (stride - 1)) + 1 96 | max_size[2] = (max_size[2] | (stride - 1)) + 1 97 | max_size = tuple(max_size) 98 | 99 | shape = (len(tensors),) + max_size 100 | batch = tensors[0].new(*shape).zero_() 101 | 102 | for img, pad_img in zip(tensors, batch): 103 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 104 | 105 | sizes = [img.shape[-2:] for img in tensors] 106 | 107 | return ImageList(batch, sizes) 108 | 109 | 110 | def collate_fn(config): 111 | def collate_data(batch): 112 | batch = list(zip(*batch)) 113 | imgs = image_list(batch[0], config.size_divisible) 114 | targets = batch[1] 115 | ids = batch[2] 116 | 117 | return imgs, targets, ids 118 | 119 | return collate_data 120 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def all_gather(data): 45 | world_size = get_world_size() 46 | 47 | if world_size == 1: 48 | return [data] 49 | 50 | buffer = pickle.dumps(data) 51 | storage = torch.ByteStorage.from_buffer(buffer) 52 | tensor = torch.ByteTensor(storage).to('cuda') 53 | 54 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 55 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 56 | dist.all_gather(size_list, local_size) 57 | size_list = [int(size.item()) for size in size_list] 58 | max_size = max(size_list) 59 | 60 | tensor_list = [] 61 | for _ in size_list: 62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 63 | 64 | if local_size != max_size: 65 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 66 | tensor = torch.cat((tensor, padding), 0) 67 | 68 | dist.all_gather(tensor_list, tensor) 69 | 70 | data_list = [] 71 | 72 | for size, tensor in zip(size_list, tensor_list): 73 | buffer = tensor.cpu().numpy().tobytes()[:size] 74 | data_list.append(pickle.loads(buffer)) 75 | 76 | return data_list 77 | 78 | 79 | def reduce_loss_dict(loss_dict): 80 | world_size = get_world_size() 81 | 82 | if world_size < 2: 83 | return loss_dict 84 | 85 | with torch.no_grad(): 86 | keys = [] 87 | losses = [] 88 | 89 | for k in sorted(loss_dict.keys()): 90 | keys.append(k) 91 | losses.append(loss_dict[k]) 92 | 93 | losses = torch.stack(losses, 0) 94 | dist.reduce(losses, dst=0) 95 | 96 | if dist.get_rank() == 0: 97 | losses /= world_size 98 | 99 | reduced_losses = {k: v for k, v in zip(keys, losses)} 100 | 101 | return reduced_losses 102 | 103 | 104 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 105 | # Code is copy-pasted exactly as in torch.utils.data.distributed. 106 | # FIXME remove this once c10d fixes the bug it has 107 | 108 | 109 | class DistributedSampler(Sampler): 110 | """Sampler that restricts data loading to a subset of the dataset. 111 | It is especially useful in conjunction with 112 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 113 | process can pass a DistributedSampler instance as a DataLoader sampler, 114 | and load a subset of the original dataset that is exclusive to it. 115 | .. note:: 116 | Dataset is assumed to be of constant size. 117 | Arguments: 118 | dataset: Dataset used for sampling. 119 | num_replicas (optional): Number of processes participating in 120 | distributed training. 121 | rank (optional): Rank of the current process within num_replicas. 122 | """ 123 | 124 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 125 | if num_replicas is None: 126 | if not dist.is_available(): 127 | raise RuntimeError("Requires distributed package to be available") 128 | num_replicas = dist.get_world_size() 129 | if rank is None: 130 | if not dist.is_available(): 131 | raise RuntimeError("Requires distributed package to be available") 132 | rank = dist.get_rank() 133 | self.dataset = dataset 134 | self.num_replicas = num_replicas 135 | self.rank = rank 136 | self.epoch = 0 137 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 138 | self.total_size = self.num_samples * self.num_replicas 139 | self.shuffle = shuffle 140 | 141 | def __iter__(self): 142 | if self.shuffle: 143 | # deterministically shuffle based on epoch 144 | g = torch.Generator() 145 | g.manual_seed(self.epoch) 146 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 147 | else: 148 | indices = torch.arange(len(self.dataset)).tolist() 149 | 150 | # add extra samples to make it evenly divisible 151 | indices += indices[: (self.total_size - len(indices))] 152 | assert len(indices) == self.total_size 153 | 154 | # subsample 155 | offset = self.num_samples * self.rank 156 | indices = indices[offset : offset + self.num_samples] 157 | assert len(indices) == self.num_samples 158 | 159 | return iter(indices) 160 | 161 | def __len__(self): 162 | return self.num_samples 163 | 164 | def set_epoch(self, epoch): 165 | self.epoch = epoch 166 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | from pycocotools.coco import COCO 7 | from pycocotools.cocoeval import COCOeval 8 | 9 | 10 | def evaluate(dataset, predictions): 11 | coco_results = {} 12 | coco_results['bbox'] = make_coco_detection(predictions, dataset) 13 | 14 | results = COCOResult('bbox') 15 | 16 | with tempfile.NamedTemporaryFile() as f: 17 | path = f.name 18 | res = evaluate_predictions_on_coco( 19 | dataset.coco, coco_results['bbox'], path, 'bbox' 20 | ) 21 | results.update(res) 22 | 23 | print(results) 24 | 25 | return res 26 | 27 | 28 | def evaluate_predictions_on_coco(coco_gt, results, result_file, iou_type): 29 | with open(result_file, 'w') as f: 30 | json.dump(results, f) 31 | 32 | coco_dt = coco_gt.loadRes(str(result_file)) if results else COCO() 33 | 34 | coco_eval = COCOeval(coco_gt, coco_dt, iou_type) 35 | coco_eval.evaluate() 36 | coco_eval.accumulate() 37 | coco_eval.summarize() 38 | 39 | # compute_thresholds_for_classes(coco_eval) 40 | 41 | return coco_eval 42 | 43 | 44 | def compute_thresholds_for_classes(coco_eval): 45 | precision = coco_eval.eval['precision'] 46 | precision = precision[0, :, :, 0, -1] 47 | scores = coco_eval.eval['scores'] 48 | scores = scores[0, :, :, 0, -1] 49 | 50 | recall = np.linspace(0, 1, num=precision.shape[0]) 51 | recall = recall[:, None] 52 | 53 | f1 = (2 * precision * recall) / (np.maximum(precision + recall, 1e-6)) 54 | max_f1 = f1.max(0) 55 | max_f1_id = f1.argmax(0) 56 | scores = scores[max_f1_id, range(len(max_f1_id))] 57 | 58 | print('Maximum f1 for classes:') 59 | print(list(max_f1)) 60 | print('Score thresholds for classes') 61 | print(list(scores)) 62 | 63 | 64 | def make_coco_detection(predictions, dataset): 65 | coco_results = [] 66 | 67 | for id, pred in enumerate(predictions): 68 | orig_id = dataset.id2img[id] 69 | 70 | if len(pred) == 0: 71 | continue 72 | 73 | img_meta = dataset.get_image_meta(id) 74 | width = img_meta['width'] 75 | height = img_meta['height'] 76 | pred = pred.resize((width, height)) 77 | pred = pred.convert('xywh') 78 | 79 | boxes = pred.box.tolist() 80 | scores = pred.fields['scores'].tolist() 81 | labels = pred.fields['labels'].tolist() 82 | 83 | labels = [dataset.id2category[i] for i in labels] 84 | 85 | coco_results.extend( 86 | [ 87 | { 88 | 'image_id': orig_id, 89 | 'category_id': labels[k], 90 | 'bbox': box, 91 | 'score': scores[k], 92 | } 93 | for k, box in enumerate(boxes) 94 | ] 95 | ) 96 | 97 | return coco_results 98 | 99 | 100 | class COCOResult: 101 | METRICS = { 102 | 'bbox': ['AP', 'AP50', 'AP75', 'APs', 'APm', 'APl'], 103 | 'segm': ['AP', 'AP50', 'AP75', 'APs', 'APm', 'APl'], 104 | 'box_proposal': [ 105 | 'AR@100', 106 | 'ARs@100', 107 | 'ARm@100', 108 | 'ARl@100', 109 | 'AR@1000', 110 | 'ARs@1000', 111 | 'ARm@1000', 112 | 'ARl@1000', 113 | ], 114 | 'keypoints': ['AP', 'AP50', 'AP75', 'APm', 'APl'], 115 | } 116 | 117 | def __init__(self, *iou_types): 118 | allowed_types = ("box_proposal", "bbox", "segm", "keypoints") 119 | assert all(iou_type in allowed_types for iou_type in iou_types) 120 | results = OrderedDict() 121 | for iou_type in iou_types: 122 | results[iou_type] = OrderedDict( 123 | [(metric, -1) for metric in COCOResult.METRICS[iou_type]] 124 | ) 125 | self.results = results 126 | 127 | def update(self, coco_eval): 128 | if coco_eval is None: 129 | return 130 | 131 | assert isinstance(coco_eval, COCOeval) 132 | s = coco_eval.stats 133 | iou_type = coco_eval.params.iouType 134 | res = self.results[iou_type] 135 | metrics = COCOResult.METRICS[iou_type] 136 | for idx, metric in enumerate(metrics): 137 | res[metric] = s[idx] 138 | 139 | def __repr__(self): 140 | return repr(self.results) 141 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | INF = 100000000 6 | 7 | 8 | class IOULoss(nn.Module): 9 | def __init__(self, loc_loss_type): 10 | super().__init__() 11 | 12 | self.loc_loss_type = loc_loss_type 13 | 14 | def forward(self, out, target, weight=None): 15 | pred_left, pred_top, pred_right, pred_bottom = out.unbind(1) 16 | target_left, target_top, target_right, target_bottom = target.unbind(1) 17 | 18 | target_area = (target_left + target_right) * (target_top + target_bottom) 19 | pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) 20 | 21 | w_intersect = torch.min(pred_left, target_left) + torch.min( 22 | pred_right, target_right 23 | ) 24 | h_intersect = torch.min(pred_bottom, target_bottom) + torch.min( 25 | pred_top, target_top 26 | ) 27 | 28 | area_intersect = w_intersect * h_intersect 29 | area_union = target_area + pred_area - area_intersect 30 | 31 | ious = (area_intersect + 1) / (area_union + 1) 32 | 33 | if self.loc_loss_type == 'iou': 34 | loss = -torch.log(ious) 35 | 36 | elif self.loc_loss_type == 'giou': 37 | g_w_intersect = torch.max(pred_left, target_left) + torch.max( 38 | pred_right, target_right 39 | ) 40 | g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max( 41 | pred_top, target_top 42 | ) 43 | g_intersect = g_w_intersect * g_h_intersect + 1e-7 44 | gious = ious - (g_intersect - area_union) / g_intersect 45 | 46 | loss = 1 - gious 47 | 48 | if weight is not None and weight.sum() > 0: 49 | return (loss * weight).sum() / weight.sum() 50 | 51 | else: 52 | return loss.mean() 53 | 54 | 55 | def clip_sigmoid(input): 56 | out = torch.clamp(torch.sigmoid(input), min=1e-4, max=1 - 1e-4) 57 | 58 | return out 59 | 60 | 61 | class SigmoidFocalLoss(nn.Module): 62 | def __init__(self, gamma, alpha): 63 | super().__init__() 64 | 65 | self.gamma = gamma 66 | self.alpha = alpha 67 | 68 | def forward(self, out, target): 69 | n_class = out.shape[1] 70 | class_ids = torch.arange( 71 | 1, n_class + 1, dtype=target.dtype, device=target.device 72 | ).unsqueeze(0) 73 | 74 | t = target.unsqueeze(1) 75 | p = torch.sigmoid(out) 76 | 77 | gamma = self.gamma 78 | alpha = self.alpha 79 | 80 | term1 = (1 - p) ** gamma * torch.log(p) 81 | term2 = p ** gamma * torch.log(1 - p) 82 | 83 | # print(term1.sum(), term2.sum()) 84 | 85 | loss = ( 86 | -(t == class_ids).float() * alpha * term1 87 | - ((t != class_ids) * (t >= 0)).float() * (1 - alpha) * term2 88 | ) 89 | 90 | return loss.sum() 91 | 92 | 93 | class FCOSLoss(nn.Module): 94 | def __init__( 95 | self, sizes, gamma, alpha, iou_loss_type, center_sample, fpn_strides, pos_radius 96 | ): 97 | super().__init__() 98 | 99 | self.sizes = sizes 100 | 101 | self.cls_loss = SigmoidFocalLoss(gamma, alpha) 102 | self.box_loss = IOULoss(iou_loss_type) 103 | self.center_loss = nn.BCEWithLogitsLoss() 104 | 105 | self.center_sample = center_sample 106 | self.strides = fpn_strides 107 | self.radius = pos_radius 108 | 109 | def prepare_target(self, points, targets): 110 | ex_size_of_interest = [] 111 | 112 | for i, point_per_level in enumerate(points): 113 | size_of_interest_per_level = point_per_level.new_tensor(self.sizes[i]) 114 | ex_size_of_interest.append( 115 | size_of_interest_per_level[None].expand(len(point_per_level), -1) 116 | ) 117 | 118 | ex_size_of_interest = torch.cat(ex_size_of_interest, 0) 119 | n_point_per_level = [len(point_per_level) for point_per_level in points] 120 | point_all = torch.cat(points, dim=0) 121 | label, box_target = self.compute_target_for_location( 122 | point_all, targets, ex_size_of_interest, n_point_per_level 123 | ) 124 | 125 | for i in range(len(label)): 126 | label[i] = torch.split(label[i], n_point_per_level, 0) 127 | box_target[i] = torch.split(box_target[i], n_point_per_level, 0) 128 | 129 | label_level_first = [] 130 | box_target_level_first = [] 131 | 132 | for level in range(len(points)): 133 | label_level_first.append( 134 | torch.cat([label_per_img[level] for label_per_img in label], 0) 135 | ) 136 | box_target_level_first.append( 137 | torch.cat( 138 | [box_target_per_img[level] for box_target_per_img in box_target], 0 139 | ) 140 | ) 141 | 142 | return label_level_first, box_target_level_first 143 | 144 | def get_sample_region(self, gt, strides, n_point_per_level, xs, ys, radius=1): 145 | n_gt = gt.shape[0] 146 | n_loc = len(xs) 147 | gt = gt[None].expand(n_loc, n_gt, 4) 148 | center_x = (gt[..., 0] + gt[..., 2]) / 2 149 | center_y = (gt[..., 1] + gt[..., 3]) / 2 150 | 151 | if center_x[..., 0].sum() == 0: 152 | return xs.new_zeros(xs.shape, dtype=torch.uint8) 153 | 154 | begin = 0 155 | 156 | center_gt = gt.new_zeros(gt.shape) 157 | 158 | for level, n_p in enumerate(n_point_per_level): 159 | end = begin + n_p 160 | stride = strides[level] * radius 161 | 162 | x_min = center_x[begin:end] - stride 163 | y_min = center_y[begin:end] - stride 164 | x_max = center_x[begin:end] + stride 165 | y_max = center_y[begin:end] + stride 166 | 167 | center_gt[begin:end, :, 0] = torch.where( 168 | x_min > gt[begin:end, :, 0], x_min, gt[begin:end, :, 0] 169 | ) 170 | center_gt[begin:end, :, 1] = torch.where( 171 | y_min > gt[begin:end, :, 1], y_min, gt[begin:end, :, 1] 172 | ) 173 | center_gt[begin:end, :, 2] = torch.where( 174 | x_max > gt[begin:end, :, 2], gt[begin:end, :, 2], x_max 175 | ) 176 | center_gt[begin:end, :, 3] = torch.where( 177 | y_max > gt[begin:end, :, 3], gt[begin:end, :, 3], y_max 178 | ) 179 | 180 | begin = end 181 | 182 | left = xs[:, None] - center_gt[..., 0] 183 | right = center_gt[..., 2] - xs[:, None] 184 | top = ys[:, None] - center_gt[..., 1] 185 | bottom = center_gt[..., 3] - ys[:, None] 186 | 187 | center_bbox = torch.stack((left, top, right, bottom), -1) 188 | is_in_boxes = center_bbox.min(-1)[0] > 0 189 | 190 | return is_in_boxes 191 | 192 | def compute_target_for_location( 193 | self, locations, targets, sizes_of_interest, n_point_per_level 194 | ): 195 | labels = [] 196 | box_targets = [] 197 | xs, ys = locations[:, 0], locations[:, 1] 198 | 199 | for i in range(len(targets)): 200 | targets_per_img = targets[i] 201 | assert targets_per_img.mode == 'xyxy' 202 | bboxes = targets_per_img.box 203 | labels_per_img = targets_per_img.fields['labels'] 204 | area = targets_per_img.area() 205 | 206 | l = xs[:, None] - bboxes[:, 0][None] 207 | t = ys[:, None] - bboxes[:, 1][None] 208 | r = bboxes[:, 2][None] - xs[:, None] 209 | b = bboxes[:, 3][None] - ys[:, None] 210 | 211 | box_targets_per_img = torch.stack([l, t, r, b], 2) 212 | 213 | if self.center_sample: 214 | is_in_boxes = self.get_sample_region( 215 | bboxes, self.strides, n_point_per_level, xs, ys, radius=self.radius 216 | ) 217 | 218 | else: 219 | is_in_boxes = box_targets_per_img.min(2)[0] > 0 220 | 221 | max_box_targets_per_img = box_targets_per_img.max(2)[0] 222 | 223 | is_cared_in_level = ( 224 | max_box_targets_per_img >= sizes_of_interest[:, [0]] 225 | ) & (max_box_targets_per_img <= sizes_of_interest[:, [1]]) 226 | 227 | locations_to_gt_area = area[None].repeat(len(locations), 1) 228 | locations_to_gt_area[is_in_boxes == 0] = INF 229 | locations_to_gt_area[is_cared_in_level == 0] = INF 230 | 231 | locations_to_min_area, locations_to_gt_id = locations_to_gt_area.min(1) 232 | 233 | box_targets_per_img = box_targets_per_img[ 234 | range(len(locations)), locations_to_gt_id 235 | ] 236 | labels_per_img = labels_per_img[locations_to_gt_id] 237 | labels_per_img[locations_to_min_area == INF] = 0 238 | 239 | labels.append(labels_per_img) 240 | box_targets.append(box_targets_per_img) 241 | 242 | return labels, box_targets 243 | 244 | def compute_centerness_targets(self, box_targets): 245 | left_right = box_targets[:, [0, 2]] 246 | top_bottom = box_targets[:, [1, 3]] 247 | centerness = (left_right.min(-1)[0] / left_right.max(-1)[0]) * ( 248 | top_bottom.min(-1)[0] / top_bottom.max(-1)[0] 249 | ) 250 | 251 | return torch.sqrt(centerness) 252 | 253 | def forward(self, locations, cls_pred, box_pred, center_pred, targets): 254 | batch = cls_pred[0].shape[0] 255 | n_class = cls_pred[0].shape[1] 256 | 257 | labels, box_targets = self.prepare_target(locations, targets) 258 | 259 | cls_flat = [] 260 | box_flat = [] 261 | center_flat = [] 262 | 263 | labels_flat = [] 264 | box_targets_flat = [] 265 | 266 | for i in range(len(labels)): 267 | cls_flat.append(cls_pred[i].permute(0, 2, 3, 1).reshape(-1, n_class)) 268 | box_flat.append(box_pred[i].permute(0, 2, 3, 1).reshape(-1, 4)) 269 | center_flat.append(center_pred[i].permute(0, 2, 3, 1).reshape(-1)) 270 | 271 | labels_flat.append(labels[i].reshape(-1)) 272 | box_targets_flat.append(box_targets[i].reshape(-1, 4)) 273 | 274 | cls_flat = torch.cat(cls_flat, 0) 275 | box_flat = torch.cat(box_flat, 0) 276 | center_flat = torch.cat(center_flat, 0) 277 | 278 | labels_flat = torch.cat(labels_flat, 0) 279 | box_targets_flat = torch.cat(box_targets_flat, 0) 280 | 281 | pos_id = torch.nonzero(labels_flat > 0).squeeze(1) 282 | 283 | cls_loss = self.cls_loss(cls_flat, labels_flat.int()) / (pos_id.numel() + batch) 284 | 285 | box_flat = box_flat[pos_id] 286 | center_flat = center_flat[pos_id] 287 | 288 | box_targets_flat = box_targets_flat[pos_id] 289 | 290 | if pos_id.numel() > 0: 291 | center_targets = self.compute_centerness_targets(box_targets_flat) 292 | 293 | box_loss = self.box_loss(box_flat, box_targets_flat, center_targets) 294 | center_loss = self.center_loss(center_flat, center_targets) 295 | 296 | else: 297 | box_loss = box_flat.sum() 298 | center_loss = center_flat.sum() 299 | 300 | return cls_loss, box_loss, center_loss 301 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from loss import FCOSLoss 8 | from postprocess import FCOSPostprocessor 9 | 10 | 11 | class Scale(nn.Module): 12 | def __init__(self, init=1.0): 13 | super().__init__() 14 | 15 | self.scale = nn.Parameter(torch.tensor([init], dtype=torch.float32)) 16 | 17 | def forward(self, input): 18 | return input * self.scale 19 | 20 | 21 | def init_conv_kaiming(module): 22 | if isinstance(module, nn.Conv2d): 23 | nn.init.kaiming_uniform_(module.weight, a=1) 24 | 25 | if module.bias is not None: 26 | nn.init.constant_(module.bias, 0) 27 | 28 | 29 | def init_conv_std(module, std=0.01): 30 | if isinstance(module, nn.Conv2d): 31 | nn.init.normal_(module.weight, std=std) 32 | 33 | if module.bias is not None: 34 | nn.init.constant_(module.bias, 0) 35 | 36 | 37 | class FPN(nn.Module): 38 | def __init__(self, in_channels, out_channel, top_blocks=None): 39 | super().__init__() 40 | 41 | self.inner_convs = nn.ModuleList() 42 | self.out_convs = nn.ModuleList() 43 | 44 | for i, in_channel in enumerate(in_channels, 1): 45 | if in_channel == 0: 46 | self.inner_convs.append(None) 47 | self.out_convs.append(None) 48 | 49 | continue 50 | 51 | inner_conv = nn.Conv2d(in_channel, out_channel, 1) 52 | feat_conv = nn.Conv2d(out_channel, out_channel, 3, padding=1) 53 | 54 | self.inner_convs.append(inner_conv) 55 | self.out_convs.append(feat_conv) 56 | 57 | self.apply(init_conv_kaiming) 58 | 59 | self.top_blocks = top_blocks 60 | 61 | def forward(self, inputs): 62 | inner = self.inner_convs[-1](inputs[-1]) 63 | outs = [self.out_convs[-1](inner)] 64 | 65 | for feat, inner_conv, out_conv in zip( 66 | inputs[:-1][::-1], self.inner_convs[:-1][::-1], self.out_convs[:-1][::-1] 67 | ): 68 | if inner_conv is None: 69 | continue 70 | 71 | upsample = F.interpolate(inner, scale_factor=2, mode='nearest') 72 | inner_feat = inner_conv(feat) 73 | inner = inner_feat + upsample 74 | outs.insert(0, out_conv(inner)) 75 | 76 | if self.top_blocks is not None: 77 | top_outs = self.top_blocks(outs[-1], inputs[-1]) 78 | outs.extend(top_outs) 79 | 80 | return outs 81 | 82 | 83 | class FPNTopP6P7(nn.Module): 84 | def __init__(self, in_channel, out_channel, use_p5=True): 85 | super().__init__() 86 | 87 | self.p6 = nn.Conv2d(in_channel, out_channel, 3, stride=2, padding=1) 88 | self.p7 = nn.Conv2d(out_channel, out_channel, 3, stride=2, padding=1) 89 | 90 | self.apply(init_conv_kaiming) 91 | 92 | self.use_p5 = use_p5 93 | 94 | def forward(self, f5, p5): 95 | input = p5 if self.use_p5 else f5 96 | 97 | p6 = self.p6(input) 98 | p7 = self.p7(F.relu(p6)) 99 | 100 | return p6, p7 101 | 102 | 103 | class FCOSHead(nn.Module): 104 | def __init__(self, in_channel, n_class, n_conv, prior): 105 | super().__init__() 106 | 107 | n_class = n_class - 1 108 | 109 | cls_tower = [] 110 | bbox_tower = [] 111 | 112 | for i in range(n_conv): 113 | cls_tower.append( 114 | nn.Conv2d(in_channel, in_channel, 3, padding=1, bias=False) 115 | ) 116 | cls_tower.append(nn.GroupNorm(32, in_channel)) 117 | cls_tower.append(nn.ReLU()) 118 | 119 | bbox_tower.append( 120 | nn.Conv2d(in_channel, in_channel, 3, padding=1, bias=False) 121 | ) 122 | bbox_tower.append(nn.GroupNorm(32, in_channel)) 123 | bbox_tower.append(nn.ReLU()) 124 | 125 | self.cls_tower = nn.Sequential(*cls_tower) 126 | self.bbox_tower = nn.Sequential(*bbox_tower) 127 | 128 | self.cls_pred = nn.Conv2d(in_channel, n_class, 3, padding=1) 129 | self.bbox_pred = nn.Conv2d(in_channel, 4, 3, padding=1) 130 | self.center_pred = nn.Conv2d(in_channel, 1, 3, padding=1) 131 | 132 | self.apply(init_conv_std) 133 | 134 | prior_bias = -math.log((1 - prior) / prior) 135 | nn.init.constant_(self.cls_pred.bias, prior_bias) 136 | 137 | self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)]) 138 | 139 | def forward(self, input): 140 | logits = [] 141 | bboxes = [] 142 | centers = [] 143 | 144 | for feat, scale in zip(input, self.scales): 145 | cls_out = self.cls_tower(feat) 146 | 147 | logits.append(self.cls_pred(cls_out)) 148 | centers.append(self.center_pred(cls_out)) 149 | 150 | bbox_out = self.bbox_tower(feat) 151 | bbox_out = torch.exp(scale(self.bbox_pred(bbox_out))) 152 | 153 | bboxes.append(bbox_out) 154 | 155 | return logits, bboxes, centers 156 | 157 | 158 | class FCOS(nn.Module): 159 | def __init__(self, config, backbone): 160 | super().__init__() 161 | 162 | self.backbone = backbone 163 | fpn_top = FPNTopP6P7( 164 | config.feat_channels[-1], config.out_channel, use_p5=config.use_p5 165 | ) 166 | self.fpn = FPN(config.feat_channels, config.out_channel, fpn_top) 167 | self.head = FCOSHead( 168 | config.out_channel, config.n_class, config.n_conv, config.prior 169 | ) 170 | self.postprocessor = FCOSPostprocessor( 171 | config.threshold, 172 | config.top_n, 173 | config.nms_threshold, 174 | config.post_top_n, 175 | config.min_size, 176 | config.n_class, 177 | ) 178 | self.loss = FCOSLoss( 179 | config.sizes, 180 | config.gamma, 181 | config.alpha, 182 | config.iou_loss_type, 183 | config.center_sample, 184 | config.fpn_strides, 185 | config.pos_radius, 186 | ) 187 | 188 | self.fpn_strides = config.fpn_strides 189 | 190 | def train(self, mode=True): 191 | super().train(mode) 192 | 193 | def freeze_bn(module): 194 | if isinstance(module, nn.BatchNorm2d): 195 | module.eval() 196 | 197 | self.apply(freeze_bn) 198 | 199 | def forward(self, input, image_sizes=None, targets=None): 200 | features = self.backbone(input) 201 | features = self.fpn(features) 202 | cls_pred, box_pred, center_pred = self.head(features) 203 | # print(cls_pred, box_pred, center_pred) 204 | location = self.compute_location(features) 205 | 206 | if self.training: 207 | loss_cls, loss_box, loss_center = self.loss( 208 | location, cls_pred, box_pred, center_pred, targets 209 | ) 210 | losses = { 211 | 'loss_cls': loss_cls, 212 | 'loss_box': loss_box, 213 | 'loss_center': loss_center, 214 | } 215 | 216 | return None, losses 217 | 218 | else: 219 | boxes = self.postprocessor( 220 | location, cls_pred, box_pred, center_pred, image_sizes 221 | ) 222 | 223 | return boxes, None 224 | 225 | def compute_location(self, features): 226 | locations = [] 227 | 228 | for i, feat in enumerate(features): 229 | _, _, height, width = feat.shape 230 | location_per_level = self.compute_location_per_level( 231 | height, width, self.fpn_strides[i], feat.device 232 | ) 233 | locations.append(location_per_level) 234 | 235 | return locations 236 | 237 | def compute_location_per_level(self, height, width, stride, device): 238 | shift_x = torch.arange( 239 | 0, width * stride, step=stride, dtype=torch.float32, device=device 240 | ) 241 | shift_y = torch.arange( 242 | 0, height * stride, step=stride, dtype=torch.float32, device=device 243 | ) 244 | shift_y, shift_x = torch.meshgrid(shift_y, shift_x) 245 | shift_x = shift_x.reshape(-1) 246 | shift_y = shift_y.reshape(-1) 247 | location = torch.stack((shift_x, shift_y), 1) + stride // 2 248 | 249 | return location 250 | -------------------------------------------------------------------------------- /postprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from boxlist import BoxList, boxlist_nms, remove_small_box, cat_boxlist 5 | 6 | 7 | class FCOSPostprocessor(nn.Module): 8 | def __init__(self, threshold, top_n, nms_threshold, post_top_n, min_size, n_class): 9 | super().__init__() 10 | 11 | self.threshold = threshold 12 | self.top_n = top_n 13 | self.nms_threshold = nms_threshold 14 | self.post_top_n = post_top_n 15 | self.min_size = min_size 16 | self.n_class = n_class 17 | 18 | def forward_single_feature_map( 19 | self, location, cls_pred, box_pred, center_pred, image_sizes 20 | ): 21 | batch, channel, height, width = cls_pred.shape 22 | 23 | cls_pred = cls_pred.view(batch, channel, height, width).permute(0, 2, 3, 1) 24 | cls_pred = cls_pred.reshape(batch, -1, channel).sigmoid() 25 | 26 | box_pred = box_pred.view(batch, 4, height, width).permute(0, 2, 3, 1) 27 | box_pred = box_pred.reshape(batch, -1, 4) 28 | 29 | center_pred = center_pred.view(batch, 1, height, width).permute(0, 2, 3, 1) 30 | center_pred = center_pred.reshape(batch, -1).sigmoid() 31 | 32 | candid_ids = cls_pred > self.threshold 33 | top_ns = candid_ids.view(batch, -1).sum(1) 34 | top_ns = top_ns.clamp(max=self.top_n) 35 | 36 | cls_pred = cls_pred * center_pred[:, :, None] 37 | 38 | results = [] 39 | 40 | for i in range(batch): 41 | cls_p = cls_pred[i] 42 | candid_id = candid_ids[i] 43 | cls_p = cls_p[candid_id] 44 | candid_nonzero = candid_id.nonzero() 45 | box_loc = candid_nonzero[:, 0] 46 | class_id = candid_nonzero[:, 1] + 1 47 | 48 | box_p = box_pred[i] 49 | box_p = box_p[box_loc] 50 | loc = location[box_loc] 51 | 52 | top_n = top_ns[i] 53 | 54 | if candid_id.sum().item() > top_n.item(): 55 | cls_p, top_k_id = cls_p.topk(top_n, sorted=False) 56 | class_id = class_id[top_k_id] 57 | box_p = box_p[top_k_id] 58 | loc = loc[top_k_id] 59 | 60 | detections = torch.stack( 61 | [ 62 | loc[:, 0] - box_p[:, 0], 63 | loc[:, 1] - box_p[:, 1], 64 | loc[:, 0] + box_p[:, 2], 65 | loc[:, 1] + box_p[:, 3], 66 | ], 67 | 1, 68 | ) 69 | 70 | height, width = image_sizes[i] 71 | 72 | boxlist = BoxList(detections, (int(width), int(height)), mode='xyxy') 73 | boxlist.fields['labels'] = class_id 74 | boxlist.fields['scores'] = torch.sqrt(cls_p) 75 | boxlist = boxlist.clip(remove_empty=False) 76 | boxlist = remove_small_box(boxlist, self.min_size) 77 | 78 | results.append(boxlist) 79 | 80 | return results 81 | 82 | def forward(self, location, cls_pred, box_pred, center_pred, image_sizes): 83 | boxes = [] 84 | 85 | for loc, cls_p, box_p, center_p in zip( 86 | location, cls_pred, box_pred, center_pred 87 | ): 88 | boxes.append( 89 | self.forward_single_feature_map( 90 | loc, cls_p, box_p, center_p, image_sizes 91 | ) 92 | ) 93 | 94 | boxlists = list(zip(*boxes)) 95 | boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] 96 | boxlists = self.select_over_scales(boxlists) 97 | 98 | return boxlists 99 | 100 | def select_over_scales(self, boxlists): 101 | results = [] 102 | 103 | for boxlist in boxlists: 104 | scores = boxlist.fields['scores'] 105 | labels = boxlist.fields['labels'] 106 | box = boxlist.box 107 | 108 | result = [] 109 | 110 | for j in range(1, self.n_class): 111 | id = (labels == j).nonzero().view(-1) 112 | score_j = scores[id] 113 | box_j = box[id, :].view(-1, 4) 114 | box_by_class = BoxList(box_j, boxlist.size, mode='xyxy') 115 | box_by_class.fields['scores'] = score_j 116 | box_by_class = boxlist_nms(box_by_class, score_j, self.nms_threshold) 117 | n_label = len(box_by_class) 118 | box_by_class.fields['labels'] = torch.full( 119 | (n_label,), j, dtype=torch.int64, device=scores.device 120 | ) 121 | result.append(box_by_class) 122 | 123 | result = cat_boxlist(result) 124 | n_detection = len(result) 125 | 126 | if n_detection > self.post_top_n > 0: 127 | scores = result.fields['scores'] 128 | img_threshold, _ = torch.kthvalue( 129 | scores.cpu(), n_detection - self.post_top_n + 1 130 | ) 131 | keep = scores >= img_threshold.item() 132 | keep = torch.nonzero(keep).squeeze(1) 133 | result = result[keep] 134 | 135 | results.append(result) 136 | 137 | return results 138 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn, optim 5 | from torch.utils.data import DataLoader, sampler 6 | from tqdm import tqdm 7 | 8 | from argument import get_args 9 | from backbone import vovnet57 10 | from dataset import COCODataset, collate_fn 11 | from model import FCOS 12 | from transform import preset_transform 13 | from evaluate import evaluate 14 | from distributed import ( 15 | get_rank, 16 | synchronize, 17 | reduce_loss_dict, 18 | DistributedSampler, 19 | all_gather, 20 | ) 21 | 22 | 23 | def accumulate_predictions(predictions): 24 | all_predictions = all_gather(predictions) 25 | 26 | if get_rank() != 0: 27 | return 28 | 29 | predictions = {} 30 | 31 | for p in all_predictions: 32 | predictions.update(p) 33 | 34 | ids = list(sorted(predictions.keys())) 35 | 36 | if len(ids) != ids[-1] + 1: 37 | print('Evaluation results is not contiguous') 38 | 39 | predictions = [predictions[i] for i in ids] 40 | 41 | return predictions 42 | 43 | 44 | @torch.no_grad() 45 | def valid(args, epoch, loader, dataset, model, device): 46 | if args.distributed: 47 | model = model.module 48 | 49 | torch.cuda.empty_cache() 50 | 51 | model.eval() 52 | 53 | pbar = tqdm(loader, dynamic_ncols=True) 54 | 55 | preds = {} 56 | 57 | for images, targets, ids in pbar: 58 | model.zero_grad() 59 | 60 | images = images.to(device) 61 | targets = [target.to(device) for target in targets] 62 | 63 | pred, _ = model(images.tensors, images.sizes) 64 | 65 | pred = [p.to('cpu') for p in pred] 66 | 67 | preds.update({id: p for id, p in zip(ids, pred)}) 68 | 69 | preds = accumulate_predictions(preds) 70 | 71 | if get_rank() != 0: 72 | return 73 | 74 | evaluate(dataset, preds) 75 | 76 | 77 | def train(args, epoch, loader, model, optimizer, device): 78 | model.train() 79 | 80 | if get_rank() == 0: 81 | pbar = tqdm(loader, dynamic_ncols=True) 82 | 83 | else: 84 | pbar = loader 85 | 86 | for images, targets, _ in pbar: 87 | model.zero_grad() 88 | 89 | images = images.to(device) 90 | targets = [target.to(device) for target in targets] 91 | 92 | _, loss_dict = model(images.tensors, targets=targets) 93 | loss_cls = loss_dict['loss_cls'].mean() 94 | loss_box = loss_dict['loss_box'].mean() 95 | loss_center = loss_dict['loss_center'].mean() 96 | 97 | loss = loss_cls + loss_box + loss_center 98 | loss.backward() 99 | nn.utils.clip_grad_norm_(model.parameters(), 10) 100 | optimizer.step() 101 | 102 | loss_reduced = reduce_loss_dict(loss_dict) 103 | loss_cls = loss_reduced['loss_cls'].mean().item() 104 | loss_box = loss_reduced['loss_box'].mean().item() 105 | loss_center = loss_reduced['loss_center'].mean().item() 106 | 107 | if get_rank() == 0: 108 | pbar.set_description( 109 | ( 110 | f'epoch: {epoch + 1}; cls: {loss_cls:.4f}; ' 111 | f'box: {loss_box:.4f}; center: {loss_center:.4f}' 112 | ) 113 | ) 114 | 115 | 116 | def data_sampler(dataset, shuffle, distributed): 117 | if distributed: 118 | return DistributedSampler(dataset, shuffle=shuffle) 119 | 120 | if shuffle: 121 | return sampler.RandomSampler(dataset) 122 | 123 | else: 124 | return sampler.SequentialSampler(dataset) 125 | 126 | 127 | if __name__ == '__main__': 128 | args = get_args() 129 | 130 | n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 131 | args.distributed = n_gpu > 1 132 | 133 | if args.distributed: 134 | torch.cuda.set_device(args.local_rank) 135 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 136 | synchronize() 137 | 138 | device = 'cuda' 139 | 140 | train_set = COCODataset(args.path, 'train', preset_transform(args, train=True)) 141 | valid_set = COCODataset(args.path, 'val', preset_transform(args, train=False)) 142 | 143 | backbone = vovnet57(pretrained=True) 144 | model = FCOS(args, backbone) 145 | model = model.to(device) 146 | 147 | optimizer = optim.SGD( 148 | model.parameters(), 149 | lr=args.lr, 150 | momentum=0.9, 151 | weight_decay=args.l2, 152 | nesterov=True, 153 | ) 154 | scheduler = optim.lr_scheduler.MultiStepLR( 155 | optimizer, milestones=[16, 22], gamma=0.1 156 | ) 157 | 158 | if args.distributed: 159 | model = nn.parallel.DistributedDataParallel( 160 | model, 161 | device_ids=[args.local_rank], 162 | output_device=args.local_rank, 163 | broadcast_buffers=False, 164 | ) 165 | 166 | train_loader = DataLoader( 167 | train_set, 168 | batch_size=args.batch, 169 | sampler=data_sampler(train_set, shuffle=True, distributed=args.distributed), 170 | num_workers=2, 171 | collate_fn=collate_fn(args), 172 | ) 173 | valid_loader = DataLoader( 174 | valid_set, 175 | batch_size=args.batch, 176 | sampler=data_sampler(valid_set, shuffle=False, distributed=args.distributed), 177 | num_workers=2, 178 | collate_fn=collate_fn(args), 179 | ) 180 | 181 | for epoch in range(args.epoch): 182 | train(args, epoch, train_loader, model, optimizer, device) 183 | valid(args, epoch, valid_loader, valid_set, model, device) 184 | 185 | scheduler.step() 186 | 187 | if get_rank() == 0: 188 | torch.save( 189 | {'model': model.module.state_dict(), 'optim': optimizer.state_dict()}, 190 | f'checkpoint/epoch-{epoch + 1}.pt', 191 | ) 192 | 193 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torchvision 5 | from torchvision.transforms import functional as F 6 | 7 | 8 | class Compose: 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | 12 | def __call__(self, img, target): 13 | for t in self.transforms: 14 | img, target = t(img, target) 15 | 16 | return img, target 17 | 18 | def __repr__(self): 19 | format_str = self.__class__.__name__ + '(' 20 | for t in self.transforms: 21 | format_str += '\n' 22 | format_str += f' {t}' 23 | format_str += '\n)' 24 | 25 | return format_str 26 | 27 | 28 | class Resize: 29 | def __init__(self, min_size, max_size): 30 | if not isinstance(min_size, (list, tuple)): 31 | min_size = (min_size,) 32 | 33 | self.min_size = min_size 34 | self.max_size = max_size 35 | 36 | def get_size(self, img_size): 37 | w, h = img_size 38 | size = random.choice(self.min_size) 39 | max_size = self.max_size 40 | 41 | if max_size is not None: 42 | min_orig = float(min((w, h))) 43 | max_orig = float(max((w, h))) 44 | 45 | if max_orig / min_orig * size > max_size: 46 | size = int(round(max_size * min_orig / max_orig)) 47 | 48 | if (w <= h and w == size) or (h <= w and h == size): 49 | return h, w 50 | 51 | if w < h: 52 | ow = size 53 | oh = int(size * h / w) 54 | 55 | else: 56 | oh = size 57 | ow = int(size * w / h) 58 | 59 | return oh, ow 60 | 61 | def __call__(self, img, target): 62 | size = self.get_size(img.size) 63 | img = F.resize(img, size) 64 | target = target.resize(img.size) 65 | 66 | return img, target 67 | 68 | 69 | class RandomHorizontalFlip: 70 | def __init__(self, p=0.5): 71 | self.p = p 72 | 73 | def __call__(self, img, target): 74 | if random.random() < self.p: 75 | img = F.hflip(img) 76 | target = target.transpose(0) 77 | 78 | return img, target 79 | 80 | 81 | class ToTensor: 82 | def __call__(self, img, target): 83 | return F.to_tensor(img), target 84 | 85 | 86 | class Normalize: 87 | def __init__(self, mean, std): 88 | self.mean = mean 89 | self.std = std 90 | 91 | def __call__(self, img, target): 92 | img = F.normalize(img, mean=self.mean, std=self.std) 93 | 94 | return img, target 95 | 96 | 97 | def preset_transform(config, train=True): 98 | if train: 99 | if config.train_min_size_range[0] == -1: 100 | min_size = config.train_min_size 101 | 102 | else: 103 | min_size = list( 104 | range( 105 | config.train_min_size_range[0], config.train_min_size_range[1] + 1 106 | ) 107 | ) 108 | 109 | max_size = config.train_max_size 110 | flip = 0.5 111 | 112 | else: 113 | min_size = config.test_min_size 114 | max_size = config.test_max_size 115 | flip = 0 116 | 117 | normalize = Normalize(mean=config.pixel_mean, std=config.pixel_std) 118 | 119 | transform = Compose( 120 | [Resize(min_size, max_size), RandomHorizontalFlip(flip), ToTensor(), normalize] 121 | ) 122 | 123 | return transform 124 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | from matplotlib import pyplot as plt 5 | from matplotlib import patheffects, patches 6 | 7 | 8 | def show_img(img, figsize=None, fig=None, ax=None): 9 | if not ax: 10 | fig, ax = plt.subplots(figsize=figsize) 11 | 12 | ax.imshow(img) 13 | ax.get_xaxis().set_visible(False) 14 | ax.get_yaxis().set_visible(False) 15 | 16 | return fig, ax 17 | 18 | 19 | def draw_outline(obj, line_width): 20 | obj.set_path_effects( 21 | [ 22 | patheffects.Stroke(linewidth=line_width, foreground='black'), 23 | patheffects.Normal(), 24 | ] 25 | ) 26 | 27 | 28 | def draw_rect(ax, box): 29 | patch = ax.add_patch( 30 | patches.Rectangle(box[:2], *box[-2:], fill=False, edgecolor='white', lw=2) 31 | ) 32 | draw_outline(patch, 4) 33 | 34 | 35 | def draw_text(ax, xy, txt, sz=14): 36 | text = ax.text( 37 | *xy, txt, verticalalignment='top', color='white', fontsize=sz, weight='bold' 38 | ) 39 | draw_outline(text, 1) 40 | --------------------------------------------------------------------------------