├── utils ├── __init__.py ├── optimizer.py ├── lr_scheduler.py ├── vis_tools.py ├── distributed_utils.py ├── weight_init.py └── box_ops.py ├── datasets ├── demo │ ├── videos │ │ └── 000006.mp4 │ └── images │ │ ├── 000000000632.jpg │ │ ├── 000000000785.jpg │ │ ├── 000000000872.jpg │ │ ├── 000000000885.jpg │ │ ├── 000000001000.jpg │ │ ├── 000000001268.jpg │ │ ├── 000000001296.jpg │ │ └── 000000001532.jpg ├── __init__.py ├── coco.py └── transforms.py ├── .gitignore ├── evaluator ├── __init__.py └── coco_evaluator.py ├── models ├── backbone │ ├── __init__.py │ └── resnet.py ├── head │ ├── __init__.py │ ├── yolof_head.py │ └── fcos_head.py ├── neck │ ├── __init__.py │ ├── fpn.py │ └── dilated_encoder.py ├── detectors │ ├── yolof │ │ ├── build.py │ │ ├── README.md │ │ ├── yolof.py │ │ ├── matcher.py │ │ └── criterion.py │ ├── __init__.py │ └── fcos │ │ ├── build.py │ │ ├── README.md │ │ ├── fcos.py │ │ └── criterion.py └── basic │ ├── norm.py │ ├── mlp.py │ ├── conv.py │ ├── attn.py │ └── transformer.py ├── README.md ├── config ├── __init__.py ├── fcos_config.py └── yolof_config.py ├── train.sh ├── benchmark.py ├── engine.py ├── test.py ├── train.py ├── demo.py └── LICENSE /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /datasets/demo/videos/000006.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/videos/000006.mp4 -------------------------------------------------------------------------------- /datasets/demo/images/000000000632.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/images/000000000632.jpg -------------------------------------------------------------------------------- /datasets/demo/images/000000000785.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/images/000000000785.jpg -------------------------------------------------------------------------------- /datasets/demo/images/000000000872.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/images/000000000872.jpg -------------------------------------------------------------------------------- /datasets/demo/images/000000000885.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/images/000000000885.jpg -------------------------------------------------------------------------------- /datasets/demo/images/000000001000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/images/000000001000.jpg -------------------------------------------------------------------------------- /datasets/demo/images/000000001268.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/images/000000001268.jpg -------------------------------------------------------------------------------- /datasets/demo/images/000000001296.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/images/000000001296.jpg -------------------------------------------------------------------------------- /datasets/demo/images/000000001532.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjh0410/ODLab/HEAD/datasets/demo/images/000000001532.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.pth 3 | *.pkl 4 | *.onnx 5 | *.pyc 6 | *.zip 7 | weights 8 | __pycache__ 9 | det_results 10 | .vscode 11 | -------------------------------------------------------------------------------- /evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluator.coco_evaluator import COCOAPIEvaluator 2 | 3 | 4 | def build_evluator(args, cfg, device): 5 | evaluator = None 6 | # COCO Evaluator 7 | if args.dataset == 'coco': 8 | evaluator = COCOAPIEvaluator(args, cfg, device) 9 | 10 | return evaluator 11 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import build_resnet 2 | 3 | 4 | def build_backbone(cfg): 5 | print('==============================') 6 | print('Backbone: {}'.format(cfg.backbone)) 7 | # ResNet 8 | if "resnet" in cfg.backbone: 9 | return build_resnet(cfg) 10 | else: 11 | raise NotImplementedError("unknown backbone: {}".format(cfg.backbone)) 12 | 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # General Object Detection Laboratory 2 | The codebase of my research of General Object Detection 3 | ## Requirements 4 | - We recommend you to use Anaconda to create a conda environment: 5 | ```Shell 6 | conda create -n odlab python=3.8 7 | ``` 8 | 9 | - Then, activate the environment: 10 | ```Shell 11 | conda activate odlab 12 | ``` 13 | 14 | - Requirements: 15 | ```Shell 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | My torch environment: 20 | - PyTorch = 2.2.0+cu121 21 | - Torchvision = 0.17.0+cu121 22 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # ----------------------- Model Config ----------------------- 2 | from .fcos_config import build_fcos_config 3 | from .yolof_config import build_yolof_config 4 | 5 | def build_config(args): 6 | # FCOS 7 | if "fcos" in args.model: 8 | cfg = build_fcos_config(args) 9 | # YOLOF 10 | elif "yolof" in args.model: 11 | cfg = build_yolof_config(args) 12 | else: 13 | raise NotImplementedError('Unknown Model: {}'.format(args.model)) 14 | 15 | # Print model config 16 | cfg.print_config() 17 | 18 | return cfg 19 | -------------------------------------------------------------------------------- /models/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .yolof_head import YolofHead 2 | from .fcos_head import FcosHead, FcosRTHead 3 | 4 | 5 | # build head 6 | def build_head(cfg, in_dim, out_dim): 7 | print('==============================') 8 | print('Head: {}'.format(cfg.head)) 9 | 10 | if cfg.head == 'fcos_head': 11 | model = FcosHead(cfg, in_dim, out_dim) 12 | elif cfg.head == 'fcos_rt_head': 13 | model = FcosRTHead(cfg, in_dim, out_dim) 14 | elif cfg.head == 'yolof_head': 15 | model = YolofHead(cfg, in_dim, out_dim) 16 | 17 | return model 18 | -------------------------------------------------------------------------------- /models/neck/__init__.py: -------------------------------------------------------------------------------- 1 | from .dilated_encoder import DilatedEncoder 2 | from .fpn import BasicFPN 3 | from typing import List 4 | 5 | # build neck 6 | def build_neck(cfg, in_dim, out_dim): 7 | print('==============================') 8 | print('Neck: {}'.format(cfg.neck)) 9 | 10 | # ----------------------- Neck module ----------------------- 11 | if cfg.neck == 'dilated_encoder': 12 | model = DilatedEncoder(cfg, in_dim, out_dim) 13 | 14 | # ----------------------- FPN Neck ----------------------- 15 | elif cfg.neck == 'basic_fpn': 16 | assert isinstance(in_dim, List) 17 | model = BasicFPN(cfg, in_dim, out_dim) 18 | else: 19 | raise NotImplementedError("Unknown Neck: <{}>".format(cfg.fpn)) 20 | 21 | return model 22 | -------------------------------------------------------------------------------- /models/detectors/yolof/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | from .criterion import SetCriterion 5 | from .yolof import YOLOF 6 | 7 | 8 | # build YOLOF 9 | def build_yolof(cfg, is_val=False): 10 | # -------------- Build YOLOF -------------- 11 | model = YOLOF(cfg = cfg, 12 | conf_thresh = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh, 13 | nms_thresh = cfg.train_nms_thresh if is_val else cfg.test_nms_thresh, 14 | topk = cfg.train_topk if is_val else cfg.test_topk, 15 | ) 16 | 17 | # -------------- Build Criterion -------------- 18 | criterion = None 19 | if is_val: 20 | # build criterion for training 21 | criterion = SetCriterion(cfg) 22 | 23 | return model, criterion -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # Args setting 2 | MODEL=$1 3 | DATASET=$2 4 | DATA_ROOT=$3 5 | BATCH_SIZE=$4 6 | WORLD_SIZE=$5 7 | MASTER_PORT=$6 8 | 9 | # -------------------------- Train Pipeline -------------------------- 10 | if [ $WORLD_SIZE == 1 ]; then 11 | python train.py \ 12 | --cuda \ 13 | --dataset ${DATASET} \ 14 | --root ${DATA_ROOT} \ 15 | --model ${MODEL} \ 16 | --batch_size ${BATCH_SIZE} 17 | elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then 18 | python -m torch.distributed.run --nproc_per_node=$WORLD_SIZE --master_port ${MASTER_PORT} \ 19 | train.py \ 20 | --cuda \ 21 | --distributed \ 22 | --dataset ${DATASET} \ 23 | --root ${DATA_ROOT} \ 24 | --model ${MODEL} \ 25 | --batch_size ${BATCH_SIZE} 26 | else 27 | echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \ 28 | multi-card training mode, which is currently unsupported." 29 | exit 1 30 | fi -------------------------------------------------------------------------------- /models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | from .fcos.build import build_fcos, build_fcos_rt 5 | from .yolof.build import build_yolof 6 | 7 | 8 | def build_model(args, cfg, is_val=False): 9 | # ------------ build object detector ------------ 10 | ## RT-FCOS 11 | if 'fcos_rt' in args.model: 12 | model, criterion = build_fcos_rt(cfg, is_val) 13 | ## FCOS 14 | elif 'fcos' in args.model: 15 | model, criterion = build_fcos(cfg, is_val) 16 | ## YOLOF 17 | elif 'yolof' in args.model: 18 | model, criterion = build_yolof(cfg, is_val) 19 | else: 20 | raise NotImplementedError("Unknown detector: {}".args.model) 21 | 22 | if is_val: 23 | # ------------ Keep training from the given weight ------------ 24 | if args.resume is not None: 25 | print('Load model from the checkpoint: ', args.resume) 26 | checkpoint = torch.load(args.resume, map_location='cpu') 27 | # checkpoint state dict 28 | checkpoint_state_dict = checkpoint.pop("model") 29 | model.load_state_dict(checkpoint_state_dict) 30 | 31 | return model, criterion 32 | 33 | else: 34 | return model 35 | -------------------------------------------------------------------------------- /models/detectors/fcos/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | from .criterion import SetCriterion 5 | from .fcos import FCOS, FcosRT 6 | 7 | 8 | # build FCOS 9 | def build_fcos(cfg, is_val=False): 10 | # -------------- Build FCOS -------------- 11 | model = FCOS(cfg = cfg, 12 | num_classes = cfg.num_classes, 13 | conf_thresh = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh, 14 | nms_thresh = cfg.train_nms_thresh if is_val else cfg.test_nms_thresh, 15 | topk = cfg.train_topk if is_val else cfg.test_topk, 16 | ) 17 | 18 | # -------------- Build Criterion -------------- 19 | criterion = None 20 | if is_val: 21 | # build criterion for training 22 | criterion = SetCriterion(cfg) 23 | 24 | return model, criterion 25 | 26 | # build FCOS 27 | def build_fcos_rt(cfg, is_val=False): 28 | # -------------- Build FCOS -------------- 29 | model = FcosRT(cfg = cfg, 30 | num_classes = cfg.num_classes, 31 | conf_thresh = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh, 32 | nms_thresh = cfg.train_nms_thresh if is_val else cfg.test_nms_thresh, 33 | topk = cfg.train_topk if is_val else cfg.test_topk, 34 | ) 35 | 36 | # -------------- Build Criterion -------------- 37 | criterion = None 38 | if is_val: 39 | # build criterion for training 40 | criterion = SetCriterion(cfg) 41 | 42 | return model, criterion -------------------------------------------------------------------------------- /models/basic/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FrozenBatchNorm2d(torch.nn.Module): 6 | """ 7 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 8 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 9 | without which any other models than torchvision.models.resnet[18,34,50,101] 10 | produce nans. 11 | """ 12 | 13 | def __init__(self, n): 14 | super(FrozenBatchNorm2d, self).__init__() 15 | self.register_buffer("weight", torch.ones(n)) 16 | self.register_buffer("bias", torch.zeros(n)) 17 | self.register_buffer("running_mean", torch.zeros(n)) 18 | self.register_buffer("running_var", torch.ones(n)) 19 | 20 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 21 | missing_keys, unexpected_keys, error_msgs): 22 | num_batches_tracked_key = prefix + 'num_batches_tracked' 23 | if num_batches_tracked_key in state_dict: 24 | del state_dict[num_batches_tracked_key] 25 | 26 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 27 | state_dict, prefix, local_metadata, strict, 28 | missing_keys, unexpected_keys, error_msgs) 29 | 30 | def forward(self, x): 31 | # move reshapes to the beginning 32 | # to make it fuser-friendly 33 | w = self.weight.reshape(1, -1, 1, 1) 34 | b = self.bias.reshape(1, -1, 1, 1) 35 | rv = self.running_var.reshape(1, -1, 1, 1) 36 | rm = self.running_mean.reshape(1, -1, 1, 1) 37 | eps = 1e-5 38 | scale = w * (rv + eps).rsqrt() 39 | bias = b - rm * scale 40 | return x * scale + bias 41 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | 4 | 5 | def build_optimizer(cfg, model, resume=None): 6 | print('==============================') 7 | print('Optimizer: {}'.format(cfg.optimizer)) 8 | print('--base_lr: {}'.format(cfg.base_lr)) 9 | print('--backbone_lr_ratio: {}'.format(cfg.bk_lr_ratio)) 10 | print('--momentum: {}'.format(cfg.momentum)) 11 | print('--weight_decay: {}'.format(cfg.weight_decay)) 12 | 13 | param_dicts = [ 14 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, 15 | { 16 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 17 | "lr": cfg.base_lr * cfg.bk_lr_ratio, 18 | }, 19 | ] 20 | 21 | if cfg.optimizer == 'sgd': 22 | optimizer = optim.SGD( 23 | params=param_dicts, 24 | lr=cfg.base_lr, 25 | momentum=cfg.momentum, 26 | weight_decay=cfg.weight_decay 27 | ) 28 | 29 | elif cfg.optimizer == 'adamw': 30 | optimizer = optim.AdamW( 31 | params=param_dicts, 32 | lr=cfg.base_lr, 33 | weight_decay=cfg.weight_decay 34 | ) 35 | 36 | start_epoch = 0 37 | if resume is not None: 38 | print('Load optimzier from the checkpoint: ', resume) 39 | checkpoint = torch.load(resume) 40 | # checkpoint state dict 41 | checkpoint_state_dict = checkpoint.pop("optimizer") 42 | optimizer.load_state_dict(checkpoint_state_dict) 43 | start_epoch = checkpoint.pop("epoch") + 1 44 | 45 | return optimizer, start_epoch 46 | -------------------------------------------------------------------------------- /models/basic/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_activation(act_type=None): 5 | if act_type == 'relu': 6 | return nn.ReLU(inplace=True) 7 | elif act_type == 'lrelu': 8 | return nn.LeakyReLU(0.1, inplace=True) 9 | elif act_type == 'gelu': 10 | return nn.GELU() 11 | elif act_type == 'mish': 12 | return nn.Mish(inplace=True) 13 | elif act_type == 'silu': 14 | return nn.SiLU(inplace=True) 15 | elif act_type is None: 16 | return nn.Identity() 17 | 18 | class MLP(nn.Module): 19 | def __init__(self, in_dim, hidden_dim, out_dim, num_layers): 20 | super().__init__() 21 | self.num_layers = num_layers 22 | h = [hidden_dim] * (num_layers - 1) 23 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim])) 24 | 25 | def forward(self, x): 26 | for i, layer in enumerate(self.layers): 27 | x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 28 | return x 29 | 30 | class FFN(nn.Module): 31 | def __init__(self, d_model=256, ffn_dim=1024, dropout=0., act_type='relu', pre_norm=False): 32 | super().__init__() 33 | # ----------- Basic parameters ----------- 34 | self.pre_norm = pre_norm 35 | self.ffn_dim = ffn_dim 36 | # ----------- Network parameters ----------- 37 | self.linear1 = nn.Linear(d_model, self.ffn_dim) 38 | self.activation = get_activation(act_type) 39 | self.dropout2 = nn.Dropout(dropout) 40 | self.linear2 = nn.Linear(self.ffn_dim, d_model) 41 | self.dropout3 = nn.Dropout(dropout) 42 | self.norm = nn.LayerNorm(d_model) 43 | 44 | def forward(self, src): 45 | if self.pre_norm: 46 | src = self.norm(src) 47 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 48 | src = src + self.dropout3(src2) 49 | else: 50 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 51 | src = src + self.dropout3(src2) 52 | src = self.norm(src) 53 | 54 | return src 55 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # ------------------------- WarmUp LR Scheduler ------------------------- 5 | ## Warmup LR Scheduler 6 | class LinearWarmUpScheduler(object): 7 | def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667): 8 | self.base_lr = base_lr 9 | self.wp_iter = wp_iter 10 | self.warmup_factor = warmup_factor 11 | 12 | def set_lr(self, optimizer, lr): 13 | for param_group in optimizer.param_groups: 14 | init_lr = param_group['initial_lr'] 15 | ratio = init_lr / self.base_lr 16 | param_group['lr'] = lr * ratio 17 | 18 | def __call__(self, iter, optimizer): 19 | # warmup 20 | alpha = iter / self.wp_iter 21 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 22 | tmp_lr = self.base_lr * warmup_factor 23 | self.set_lr(optimizer, tmp_lr) 24 | 25 | ## Build WP LR Scheduler 26 | def build_wp_lr_scheduler(cfg): 27 | print('==============================') 28 | print('WarmUpScheduler: {}'.format(cfg.warmup)) 29 | print('--base_lr: {}'.format(cfg.base_lr)) 30 | print('--warmup_iters: {} ({})'.format(cfg.warmup_iters, cfg.warmup_iters * cfg.grad_accumulate)) 31 | print('--warmup_factor: {}'.format(cfg.warmup_factor)) 32 | 33 | if cfg.warmup == 'linear': 34 | wp_lr_scheduler = LinearWarmUpScheduler(cfg.base_lr, cfg.warmup_iters, cfg.warmup_factor) 35 | 36 | return wp_lr_scheduler 37 | 38 | 39 | # ------------------------- LR Scheduler ------------------------- 40 | def build_lr_scheduler(cfg, optimizer, resume=None): 41 | print('==============================') 42 | print('LR Scheduler: {}'.format(cfg.lr_scheduler)) 43 | 44 | if cfg.lr_scheduler == 'step': 45 | assert hasattr(cfg, 'lr_epoch') 46 | print('--lr_epoch: {}'.format(cfg.lr_epoch)) 47 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg.lr_epoch) 48 | elif cfg.lr_scheduler == 'cosine': 49 | pass 50 | 51 | if resume is not None: 52 | print('Load lr scheduler from the checkpoint: ', resume) 53 | checkpoint = torch.load(resume) 54 | # checkpoint state dict 55 | checkpoint_state_dict = checkpoint.pop("lr_scheduler") 56 | lr_scheduler.load_state_dict(checkpoint_state_dict) 57 | 58 | return lr_scheduler 59 | -------------------------------------------------------------------------------- /models/detectors/yolof/README.md: -------------------------------------------------------------------------------- 1 | # YOLOF: You Only Look One-level Feature 2 | 3 | - COCO 4 | 5 | | Model | scale | FPSFP32
RTX 4060 | APval
0.5:0.95 | APval
0.5 | Weight | Logs | 6 | | ---------------- | ---------- | ---------------------- | ---------------------- | --------------- | ------ | ----- | 7 | | YOLOF_R18_C5_1x | 800,1333 | 54 | 32.8 | 51.4 | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/yolof_r18_c5_1x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/YOLOF-R18-C5-1x.txt) | 8 | | YOLOF_R50_C5_1x | 800,1333 | 21 | 37.7 | 57.2 | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/yolof_r50_c5_1x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/YOLOF-R50-C5-1x.txt) | 9 | 10 | 11 | ## Train YOLOF 12 | ### Single GPU 13 | Taking training **YOLOF_R18_C5_1x** on COCO as the example, 14 | ```Shell 15 | python main.py --cuda -d coco --root path/to/coco -m yolof_r18_c5_1x --batch_size 16 --eval_epoch 2 16 | ``` 17 | 18 | ### Multi GPU 19 | Taking training **YOLOF_R18_C5_1x** on COCO as the example, 20 | ```Shell 21 | python -m torch.distributed.run --nproc_per_node=8 train.py --cuda -dist -d coco --root path/to/coco -m yolof_r18_c5_1x --batch_size 16 --eval_epoch 2 22 | ``` 23 | 24 | ## Test YOLOF 25 | Taking testing **YOLOF_R18_C5_1x** on COCO-val as the example, 26 | ```Shell 27 | python test.py --cuda -d coco --root path/to/coco -m yolof_r18_c5_1x --weight path/to/yolof_r18_c5_1x.pth -vt 0.4 --show 28 | ``` 29 | 30 | ## Evaluate YOLOF 31 | Taking evaluating **YOLOF_R18_C5_1x** on COCO-val as the example, 32 | ```Shell 33 | python main.py --cuda -d coco --root path/to/coco -m yolof_r18_c5_1x --resume path/to/yolof_r18_c5_1x.pth --eval_first 34 | ``` 35 | 36 | ## Demo 37 | ### Detect with Image 38 | ```Shell 39 | python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolof_r18_c5_1x --weight path/to/weight -vt 0.4 --show 40 | ``` 41 | 42 | ### Detect with Video 43 | ```Shell 44 | python demo.py --mode video --path_to_vid path/to/video --cuda -m yolof_r18_c5_1x --weight path/to/weight -vt 0.4 --show --gif 45 | ``` 46 | 47 | ### Detect with Camera 48 | ```Shell 49 | python demo.py --mode camera --cuda -m yolof_r18_c5_1x --weight path/to/weight -vt 0.4 --show --gif 50 | ``` -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch.utils.data 3 | from torch.utils.data import DataLoader, DistributedSampler 4 | 5 | from .coco import build_coco, coco_indexs 6 | from .transforms import build_transform 7 | 8 | 9 | def build_dataset(args, cfg, transform=None, is_train=False): 10 | if args.dataset == 'coco': 11 | dataset = build_coco(args, transform, is_train) 12 | class_labels = dataset.coco_labels 13 | num_classes = 80 14 | cfg.class_labels = class_labels 15 | cfg.num_classes = num_classes 16 | 17 | return dataset 18 | 19 | def build_dataloader(args, dataset, batch_size, collate_fn, is_train=False): 20 | if args.distributed: 21 | sampler = DistributedSampler(dataset) if is_train else DistributedSampler(dataset, shuffle=False) 22 | else: 23 | sampler = torch.utils.data.RandomSampler(dataset) if is_train else torch.utils.data.SequentialSampler(dataset) 24 | 25 | if is_train: 26 | batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True) 27 | dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=args.num_workers) 28 | else: 29 | dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=False, collate_fn=collate_fn, num_workers=args.num_workers) 30 | 31 | return dataloader 32 | 33 | 34 | coco_labels = {1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'} 35 | -------------------------------------------------------------------------------- /models/basic/conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # --------------------- Basic modules --------------------- 5 | def get_conv2d(c1, c2, k, p, s, d, g, bias=False): 6 | conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias) 7 | 8 | return conv 9 | 10 | def get_activation(act_type=None): 11 | if act_type == 'relu': 12 | return nn.ReLU(inplace=True) 13 | elif act_type == 'lrelu': 14 | return nn.LeakyReLU(0.1, inplace=True) 15 | elif act_type == 'mish': 16 | return nn.Mish(inplace=True) 17 | elif act_type == 'silu': 18 | return nn.SiLU(inplace=True) 19 | elif act_type is None: 20 | return nn.Identity() 21 | else: 22 | raise NotImplementedError 23 | 24 | def get_norm(norm_type, dim): 25 | if norm_type == 'BN': 26 | return nn.BatchNorm2d(dim) 27 | elif norm_type == 'GN': 28 | return nn.GroupNorm(num_groups=32, num_channels=dim) 29 | elif norm_type is None: 30 | return nn.Identity() 31 | else: 32 | raise NotImplementedError 33 | 34 | class BasicConv(nn.Module): 35 | def __init__(self, 36 | in_dim, # in channels 37 | out_dim, # out channels 38 | kernel_size=1, # kernel size 39 | padding=0, # padding 40 | stride=1, # padding 41 | dilation=1, # dilation 42 | act_type :str = 'lrelu', # activation 43 | norm_type :str = 'BN', # normalization 44 | depthwise :bool = False 45 | ): 46 | super(BasicConv, self).__init__() 47 | self.depthwise = depthwise 48 | use_bias = False if norm_type is not None else True 49 | if not depthwise: 50 | self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias) 51 | self.norm = get_norm(norm_type, out_dim) 52 | else: 53 | self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias) 54 | self.norm1 = get_norm(norm_type, in_dim) 55 | self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1) 56 | self.norm2 = get_norm(norm_type, out_dim) 57 | self.act = get_activation(act_type) 58 | 59 | def forward(self, x): 60 | if not self.depthwise: 61 | return self.act(self.norm(self.conv(x))) 62 | else: 63 | # Depthwise conv 64 | x = self.act(self.norm1(self.conv1(x))) 65 | # Pointwise conv 66 | x = self.act(self.norm2(self.conv2(x))) 67 | return x 68 | -------------------------------------------------------------------------------- /models/detectors/fcos/README.md: -------------------------------------------------------------------------------- 1 | # FCOS: Fully Convolutional One-Stage Object Detector 2 | 3 | 4 | - COCO 5 | 6 | | Model | scale | FPSFP32
RTX 4060 | APval
0.5:0.95 | APval
0.5 | Weight | Logs | 7 | | ---------------| ---------- | -------------------------- | ---------------------- | --------------- | ------ | ----- | 8 | | FCOS_R18_1x | 800,1333 | 24 | 34.0 | 52.2 | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/fcos_r18_1x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/FCOS-R18-1x.txt) | 9 | | FCOS_R50_1x | 800,1333 | 9 | 39.0 | 58.3 | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/fcos_r50_1x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/FCOS-R50-1x.txt) | 10 | | FCOS_RT_R18_3x | 512,736 | 56 | 35.8 | 53.3 | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/fcos_rt_r18_3x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/FCOS-RT-R18-3x.txt) | 11 | | FCOS_RT_R50_3x | 512,736 | 34 | 40.7 | 59.3 | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/fcos_rt_r50_3x_coco.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/yolo_tutorial_ckpt/FCOS-RT-R50-3x.txt) | 12 | 13 | ## Train FCOS 14 | ### Single GPU 15 | Taking training **FCOS_R18_1x** on COCO as the example, 16 | ```Shell 17 | python main.py --cuda -d coco --root path/to/coco -m fcos_r18_1x --batch_size 16 --eval_epoch 2 18 | ``` 19 | 20 | ### Multi GPU 21 | Taking training **FCOS_R18_1x** on COCO as the example, 22 | ```Shell 23 | python -m torch.distributed.run --nproc_per_node=8 train.py --cuda -dist -d coco --root path/to/coco -m fcos_r18_1x --batch_size 16 --eval_epoch 2 24 | ``` 25 | 26 | ## Test FCOS 27 | Taking testing **FCOS_R18_1x** on COCO-val as the example, 28 | ```Shell 29 | python test.py --cuda -d coco --root path/to/coco -m fcos_r18_1x --weight path/to/fcos_r18_1x.pth -vt 0.4 --show 30 | ``` 31 | 32 | ## Evaluate FCOS 33 | Taking evaluating **FCOS_R18_1x** on COCO-val as the example, 34 | ```Shell 35 | python main.py --cuda -d coco --root path/to/coco -m fcos_r18_1x --resume path/to/fcos_r18_1x.pth --eval_first 36 | ``` 37 | 38 | ## Demo 39 | ### Detect with Image 40 | ```Shell 41 | python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m fcos_r18_1x --weight path/to/weight -vt 0.4 --show 42 | ``` 43 | 44 | ### Detect with Video 45 | ```Shell 46 | python demo.py --mode video --path_to_vid path/to/video --cuda -m fcos_r18_1x --weight path/to/weight -vt 0.4 --show --gif 47 | ``` 48 | 49 | ### Detect with Camera 50 | ```Shell 51 | python demo.py --mode camera --cuda -m fcos_r18_1x --weight path/to/weight -vt 0.4 --show --gif 52 | ``` -------------------------------------------------------------------------------- /models/neck/fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from utils import weight_init 5 | 6 | 7 | # ------------------ Basic Feature Pyramid Network ------------------ 8 | class BasicFPN(nn.Module): 9 | def __init__(self, cfg, 10 | in_dims=[512, 1024, 2048], 11 | out_dim=256, 12 | ): 13 | super().__init__() 14 | # ------------------ Basic parameters ------------------- 15 | self.p6_feat = cfg.fpn_p6_feat 16 | self.p7_feat = cfg.fpn_p7_feat 17 | self.from_c5 = cfg.fpn_p6_from_c5 18 | 19 | # ------------------ Network parameters ------------------- 20 | ## latter layers 21 | self.input_projs = nn.ModuleList() 22 | self.smooth_layers = nn.ModuleList() 23 | for in_dim in in_dims[::-1]: 24 | self.input_projs.append(nn.Conv2d(in_dim, out_dim, kernel_size=1)) 25 | self.smooth_layers.append(nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)) 26 | 27 | ## P6/P7 layers 28 | if self.p6_feat: 29 | if self.from_c5: 30 | self.p6_conv = nn.Conv2d(in_dims[-1], out_dim, kernel_size=3, stride=2, padding=1) 31 | else: # from p5 32 | self.p6_conv = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1) 33 | if self.p7_feat: 34 | self.p7_conv = nn.Sequential( 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1) 37 | ) 38 | 39 | self._init_weight() 40 | 41 | def _init_weight(self): 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | weight_init.c2_xavier_fill(m) 45 | 46 | def forward(self, feats): 47 | """ 48 | feats: (List of Tensor) [C3, C4, C5], C_i ∈ R^(B x C_i x H_i x W_i) 49 | """ 50 | outputs = [] 51 | # [C3, C4, C5] -> [C5, C4, C3] 52 | feats = feats[::-1] 53 | top_level_feat = feats[0] 54 | prev_feat = self.input_projs[0](top_level_feat) 55 | outputs.append(self.smooth_layers[0](prev_feat)) 56 | 57 | for feat, input_proj, smooth_layer in zip(feats[1:], self.input_projs[1:], self.smooth_layers[1:]): 58 | feat = input_proj(feat) 59 | top_down_feat = F.interpolate(prev_feat, size=feat.shape[2:], mode='nearest') 60 | prev_feat = feat + top_down_feat 61 | outputs.insert(0, smooth_layer(prev_feat)) 62 | 63 | if self.p6_feat: 64 | if self.from_c5: 65 | p6_feat = self.p6_conv(feats[0]) 66 | else: 67 | p6_feat = self.p6_conv(outputs[-1]) 68 | # [P3, P4, P5] -> [P3, P4, P5, P6] 69 | outputs.append(p6_feat) 70 | 71 | if self.p7_feat: 72 | p7_feat = self.p7_conv(p6_feat) 73 | # [P3, P4, P5, P6] -> [P3, P4, P5, P6, P7] 74 | outputs.append(p7_feat) 75 | 76 | # [P3, P4, P5] or [P3, P4, P5, P6, P7] 77 | return outputs 78 | -------------------------------------------------------------------------------- /models/neck/dilated_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils import weight_init 3 | 4 | from ..basic.conv import BasicConv 5 | 6 | 7 | # BottleNeck 8 | class Bottleneck(nn.Module): 9 | def __init__(self, in_dim, dilation, expand_ratio, act_type='relu', norm_type='BN'): 10 | super(Bottleneck, self).__init__() 11 | # ------------------ Basic parameters ------------------- 12 | self.in_dim = in_dim 13 | self.dilation = dilation 14 | self.expand_ratio = expand_ratio 15 | inter_dim = round(in_dim * expand_ratio) 16 | # ------------------ Network parameters ------------------- 17 | self.branch = nn.Sequential( 18 | BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type), 19 | BasicConv(inter_dim, inter_dim, kernel_size=3, padding=dilation, dilation=dilation, act_type=act_type, norm_type=norm_type), 20 | BasicConv(inter_dim, in_dim, kernel_size=1, act_type=act_type, norm_type=norm_type) 21 | ) 22 | 23 | def forward(self, x): 24 | return x + self.branch(x) 25 | 26 | # Dilated Encoder 27 | class DilatedEncoder(nn.Module): 28 | def __init__(self, cfg, in_dim, out_dim): 29 | super(DilatedEncoder, self).__init__() 30 | # ------------------ Basic parameters ------------------- 31 | self.in_dim = in_dim 32 | self.out_dim = out_dim 33 | self.expand_ratio = cfg.neck_expand_ratio 34 | self.dilations = cfg.neck_dilations 35 | self.act_type = cfg.neck_act 36 | self.norm_type = cfg.neck_norm 37 | # ------------------ Network parameters ------------------- 38 | ## proj layer 39 | self.projector = nn.Sequential( 40 | BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=self.norm_type), 41 | BasicConv(out_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=self.norm_type) 42 | ) 43 | ## encoder layers 44 | self.encoders = nn.Sequential( 45 | *[Bottleneck(out_dim, d, self.expand_ratio, self.act_type, self.norm_type) for d in self.dilations]) 46 | 47 | self._init_weight() 48 | 49 | def _init_weight(self): 50 | for m in self.projector: 51 | if isinstance(m, nn.Conv2d): 52 | weight_init.c2_xavier_fill(m) 53 | weight_init.c2_xavier_fill(m) 54 | if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): 55 | nn.init.constant_(m.weight, 1) 56 | nn.init.constant_(m.bias, 0) 57 | 58 | for m in self.encoders.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.normal_(m.weight, mean=0, std=0.01) 61 | if hasattr(m, 'bias') and m.bias is not None: 62 | nn.init.constant_(m.bias, 0) 63 | 64 | if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): 65 | nn.init.constant_(m.weight, 1) 66 | nn.init.constant_(m.bias, 0) 67 | 68 | def forward(self, x): 69 | x = self.projector(x) 70 | x = self.encoders(x) 71 | 72 | return x 73 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | 5 | # load transform 6 | from datasets import build_dataset, build_transform 7 | 8 | # load some utils 9 | from utils.misc import compute_flops, load_weight 10 | 11 | from config import build_config 12 | from models.detectors import build_model 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Benchmark') 16 | # Model 17 | parser.add_argument('-m', '--model', default='fcos_r18_1x', 18 | help='build detector') 19 | parser.add_argument('--fuse_conv_bn', action='store_true', default=False, 20 | help='fuse conv and bn') 21 | parser.add_argument('--weight', default=None, type=str, 22 | help='Trained state_dict file path to open') 23 | # Data root 24 | parser.add_argument('--root', default='/data/datasets/COCO', 25 | help='data root') 26 | # cuda 27 | parser.add_argument('--cuda', action='store_true', default=False, 28 | help='use cuda.') 29 | 30 | args = parser.parse_args() 31 | 32 | 33 | def test(cfg, model, device, dataset, transform): 34 | # Step-1: Compute FLOPs and Params 35 | compute_flops( 36 | model=model, 37 | min_size=cfg.test_min_size, 38 | max_size=cfg.test_max_size, 39 | device=device) 40 | 41 | # Step-2: Compute FPS 42 | num_images = 2002 43 | total_time = 0 44 | count = 0 45 | with torch.no_grad(): 46 | for index in range(num_images): 47 | if index % 500 == 0: 48 | print('Testing image {:d}/{:d}....'.format(index+1, num_images)) 49 | 50 | # Load an image 51 | image, _ = dataset[index] 52 | 53 | # Preprocess 54 | x, _ = transform(image) 55 | x = x.unsqueeze(0).to(device) 56 | 57 | # Star 58 | torch.cuda.synchronize() 59 | start_time = time.perf_counter() 60 | 61 | # Inference 62 | outputs = model(x) 63 | 64 | # End 65 | torch.cuda.synchronize() 66 | elapsed = time.perf_counter() - start_time 67 | 68 | if index > 1: 69 | total_time += elapsed 70 | count += 1 71 | 72 | print('- FPS :', 1.0 / (total_time / count)) 73 | 74 | 75 | 76 | if __name__ == '__main__': 77 | # get device 78 | if args.cuda: 79 | print('use cuda') 80 | device = torch.device("cuda") 81 | else: 82 | device = torch.device("cpu") 83 | 84 | # Dataset & Model Config 85 | cfg = build_config(args) 86 | 87 | # Transform 88 | transform = build_transform(cfg, is_train=False) 89 | 90 | # Dataset 91 | args.dataset = 'coco' 92 | dataset = build_dataset(args, cfg, is_train=False) 93 | 94 | # Model 95 | model = build_model(args, cfg, is_val=False) 96 | model = load_weight(model, args.weight, args.fuse_conv_bn) 97 | model.to(device).eval() 98 | 99 | print("================= DETECT =================") 100 | # Run 101 | test(cfg, model, device, dataset, transform) 102 | -------------------------------------------------------------------------------- /evaluator/coco_evaluator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import contextlib 4 | import torch 5 | from pycocotools.cocoeval import COCOeval 6 | 7 | from datasets import build_transform 8 | from datasets.coco import build_coco 9 | 10 | class COCOAPIEvaluator(): 11 | def __init__(self, args, cfg, device): 12 | # ----------------- Basic parameters ----------------- 13 | self.image_set = 'val2017' 14 | self.device = device 15 | # ----------------- Metrics ----------------- 16 | self.map = 0. 17 | self.ap50_95 = 0. 18 | self.ap50 = 0. 19 | # ----------------- Dataset ----------------- 20 | self.transform = build_transform(cfg, is_train=False) 21 | self.dataset = build_coco(args, self.transform, is_train=False) 22 | 23 | 24 | @torch.no_grad() 25 | def evaluate(self, model): 26 | ids = [] 27 | coco_results = [] 28 | model.eval() 29 | model.trainable = False 30 | 31 | # start testing 32 | for index, (image, target) in enumerate(self.dataset): 33 | if index % 500 == 0: 34 | print('[Eval: %d / %d]'%(index, len(self.dataset))) 35 | # image id 36 | id_ = int(target['image_id']) 37 | ids.append(id_) 38 | 39 | # inference 40 | image = image.unsqueeze(0).to(self.device) 41 | outputs = model(image) 42 | scores = outputs['scores'] 43 | labels = outputs['labels'] 44 | bboxes = outputs['bboxes'] 45 | 46 | # rescale bbox 47 | orig_h, orig_w = target["orig_size"].tolist() 48 | bboxes[..., 0::2] *= orig_w 49 | bboxes[..., 1::2] *= orig_h 50 | 51 | # reformat results 52 | for i, box in enumerate(bboxes): 53 | x1 = float(box[0]) 54 | y1 = float(box[1]) 55 | x2 = float(box[2]) 56 | y2 = float(box[3]) 57 | label = self.dataset.coco_indexs[int(labels[i])] 58 | 59 | # COCO json format 60 | bbox = [x1, y1, x2 - x1, y2 - y1] 61 | score = float(scores[i]) 62 | A = {"image_id": id_, 63 | "category_id": label, 64 | "bbox": bbox, 65 | "score": score} 66 | coco_results.append(A) 67 | 68 | model.train() 69 | model.trainable = True 70 | annType = ['segm', 'bbox', 'keypoints'] 71 | # Evaluate the Dt (detection) json comparing with the ground truth 72 | if len(coco_results) > 0: 73 | print('evaluating ......') 74 | cocoGt = self.dataset.coco 75 | # suppress pycocotools prints 76 | with open(os.devnull, 'w') as devnull: 77 | with contextlib.redirect_stdout(devnull): 78 | cocoDt = cocoGt.loadRes(coco_results) 79 | cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1]) 80 | cocoEval.params.imgIds = ids 81 | cocoEval.evaluate() 82 | cocoEval.accumulate() 83 | cocoEval.summarize() 84 | # update mAP 85 | ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1] 86 | print('ap50_95 : ', ap50_95) 87 | print('ap50 : ', ap50) 88 | self.map = ap50_95 89 | self.ap50_95 = ap50_95 90 | self.ap50 = ap50 91 | del coco_results 92 | else: 93 | print('No coco detection results !') 94 | 95 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Train and eval functions used in main.py 4 | """ 5 | import math 6 | import sys 7 | from typing import Iterable 8 | 9 | import torch 10 | from utils import distributed_utils 11 | from utils.misc import MetricLogger, SmoothedValue 12 | from utils.vis_tools import vis_data 13 | 14 | 15 | def train_one_epoch(cfg, 16 | model : torch.nn.Module, 17 | criterion : torch.nn.Module, 18 | data_loader : Iterable, 19 | optimizer : torch.optim.Optimizer, 20 | device : torch.device, 21 | epoch : int, 22 | vis_target : bool, 23 | warmup_lr_scheduler, 24 | debug :bool = False 25 | ): 26 | model.train() 27 | criterion.train() 28 | metric_logger = MetricLogger(delimiter=" ") 29 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) 30 | header = 'Epoch: [{} / {}]'.format(epoch, cfg.max_epoch) 31 | epoch_size = len(data_loader) 32 | print_freq = 10 33 | 34 | optimizer.zero_grad() 35 | 36 | for iter_i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 37 | ni = iter_i + epoch * epoch_size 38 | # WarmUp 39 | if ni % cfg.grad_accumulate == 0: 40 | ni = ni // cfg.grad_accumulate 41 | if ni < cfg.warmup_iters: 42 | warmup_lr_scheduler(ni, optimizer) 43 | elif ni == cfg.warmup_iters: 44 | print('Warmup stage is over.') 45 | warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr) 46 | 47 | # To device 48 | images, masks = samples 49 | images = images.to(device) 50 | masks = masks.to(device) 51 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 52 | 53 | # Visualize train targets 54 | if vis_target: 55 | vis_data(images, targets, masks, cfg.class_labels, cfg.normalize_coords, cfg.box_format) 56 | 57 | # Inference 58 | outputs = model(images, masks) 59 | 60 | # Compute loss 61 | loss_dict = criterion(outputs, targets) 62 | loss_weight_dict = criterion.weight_dict 63 | losses = sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict) 64 | loss_value = losses.item() 65 | losses /= cfg.grad_accumulate 66 | 67 | # Reduce losses over all GPUs for logging purposes 68 | loss_dict_reduced = distributed_utils.reduce_dict(loss_dict) 69 | 70 | # Check loss 71 | if not math.isfinite(loss_value): 72 | print("Loss is {}, stopping training".format(loss_value)) 73 | print(loss_dict_reduced) 74 | sys.exit(1) 75 | 76 | # Backward 77 | losses.backward() 78 | 79 | # Optimize 80 | if (iter_i + 1) % cfg.grad_accumulate == 0: 81 | if cfg.clip_max_norm > 0: 82 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_max_norm) 83 | optimizer.step() 84 | optimizer.zero_grad() 85 | 86 | metric_logger.update(loss=loss_value, **loss_dict_reduced) 87 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 88 | 89 | if debug: 90 | print("For debug mode, we only train the model with 1 iteration.") 91 | break 92 | 93 | # gather the stats from all processes 94 | metric_logger.synchronize_between_processes() 95 | print("Averaged stats:", metric_logger) 96 | 97 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 98 | -------------------------------------------------------------------------------- /utils/vis_tools.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import torch 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | # -------------------------- For Detection Task -------------------------- 9 | ## visualize the input data during the training stage 10 | def vis_data(images, targets, masks=None, class_labels=None, normalized_coord=False, box_format='xyxy'): 11 | """ 12 | images: (tensor) [B, 3, H, W] 13 | masks: (Tensor) [B, H, W] 14 | targets: (list) a list of targets 15 | """ 16 | batch_size = images.size(0) 17 | np.random.seed(0) 18 | class_colors = [(np.random.randint(255), 19 | np.random.randint(255), 20 | np.random.randint(255)) for _ in range(80)] 21 | pixel_means = [0.485, 0.456, 0.406] 22 | pixel_std = [0.229, 0.224, 0.225] 23 | 24 | for bi in range(batch_size): 25 | target = targets[bi] 26 | # to numpy 27 | image = images[bi].permute(1, 2, 0).cpu().numpy() 28 | not_mask = ~masks[bi] 29 | img_h = not_mask.cumsum(0, dtype=torch.int32)[-1, 0] 30 | img_w = not_mask.cumsum(1, dtype=torch.int32)[0, -1] 31 | # denormalize 32 | image = (image * pixel_std + pixel_means) * 255 33 | image = image[:, :, (2, 1, 0)].astype(np.uint8) 34 | image = image.copy() 35 | 36 | tgt_boxes = target['boxes'].float() 37 | tgt_labels = target['labels'].long() 38 | for box, label in zip(tgt_boxes, tgt_labels): 39 | box_ = box.clone() 40 | if normalized_coord: 41 | box_[..., [0, 2]] *= img_w 42 | box_[..., [1, 3]] *= img_h 43 | if box_format == 'xywh': 44 | box_x1y1 = box_[..., :2] - box_[..., 2:] * 0.5 45 | box_x2y2 = box_[..., :2] + box_[..., 2:] * 0.5 46 | box_ = torch.cat([box_x1y1, box_x2y2], dim=-1) 47 | x1, y1, x2, y2 = box_.long().cpu().numpy() 48 | 49 | cls_id = label.item() 50 | color = class_colors[cls_id] 51 | # draw box 52 | cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) 53 | if class_labels is not None: 54 | class_name = class_labels[cls_id] 55 | # plot title bbox 56 | t_size = cv2.getTextSize(class_name, 0, fontScale=1, thickness=2)[0] 57 | cv2.rectangle(image, (x1, y1-t_size[1]), (int(x1 + t_size[0] * 0.4), y1), color, -1) 58 | # put the test on the title bbox 59 | cv2.putText(image, class_name, (x1, y1 - 5), 0, 0.4, (0, 0, 0), 1, lineType=cv2.LINE_AA) 60 | 61 | cv2.imshow('train target', image) 62 | cv2.waitKey(0) 63 | 64 | ## Draw bbox & label on the image 65 | def plot_bbox_labels(img, bbox, label=None, cls_color=None, text_scale=0.4): 66 | x1, y1, x2, y2 = bbox 67 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 68 | t_size = cv2.getTextSize(label, 0, fontScale=1, thickness=2)[0] 69 | # plot bbox 70 | cv2.rectangle(img, (x1, y1), (x2, y2), cls_color, 2) 71 | 72 | if label is not None: 73 | # plot title bbox 74 | cv2.rectangle(img, (x1, y1-t_size[1]), (int(x1 + t_size[0] * text_scale), y1), cls_color, -1) 75 | # put the test on the title bbox 76 | cv2.putText(img, label, (int(x1), int(y1 - 5)), 0, text_scale, (0, 0, 0), 1, lineType=cv2.LINE_AA) 77 | 78 | return img 79 | 80 | ## Visualize the detection results 81 | def visualize(image, bboxes, scores, labels, class_colors, class_names): 82 | ts = 0.4 83 | for i, bbox in enumerate(bboxes): 84 | cls_id = int(labels[i]) 85 | cls_color = class_colors[cls_id] 86 | 87 | mess = '%s: %.2f' % (class_names[cls_id], scores[i]) 88 | image = plot_bbox_labels(image, bbox, mess, cls_color, text_scale=ts) 89 | 90 | return image -------------------------------------------------------------------------------- /models/basic/attn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | # ----------------- BoxRPM Cross Attention Ops ----------------- 8 | class GlobalCrossAttention(nn.Module): 9 | def __init__( 10 | self, 11 | dim :int = 256, 12 | num_heads :int = 8, 13 | qkv_bias :bool = True, 14 | qk_scale :float = None, 15 | attn_drop :float = 0.0, 16 | proj_drop :float = 0.0, 17 | rpe_hidden_dim :int = 512, 18 | feature_stride :int = 16, 19 | ): 20 | super().__init__() 21 | # --------- Basic parameters --------- 22 | self.dim = dim 23 | self.num_heads = num_heads 24 | head_dim = dim // num_heads 25 | self.scale = qk_scale or head_dim ** -0.5 26 | self.feature_stride = feature_stride 27 | 28 | # --------- Network parameters --------- 29 | self.cpb_mlp1 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads) 30 | self.cpb_mlp2 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads) 31 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 32 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 33 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 34 | self.attn_drop = nn.Dropout(attn_drop) 35 | self.proj = nn.Linear(dim, dim) 36 | self.proj_drop = nn.Dropout(proj_drop) 37 | self.softmax = nn.Softmax(dim=-1) 38 | 39 | def build_cpb_mlp(self, in_dim, hidden_dim, out_dim): 40 | cpb_mlp = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=True), 41 | nn.ReLU(inplace=True), 42 | nn.Linear(hidden_dim, out_dim, bias=False)) 43 | return cpb_mlp 44 | 45 | def forward(self, 46 | query, 47 | reference_points, 48 | k_input_flatten, 49 | v_input_flatten, 50 | input_spatial_shapes, 51 | input_padding_mask=None, 52 | ): 53 | assert input_spatial_shapes.size(0) == 1, 'This is designed for single-scale decoder.' 54 | h, w = input_spatial_shapes[0] 55 | stride = self.feature_stride 56 | 57 | ref_pts = torch.cat([ 58 | reference_points[:, :, :, :2] - reference_points[:, :, :, 2:] / 2, 59 | reference_points[:, :, :, :2] + reference_points[:, :, :, 2:] / 2, 60 | ], dim=-1) # B, nQ, 1, 4 61 | 62 | pos_x = torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=w.device)[None, None, :, None] * stride # 1, 1, w, 1 63 | pos_y = torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=h.device)[None, None, :, None] * stride # 1, 1, h, 1 64 | 65 | delta_x = ref_pts[..., 0::2] - pos_x # B, nQ, w, 2 66 | delta_y = ref_pts[..., 1::2] - pos_y # B, nQ, h, 2 67 | 68 | rpe_x, rpe_y = self.cpb_mlp1(delta_x), self.cpb_mlp2(delta_y) # B, nQ, w/h, nheads 69 | rpe = (rpe_x[:, :, None] + rpe_y[:, :, :, None]).flatten(2, 3) # B, nQ, h, w, nheads -> B, nQ, h*w, nheads 70 | rpe = rpe.permute(0, 3, 1, 2) 71 | 72 | B_, N, C = k_input_flatten.shape 73 | k = self.k(k_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 74 | v = self.v(v_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 75 | B_, N, C = query.shape 76 | q = self.q(query).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 77 | q = q * self.scale 78 | 79 | attn = q @ k.transpose(-2, -1) 80 | attn += rpe 81 | if input_padding_mask is not None: 82 | attn += input_padding_mask[:, None, None] * -100 83 | 84 | fmin, fmax = torch.finfo(attn.dtype).min, torch.finfo(attn.dtype).max 85 | torch.clip_(attn, min=fmin, max=fmax) 86 | 87 | attn = self.softmax(attn) 88 | attn = self.attn_drop(attn) 89 | x = attn @ v 90 | 91 | x = x.transpose(1, 2).reshape(B_, N, C) 92 | x = self.proj(x) 93 | x = self.proj_drop(x) 94 | 95 | return x 96 | -------------------------------------------------------------------------------- /models/detectors/yolof/yolof.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # --------------- Model components --------------- 5 | from ...backbone import build_backbone 6 | from ...neck import build_neck 7 | from ...head import build_head 8 | 9 | # --------------- External components --------------- 10 | from utils.misc import multiclass_nms 11 | 12 | 13 | # ------------------------ You Only Look One-level Feature ------------------------ 14 | class YOLOF(nn.Module): 15 | def __init__(self, 16 | cfg, 17 | num_classes :int = 80, 18 | conf_thresh :float = 0.05, 19 | nms_thresh :float = 0.6, 20 | topk :int = 1000, 21 | ca_nms :bool = False): 22 | super(YOLOF, self).__init__() 23 | # ---------------------- Basic Parameters ---------------------- 24 | self.cfg = cfg 25 | self.topk = topk 26 | self.num_classes = num_classes 27 | self.conf_thresh = conf_thresh 28 | self.nms_thresh = nms_thresh 29 | self.ca_nms = ca_nms 30 | 31 | # ---------------------- Network Parameters ---------------------- 32 | ## Backbone 33 | self.backbone, feat_dims = build_backbone(cfg) 34 | 35 | ## Neck 36 | self.neck = build_neck(cfg, feat_dims[-1], cfg.head_dim) 37 | 38 | ## Heads 39 | self.head = build_head(cfg, cfg.head_dim, cfg.head_dim) 40 | 41 | def post_process(self, cls_pred, box_pred): 42 | """ 43 | Input: 44 | cls_pred: (Tensor) [[H x W x KA, C] 45 | box_pred: (Tensor) [H x W x KA, 4] 46 | """ 47 | cls_pred = cls_pred[0] 48 | box_pred = box_pred[0] 49 | 50 | # (H x W x KA x C,) 51 | scores_i = cls_pred.sigmoid().flatten() 52 | 53 | # Keep top k top scoring indices only. 54 | num_topk = min(self.topk, box_pred.size(0)) 55 | 56 | # torch.sort is actually faster than .topk (at least on GPUs) 57 | predicted_prob, topk_idxs = scores_i.sort(descending=True) 58 | topk_scores = predicted_prob[:num_topk] 59 | topk_idxs = topk_idxs[:num_topk] 60 | 61 | # filter out the proposals with low confidence score 62 | keep_idxs = topk_scores > self.conf_thresh 63 | topk_idxs = topk_idxs[keep_idxs] 64 | 65 | # final scores 66 | scores = topk_scores[keep_idxs] 67 | # final labels 68 | labels = topk_idxs % self.num_classes 69 | # final bboxes 70 | anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor') 71 | bboxes = box_pred[anchor_idxs] 72 | 73 | # to cpu & numpy 74 | scores = scores.cpu().numpy() 75 | labels = labels.cpu().numpy() 76 | bboxes = bboxes.cpu().numpy() 77 | 78 | # nms 79 | scores, labels, bboxes = multiclass_nms( 80 | scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms) 81 | 82 | return bboxes, scores, labels 83 | 84 | def forward(self, src, src_mask=None): 85 | # ---------------- Backbone ---------------- 86 | pyramid_feats = self.backbone(src) 87 | 88 | # ---------------- Neck ---------------- 89 | feat = self.neck(pyramid_feats[-1]) 90 | 91 | # ---------------- Heads ---------------- 92 | outputs = self.head(feat, src_mask) 93 | 94 | if not self.training: 95 | # ---------------- PostProcess ---------------- 96 | cls_pred = outputs["pred_cls"] 97 | box_pred = outputs["pred_box"] 98 | bboxes, scores, labels = self.post_process(cls_pred, box_pred) 99 | # normalize bbox 100 | bboxes[..., 0::2] /= src.shape[-1] 101 | bboxes[..., 1::2] /= src.shape[-2] 102 | bboxes = bboxes.clip(0., 1.) 103 | 104 | outputs = { 105 | 'scores': scores, 106 | 'labels': labels, 107 | 'bboxes': bboxes 108 | } 109 | 110 | return outputs 111 | -------------------------------------------------------------------------------- /utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import subprocess 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def reduce_dict(input_dict, average=True): 10 | """ 11 | Args: 12 | input_dict (dict): all the values will be reduced 13 | average (bool): whether to do average or sum 14 | Reduce the values in the dictionary from all processes so that all processes 15 | have the averaged results. Returns a dict with the same fields as 16 | input_dict, after reduction. 17 | """ 18 | world_size = get_world_size() 19 | if world_size < 2: 20 | return input_dict 21 | with torch.no_grad(): 22 | names = [] 23 | values = [] 24 | # sort the keys so that they are consistent across processes 25 | for k in sorted(input_dict.keys()): 26 | names.append(k) 27 | values.append(input_dict[k]) 28 | values = torch.stack(values, dim=0) 29 | dist.all_reduce(values) 30 | if average: 31 | values /= world_size 32 | reduced_dict = {k: v for k, v in zip(names, values)} 33 | return reduced_dict 34 | 35 | 36 | def get_sha(): 37 | cwd = os.path.dirname(os.path.abspath(__file__)) 38 | 39 | def _run(command): 40 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 41 | sha = 'N/A' 42 | diff = "clean" 43 | branch = 'N/A' 44 | try: 45 | sha = _run(['git', 'rev-parse', 'HEAD']) 46 | subprocess.check_output(['git', 'diff'], cwd=cwd) 47 | diff = _run(['git', 'diff-index', 'HEAD']) 48 | diff = "has uncommited changes" if diff else "clean" 49 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 50 | except Exception: 51 | pass 52 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 53 | return message 54 | 55 | 56 | def setup_for_distributed(is_master): 57 | """ 58 | This function disables printing when not in master process 59 | """ 60 | import builtins as __builtin__ 61 | builtin_print = __builtin__.print 62 | 63 | def print(*args, **kwargs): 64 | force = kwargs.pop('force', False) 65 | if is_master or force: 66 | builtin_print(*args, **kwargs) 67 | 68 | __builtin__.print = print 69 | 70 | 71 | def is_dist_avail_and_initialized(): 72 | if not dist.is_available(): 73 | return False 74 | if not dist.is_initialized(): 75 | return False 76 | return True 77 | 78 | 79 | def get_world_size(): 80 | if not is_dist_avail_and_initialized(): 81 | return 1 82 | return dist.get_world_size() 83 | 84 | 85 | def get_rank(): 86 | if not is_dist_avail_and_initialized(): 87 | return 0 88 | return dist.get_rank() 89 | 90 | 91 | def is_main_process(): 92 | return get_rank() == 0 93 | 94 | 95 | def save_on_master(*args, **kwargs): 96 | if is_main_process(): 97 | torch.save(*args, **kwargs) 98 | 99 | 100 | def init_distributed_mode(args): 101 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 102 | args.rank = int(os.environ["RANK"]) 103 | args.world_size = int(os.environ['WORLD_SIZE']) 104 | args.gpu = int(os.environ['LOCAL_RANK']) 105 | elif 'SLURM_PROCID' in os.environ: 106 | args.rank = int(os.environ['SLURM_PROCID']) 107 | args.gpu = args.rank % torch.cuda.device_count() 108 | else: 109 | print('Not using distributed mode') 110 | args.distributed = False 111 | return 112 | 113 | args.distributed = True 114 | 115 | torch.cuda.set_device(args.gpu) 116 | args.dist_backend = 'nccl' 117 | print('| distributed init (rank {}): {}'.format( 118 | args.rank, args.dist_url), flush=True) 119 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 120 | world_size=args.world_size, rank=args.rank) 121 | torch.distributed.barrier() 122 | setup_for_distributed(args.rank == 0) 123 | -------------------------------------------------------------------------------- /utils/weight_init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | import math 4 | 5 | import torch.nn as nn 6 | 7 | 8 | def constant_init(module, val, bias=0): 9 | nn.init.constant_(module.weight, val) 10 | if hasattr(module, 'bias') and module.bias is not None: 11 | nn.init.constant_(module.bias, bias) 12 | 13 | 14 | def xavier_init(module, gain=1, bias=0, distribution='normal'): 15 | assert distribution in ['uniform', 'normal'] 16 | if distribution == 'uniform': 17 | nn.init.xavier_uniform_(module.weight, gain=gain) 18 | else: 19 | nn.init.xavier_normal_(module.weight, gain=gain) 20 | if hasattr(module, 'bias') and module.bias is not None: 21 | nn.init.constant_(module.bias, bias) 22 | 23 | 24 | def normal_init(module, mean=0, std=1, bias=0): 25 | nn.init.normal_(module.weight, mean, std) 26 | if hasattr(module, 'bias') and module.bias is not None: 27 | nn.init.constant_(module.bias, bias) 28 | 29 | 30 | def uniform_init(module, a=0, b=1, bias=0): 31 | nn.init.uniform_(module.weight, a, b) 32 | if hasattr(module, 'bias') and module.bias is not None: 33 | nn.init.constant_(module.bias, bias) 34 | 35 | 36 | def kaiming_init(module, 37 | a=0, 38 | mode='fan_out', 39 | nonlinearity='relu', 40 | bias=0, 41 | distribution='normal'): 42 | assert distribution in ['uniform', 'normal'] 43 | if distribution == 'uniform': 44 | nn.init.kaiming_uniform_(module.weight, 45 | a=a, 46 | mode=mode, 47 | nonlinearity=nonlinearity) 48 | else: 49 | nn.init.kaiming_normal_(module.weight, 50 | a=a, 51 | mode=mode, 52 | nonlinearity=nonlinearity) 53 | if hasattr(module, 'bias') and module.bias is not None: 54 | nn.init.constant_(module.bias, bias) 55 | 56 | 57 | def caffe2_xavier_init(module, bias=0): 58 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch 59 | # Acknowledgment to FAIR's internal code 60 | kaiming_init(module, 61 | a=1, 62 | mode='fan_in', 63 | nonlinearity='leaky_relu', 64 | bias=bias, 65 | distribution='uniform') 66 | 67 | 68 | def c2_xavier_fill(module: nn.Module): 69 | """ 70 | Initialize `module.weight` using the "XavierFill" implemented in Caffe2. 71 | Also initializes `module.bias` to 0. 72 | 73 | Args: 74 | module (torch.nn.Module): module to initialize. 75 | """ 76 | # Caffe2 implementation of XavierFill in fact 77 | # corresponds to kaiming_uniform_ in PyTorch 78 | nn.init.kaiming_uniform_(module.weight, a=1) 79 | if module.bias is not None: 80 | nn.init.constant_(module.bias, 0) 81 | 82 | 83 | def c2_msra_fill(module: nn.Module): 84 | """ 85 | Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. 86 | Also initializes `module.bias` to 0. 87 | 88 | Args: 89 | module (torch.nn.Module): module to initialize. 90 | """ 91 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") 92 | if module.bias is not None: 93 | nn.init.constant_(module.bias, 0) 94 | 95 | 96 | def init_weights(m: nn.Module, zero_init_final_gamma=False): 97 | """Performs ResNet-style weight initialization.""" 98 | if isinstance(m, nn.Conv2d): 99 | # Note that there is no bias due to BN 100 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | zero_init_gamma = ( 104 | hasattr(m, "final_bn") and m.final_bn and zero_init_final_gamma 105 | ) 106 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 107 | m.bias.data.zero_() 108 | elif isinstance(m, nn.Linear): 109 | m.weight.data.normal_(mean=0.0, std=0.01) 110 | m.bias.data.zero_() 111 | -------------------------------------------------------------------------------- /utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def get_ious(bboxes1, 10 | bboxes2, 11 | box_mode="xyxy", 12 | iou_type="iou"): 13 | """ 14 | Compute iou loss of type ['iou', 'giou', 'linear_iou'] 15 | 16 | Args: 17 | inputs (tensor): pred values 18 | targets (tensor): target values 19 | weight (tensor): loss weight 20 | box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported. 21 | loss_type (str): 'giou' or 'iou' or 'linear_iou' 22 | reduction (str): reduction manner 23 | 24 | Returns: 25 | loss (tensor): computed iou loss. 26 | """ 27 | if box_mode == "ltrb": 28 | bboxes1 = torch.cat((-bboxes1[..., :2], bboxes1[..., 2:]), dim=-1) 29 | bboxes2 = torch.cat((-bboxes2[..., :2], bboxes2[..., 2:]), dim=-1) 30 | elif box_mode != "xyxy": 31 | raise NotImplementedError 32 | 33 | eps = torch.finfo(torch.float32).eps 34 | 35 | bboxes1_area = (bboxes1[..., 2] - bboxes1[..., 0]).clamp_(min=0) \ 36 | * (bboxes1[..., 3] - bboxes1[..., 1]).clamp_(min=0) 37 | bboxes2_area = (bboxes2[..., 2] - bboxes2[..., 0]).clamp_(min=0) \ 38 | * (bboxes2[..., 3] - bboxes2[..., 1]).clamp_(min=0) 39 | 40 | w_intersect = (torch.min(bboxes1[..., 2], bboxes2[..., 2]) 41 | - torch.max(bboxes1[..., 0], bboxes2[..., 0])).clamp_(min=0) 42 | h_intersect = (torch.min(bboxes1[..., 3], bboxes2[..., 3]) 43 | - torch.max(bboxes1[..., 1], bboxes2[..., 1])).clamp_(min=0) 44 | 45 | area_intersect = w_intersect * h_intersect 46 | area_union = bboxes2_area + bboxes1_area - area_intersect 47 | ious = area_intersect / area_union.clamp(min=eps) 48 | 49 | if iou_type == "iou": 50 | return ious 51 | elif iou_type == "giou": 52 | g_w_intersect = torch.max(bboxes1[..., 2], bboxes2[..., 2]) \ 53 | - torch.min(bboxes1[..., 0], bboxes2[..., 0]) 54 | g_h_intersect = torch.max(bboxes1[..., 3], bboxes2[..., 3]) \ 55 | - torch.min(bboxes1[..., 1], bboxes2[..., 1]) 56 | ac_uion = g_w_intersect * g_h_intersect 57 | gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps) 58 | return gious 59 | else: 60 | raise NotImplementedError 61 | 62 | def box_cxcywh_to_xyxy(x): 63 | x_c, y_c, w, h = x.unbind(-1) 64 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 65 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 66 | return torch.stack(b, dim=-1) 67 | 68 | def box_xyxy_to_cxcywh(x): 69 | x0, y0, x1, y1 = x.unbind(-1) 70 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 71 | (x1 - x0), (y1 - y0)] 72 | return torch.stack(b, dim=-1) 73 | 74 | # modified from torchvision to also return the union 75 | def box_iou(boxes1, boxes2): 76 | area1 = box_area(boxes1) 77 | area2 = box_area(boxes2) 78 | 79 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 80 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 81 | 82 | wh = (rb - lt).clamp(min=0) # [N,M,2] 83 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 84 | 85 | union = area1[:, None] + area2 - inter 86 | union[union == 0.0] = 1.0 87 | 88 | iou = inter / union 89 | 90 | return iou, union 91 | 92 | def generalized_box_iou(boxes1, boxes2): 93 | """ 94 | Generalized IoU from https://giou.stanford.edu/ 95 | 96 | The boxes should be in [x0, y0, x1, y1] format 97 | 98 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 99 | and M = len(boxes2) 100 | """ 101 | # degenerate boxes gives inf / nan results 102 | # so do an early check 103 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 104 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 105 | iou, union = box_iou(boxes1, boxes2) 106 | 107 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 108 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 109 | 110 | wh = (rb - lt).clamp(min=0) # [N,M,2] 111 | area = wh[:, :, 0] * wh[:, :, 1] 112 | 113 | return iou - (area - union) / area 114 | -------------------------------------------------------------------------------- /models/detectors/yolof/matcher.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------- 2 | # Copyright (c) Megvii Inc. All rights reserved. 3 | # --------------------------------------------------------------------- 4 | 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from utils.box_ops import * 10 | 11 | 12 | class UniformMatcher(nn.Module): 13 | """ 14 | This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/uniform_matcher.py 15 | Uniform Matching between the anchors and gt boxes, which can achieve 16 | balance in positive anchors. 17 | 18 | Args: 19 | match_times(int): Number of positive anchors for each gt box. 20 | """ 21 | 22 | def __init__(self, match_times: int = 4): 23 | super().__init__() 24 | self.match_times = match_times 25 | 26 | @torch.no_grad() 27 | def forward(self, pred_boxes, anchor_boxes, targets): 28 | """ 29 | pred_boxes: (Tensor) -> [B, num_queries, 4] 30 | anchor_boxes: (Tensor) -> [num_queries, 4] 31 | targets: (Dict) -> dict{'boxes': [...], 'labels': [...]} 32 | """ 33 | 34 | bs, num_queries = pred_boxes.shape[:2] 35 | 36 | # We flatten to compute the cost matrices in a batch 37 | # [B, num_queries, 4] -> [M, 4] 38 | out_bbox = pred_boxes.flatten(0, 1) 39 | # [num_queries, 4] -> [1, num_queries, 4] -> [B, num_queries, 4] -> [M, 4] 40 | anchor_boxes = anchor_boxes[None].repeat(bs, 1, 1) 41 | anchor_boxes = anchor_boxes.flatten(0, 1) 42 | 43 | # Also concat the target boxes 44 | tgt_bbox = torch.cat([v['boxes'] for v in targets]) 45 | 46 | # Compute the L1 cost between boxes 47 | # Note that we use anchors and predict boxes both 48 | cost_bbox = torch.cdist(box_xyxy_to_cxcywh(out_bbox), 49 | box_xyxy_to_cxcywh(tgt_bbox), 50 | p=1) 51 | cost_bbox_anchors = torch.cdist(anchor_boxes, 52 | box_xyxy_to_cxcywh(tgt_bbox), 53 | p=1) 54 | 55 | # Final cost matrix: [B, M, N], M=num_queries, N=num_tgt 56 | C = cost_bbox 57 | C = C.view(bs, num_queries, -1).cpu() 58 | C1 = cost_bbox_anchors 59 | C1 = C1.view(bs, num_queries, -1).cpu() 60 | 61 | sizes = [len(v['boxes']) for v in targets] # the number of object instances in each image 62 | all_indices_list = [[] for _ in range(bs)] 63 | # positive indices when matching predict boxes and gt boxes 64 | # len(indices) = batch size 65 | # len(tupe) = topk 66 | indices = [ 67 | tuple( 68 | torch.topk( 69 | c[i], 70 | k=self.match_times, 71 | dim=0, 72 | largest=False)[1].numpy().tolist() 73 | ) 74 | for i, c in enumerate(C.split(sizes, -1)) 75 | ] 76 | # positive indices when matching anchor boxes and gt boxes 77 | indices1 = [ 78 | tuple( 79 | torch.topk( 80 | c[i], 81 | k=self.match_times, 82 | dim=0, 83 | largest=False)[1].numpy().tolist()) 84 | for i, c in enumerate(C1.split(sizes, -1))] 85 | 86 | # concat the indices according to image ids 87 | # img_id = batch_id 88 | for img_id, (idx, idx1) in enumerate(zip(indices, indices1)): 89 | img_idx_i = [ 90 | np.array(idx_ + idx1_) 91 | for (idx_, idx1_) in zip(idx, idx1) 92 | ] # 'i' is the index of queris 93 | img_idx_j = [ 94 | np.array(list(range(len(idx_))) + list(range(len(idx1_)))) 95 | for (idx_, idx1_) in zip(idx, idx1) 96 | ] # 'j' is the index of tgt 97 | all_indices_list[img_id] = [*zip(img_idx_i, img_idx_j)] 98 | 99 | # re-organize the positive indices 100 | all_indices = [] 101 | for img_id in range(bs): 102 | all_idx_i = [] 103 | all_idx_j = [] 104 | for idx_list in all_indices_list[img_id]: 105 | idx_i, idx_j = idx_list 106 | all_idx_i.append(idx_i) 107 | all_idx_j.append(idx_j) 108 | all_idx_i = np.hstack(all_idx_i) 109 | all_idx_j = np.hstack(all_idx_j) 110 | all_indices.append((all_idx_i, all_idx_j)) 111 | 112 | 113 | return [(torch.as_tensor(i, dtype=torch.int64), 114 | torch.as_tensor(j, dtype=torch.int64)) for i, j in all_indices] 115 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import os 4 | import time 5 | import numpy as np 6 | from copy import deepcopy 7 | import torch 8 | 9 | # load transform 10 | from datasets import build_dataset, build_transform 11 | 12 | # load some utils 13 | from utils.misc import load_weight, compute_flops 14 | from utils.vis_tools import visualize 15 | 16 | from config import build_config 17 | from models.detectors import build_model 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Object Detection Lab') 22 | # Basic 23 | parser.add_argument('--cuda', action='store_true', default=False, 24 | help='use cuda.') 25 | parser.add_argument('--show', action='store_true', default=False, 26 | help='show the visulization results.') 27 | parser.add_argument('--save', action='store_true', default=False, 28 | help='save the visulization results.') 29 | parser.add_argument('--save_folder', default='det_results/', type=str, 30 | help='Dir to save results') 31 | parser.add_argument('-vt', '--visual_threshold', default=0.3, type=float, 32 | help='Final confidence threshold') 33 | parser.add_argument('-ws', '--window_scale', default=1.0, type=float, 34 | help='resize window of cv2 for visualization.') 35 | # Model 36 | parser.add_argument('-m', '--model', default='yolof_r18_c5_1x', type=str, 37 | help='build detector') 38 | parser.add_argument('--weight', default=None, 39 | type=str, help='Trained state_dict file path to open') 40 | parser.add_argument('--fuse_conv_bn', action='store_true', default=False, 41 | help='fuse Conv & BN') 42 | # Dataset 43 | parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/', 44 | help='data root') 45 | parser.add_argument('-d', '--dataset', default='coco', 46 | help='coco, voc.') 47 | 48 | return parser.parse_args() 49 | 50 | @torch.no_grad() 51 | def test_det(args, model, device, dataset, transform, class_colors, class_names): 52 | num_images = len(dataset) 53 | save_path = os.path.join('det_results/', args.dataset, args.model) 54 | os.makedirs(save_path, exist_ok=True) 55 | 56 | for index, (image, _) in enumerate(dataset): 57 | print('Testing image {:d}/{:d}....'.format(index+1, num_images)) 58 | orig_h, orig_w = image.height, image.width 59 | 60 | # PreProcess 61 | x, _ = transform(image) 62 | x = x.unsqueeze(0).to(device) 63 | 64 | # Inference 65 | t0 = time.time() 66 | outputs = model(x) 67 | scores = outputs['scores'] 68 | labels = outputs['labels'] 69 | bboxes = outputs['bboxes'] 70 | print("Infer. time: {}".format(time.time() - t0, "s")) 71 | 72 | # Rescale bboxes 73 | bboxes[..., 0::2] *= orig_w 74 | bboxes[..., 1::2] *= orig_h 75 | 76 | # Convert PIL.Image to numpy 77 | image = np.array(image).astype(np.uint8) 78 | image = image[..., (2, 1, 0)].copy() 79 | 80 | # Visualize results 81 | img_processed = visualize(image=image, 82 | bboxes=bboxes, 83 | scores=scores, 84 | labels=labels, 85 | class_colors=class_colors, 86 | class_names=class_names) 87 | if args.show: 88 | h, w = img_processed.shape[:2] 89 | sw, sh = int(w*args.window_scale), int(h*args.window_scale) 90 | cv2.namedWindow('detection', 0) 91 | cv2.resizeWindow('detection', sw, sh) 92 | cv2.imshow('detection', img_processed) 93 | cv2.waitKey(0) 94 | 95 | if args.save: 96 | # save result 97 | cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed) 98 | 99 | 100 | if __name__ == '__main__': 101 | args = parse_args() 102 | # cuda 103 | if args.cuda: 104 | print('use cuda') 105 | device = torch.device("cuda") 106 | else: 107 | device = torch.device("cpu") 108 | 109 | # Dataset & Model Config 110 | cfg = build_config(args) 111 | 112 | # Transform 113 | transform = build_transform(cfg, is_train=False) 114 | 115 | # Dataset 116 | dataset = build_dataset(args, cfg, is_train=False) 117 | 118 | np.random.seed(0) 119 | class_colors = [(np.random.randint(255), 120 | np.random.randint(255), 121 | np.random.randint(255)) for _ in range(cfg.num_classes)] 122 | 123 | # Model 124 | model = build_model(args, cfg, is_val=False) 125 | model = load_weight(model, args.weight, args.fuse_conv_bn) 126 | model.to(device).eval() 127 | 128 | # Compute FLOPs and Params 129 | model_copy = deepcopy(model) 130 | model_copy.trainable = False 131 | model_copy.eval() 132 | compute_flops( 133 | model=model_copy, 134 | min_size=cfg.test_min_size, 135 | max_size=cfg.test_max_size, 136 | device=device) 137 | del model_copy 138 | 139 | print("================= DETECT =================") 140 | # run 141 | test_det(args = args, 142 | model = model, 143 | device = device, 144 | dataset = dataset, 145 | transform = transform, 146 | class_colors = class_colors, 147 | class_names = cfg.class_labels, 148 | ) 149 | -------------------------------------------------------------------------------- /models/detectors/yolof/criterion.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------- 2 | # Copyright (c) Megvii Inc. All rights reserved. 3 | # --------------------------------------------------------------------- 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from utils.box_ops import * 10 | from utils.misc import sigmoid_focal_loss 11 | from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized 12 | 13 | from .matcher import UniformMatcher 14 | 15 | 16 | class SetCriterion(nn.Module): 17 | """ 18 | This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/yolof.py 19 | """ 20 | def __init__(self, cfg): 21 | super().__init__() 22 | # ------------- Basic parameters ------------- 23 | self.cfg = cfg 24 | self.num_classes = cfg.num_classes 25 | # ------------- Focal loss ------------- 26 | self.alpha = cfg.focal_loss_alpha 27 | self.gamma = cfg.focal_loss_gamma 28 | # ------------- Loss weight ------------- 29 | self.weight_dict = {'loss_cls': cfg.loss_cls_weight, 30 | 'loss_reg': cfg.loss_reg_weight} 31 | # ------------- Matcher ------------- 32 | self.matcher_cfg = cfg.matcher_hpy 33 | self.matcher = UniformMatcher(self.matcher_cfg['topk_candidates']) 34 | 35 | def loss_labels(self, pred_cls, tgt_cls, num_boxes): 36 | """ 37 | pred_cls: (Tensor) [N, C] 38 | tgt_cls: (Tensor) [N, C] 39 | """ 40 | # cls loss: [V, C] 41 | loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma) 42 | 43 | return loss_cls.sum() / num_boxes 44 | 45 | def loss_bboxes(self, pred_box, tgt_box, num_boxes): 46 | """ 47 | pred_box: (Tensor) [N, 4] 48 | tgt_box: (Tensor) [N, 4] 49 | """ 50 | # giou 51 | pred_giou = generalized_box_iou(pred_box, tgt_box) # [N, M] 52 | # giou loss 53 | loss_reg = 1. - torch.diag(pred_giou) 54 | 55 | return loss_reg.sum() / num_boxes 56 | 57 | def forward(self, outputs, targets): 58 | """ 59 | outputs['pred_cls']: (Tensor) [B, M, C] 60 | outputs['pred_box']: (Tensor) [B, M, 4] 61 | targets: (List) [dict{'boxes': [...], 62 | 'labels': [...], 63 | 'orig_size': ...}, ...] 64 | """ 65 | # -------------------- Pre-process -------------------- 66 | pred_box = outputs['pred_box'] 67 | pred_cls = outputs['pred_cls'].reshape(-1, self.num_classes) 68 | anchor_boxes = outputs['anchors'] 69 | masks = ~outputs['mask'] 70 | device = pred_box.device 71 | B = len(targets) 72 | 73 | # -------------------- Label assignment -------------------- 74 | indices = self.matcher(pred_box, anchor_boxes, targets) 75 | 76 | # [M, 4] -> [1, M, 4] -> [B, M, 4] 77 | anchor_boxes = box_cxcywh_to_xyxy(anchor_boxes) 78 | anchor_boxes = anchor_boxes[None].repeat(B, 1, 1) 79 | 80 | ious = [] 81 | pos_ious = [] 82 | for i in range(B): 83 | src_idx, tgt_idx = indices[i] 84 | # iou between predbox and tgt box 85 | iou, _ = box_iou(pred_box[i, ...], (targets[i]['boxes']).clone()) 86 | if iou.numel() == 0: 87 | max_iou = iou.new_full((iou.size(0),), 0) 88 | else: 89 | max_iou = iou.max(dim=1)[0] 90 | # iou between anchorbox and tgt box 91 | a_iou, _ = box_iou(anchor_boxes[i], (targets[i]['boxes']).clone()) 92 | if a_iou.numel() == 0: 93 | pos_iou = a_iou.new_full((0,), 0) 94 | else: 95 | pos_iou = a_iou[src_idx, tgt_idx] 96 | ious.append(max_iou) 97 | pos_ious.append(pos_iou) 98 | 99 | ious = torch.cat(ious) 100 | ignore_idx = ious > self.matcher_cfg['ignore_thresh'] 101 | pos_ious = torch.cat(pos_ious) 102 | pos_ignore_idx = pos_ious < self.matcher_cfg['iou_thresh'] 103 | 104 | src_idx = torch.cat( 105 | [src + idx * anchor_boxes[0].shape[0] for idx, (src, _) in 106 | enumerate(indices)]) 107 | # [BM,] 108 | gt_cls = torch.full(pred_cls.shape[:1], 109 | self.num_classes, 110 | dtype=torch.int64, 111 | device=device) 112 | gt_cls[ignore_idx] = -1 113 | tgt_cls_o = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)]) 114 | tgt_cls_o[pos_ignore_idx] = -1 115 | 116 | gt_cls[src_idx] = tgt_cls_o.to(device) 117 | 118 | foreground_idxs = (gt_cls >= 0) & (gt_cls != self.num_classes) 119 | num_foreground = foreground_idxs.sum() 120 | 121 | if is_dist_avail_and_initialized(): 122 | torch.distributed.all_reduce(num_foreground) 123 | num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item() 124 | 125 | # -------------------- Classification loss -------------------- 126 | gt_cls_target = torch.zeros_like(pred_cls) 127 | gt_cls_target[foreground_idxs, gt_cls[foreground_idxs]] = 1 128 | valid_idxs = (gt_cls >= 0) & masks 129 | loss_labels = self.loss_labels(pred_cls[valid_idxs], gt_cls_target[valid_idxs], num_foreground) 130 | 131 | # -------------------- Regression loss -------------------- 132 | tgt_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(device) 133 | tgt_boxes = tgt_boxes[~pos_ignore_idx] 134 | matched_pred_box = pred_box.reshape(-1, 4)[src_idx[~pos_ignore_idx.cpu()]] 135 | loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_foreground) 136 | 137 | loss_dict = dict( 138 | loss_cls = loss_labels, 139 | loss_reg = loss_bboxes, 140 | ) 141 | 142 | return loss_dict 143 | 144 | 145 | if __name__ == "__main__": 146 | pass 147 | -------------------------------------------------------------------------------- /models/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | 6 | import torch 7 | import torchvision 8 | from torch import nn 9 | from torchvision.models._utils import IntermediateLayerGetter 10 | from torchvision.models.resnet import (ResNet18_Weights, 11 | ResNet34_Weights, 12 | ResNet50_Weights, 13 | ResNet101_Weights) 14 | 15 | model_urls = { 16 | # IN1K-Cls pretrained weights 17 | 'resnet18': ResNet18_Weights, 18 | 'resnet34': ResNet34_Weights, 19 | 'resnet50': ResNet50_Weights, 20 | 'resnet101': ResNet101_Weights, 21 | } 22 | 23 | 24 | # Frozen BatchNormazlizarion 25 | class FrozenBatchNorm2d(torch.nn.Module): 26 | """ 27 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 28 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 29 | without which any other models than torchvision.models.resnet[18,34,50,101] 30 | produce nans. 31 | """ 32 | 33 | def __init__(self, n): 34 | super(FrozenBatchNorm2d, self).__init__() 35 | self.register_buffer("weight", torch.ones(n)) 36 | self.register_buffer("bias", torch.zeros(n)) 37 | self.register_buffer("running_mean", torch.zeros(n)) 38 | self.register_buffer("running_var", torch.ones(n)) 39 | 40 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 41 | missing_keys, unexpected_keys, error_msgs): 42 | num_batches_tracked_key = prefix + 'num_batches_tracked' 43 | if num_batches_tracked_key in state_dict: 44 | del state_dict[num_batches_tracked_key] 45 | 46 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 47 | state_dict, prefix, local_metadata, strict, 48 | missing_keys, unexpected_keys, error_msgs) 49 | 50 | def forward(self, x): 51 | # move reshapes to the beginning 52 | # to make it fuser-friendly 53 | w = self.weight.reshape(1, -1, 1, 1) 54 | b = self.bias.reshape(1, -1, 1, 1) 55 | rv = self.running_var.reshape(1, -1, 1, 1) 56 | rm = self.running_mean.reshape(1, -1, 1, 1) 57 | eps = 1e-5 58 | scale = w * (rv + eps).rsqrt() 59 | bias = b - rm * scale 60 | return x * scale + bias 61 | 62 | # -------------------- ResNet series -------------------- 63 | class ResNet(nn.Module): 64 | """Standard ResNet backbone.""" 65 | def __init__(self, 66 | name :str = "resnet50", 67 | res5_dilation :bool = False, 68 | norm_type :str = "BN", 69 | freeze_at :int = 0, 70 | use_pretrained :bool = False): 71 | super().__init__() 72 | # Pretrained 73 | if use_pretrained: 74 | pretrained_weights = model_urls[name].IMAGENET1K_V1 75 | else: 76 | pretrained_weights = None 77 | 78 | # Norm layer 79 | print("- Norm layer of backbone: {}".format(norm_type)) 80 | if norm_type == 'BN': 81 | norm_layer = nn.BatchNorm2d 82 | elif norm_type == 'FrozeBN': 83 | norm_layer = FrozenBatchNorm2d 84 | else: 85 | raise NotImplementedError("Unknown norm type: {}".format(norm_type)) 86 | 87 | # Backbone 88 | backbone = getattr(torchvision.models, name)( 89 | replace_stride_with_dilation=[False, False, res5_dilation], 90 | norm_layer=norm_layer, weights=pretrained_weights) 91 | return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 92 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 93 | self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048] 94 | 95 | # Freeze 96 | print("- Freeze at {}".format(freeze_at)) 97 | if freeze_at >= 0: 98 | for name, parameter in backbone.named_parameters(): 99 | if freeze_at == 0: # Only freeze stem layer 100 | if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 101 | parameter.requires_grad_(False) 102 | elif freeze_at == 1: # Freeze stem layer + layer1 103 | if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 104 | parameter.requires_grad_(False) 105 | elif freeze_at == 2: # Freeze stem layer + layer1 + layer2 106 | if 'layer3' not in name and 'layer4' not in name: 107 | parameter.requires_grad_(False) 108 | elif freeze_at == 3: # Freeze stem layer + layer1 + layer2 + layer3 109 | if 'layer4' not in name: 110 | parameter.requires_grad_(False) 111 | else: # Freeze all resnet's layers 112 | parameter.requires_grad_(False) 113 | 114 | def forward(self, x): 115 | xs = self.body(x) 116 | fmp_list = [] 117 | for name, fmp in xs.items(): 118 | fmp_list.append(fmp) 119 | 120 | return fmp_list 121 | 122 | 123 | # build backbone 124 | def build_resnet(cfg): 125 | # ResNet series 126 | backbone = ResNet( 127 | name = cfg.backbone, 128 | res5_dilation = cfg.res5_dilation, 129 | norm_type = cfg.bk_norm, 130 | use_pretrained = cfg.use_pretrained, 131 | freeze_at = cfg.freeze_at) 132 | 133 | return backbone, backbone.feat_dims 134 | 135 | 136 | if __name__ == '__main__': 137 | 138 | class FcosBaseConfig(object): 139 | def __init__(self): 140 | self.backbone = "resnet18" 141 | self.bk_norm = "FrozeBN" 142 | self.res5_dilation = False 143 | self.use_pretrained = True 144 | self.freeze_at = 0 145 | 146 | cfg = FcosBaseConfig() 147 | model, feat_dim = build_resnet(cfg) 148 | print(feat_dim) 149 | 150 | x = torch.randn(2, 3, 320, 320) 151 | output = model(x) 152 | for y in output: 153 | print(y.size()) 154 | -------------------------------------------------------------------------------- /models/head/yolof_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..basic.conv import BasicConv 6 | 7 | 8 | class YolofHead(nn.Module): 9 | def __init__(self, cfg, in_dim, out_dim,): 10 | super().__init__() 11 | self.fmp_size = None 12 | self.ctr_clamp = cfg.center_clamp 13 | self.DEFAULT_EXP_CLAMP = math.log(1e8) 14 | self.DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16) 15 | # ------------------ Basic parameters ------------------- 16 | self.cfg = cfg 17 | self.in_dim = in_dim 18 | self.stride = cfg.out_stride 19 | self.num_classes = cfg.num_classes 20 | self.num_cls_head = cfg.num_cls_head 21 | self.num_reg_head = cfg.num_reg_head 22 | self.act_type = cfg.head_act 23 | self.norm_type = cfg.head_norm 24 | # Anchor config 25 | self.anchor_size = torch.as_tensor(cfg.anchor_size) 26 | self.num_anchors = len(cfg.anchor_size) 27 | 28 | # ------------------ Network parameters ------------------- 29 | ## cls head 30 | cls_heads = [] 31 | self.cls_head_dim = out_dim 32 | for i in range(self.num_cls_head): 33 | if i == 0: 34 | cls_heads.append( 35 | BasicConv(in_dim, self.cls_head_dim, 36 | kernel_size=3, padding=1, stride=1, 37 | act_type=self.act_type, norm_type=self.norm_type) 38 | ) 39 | else: 40 | cls_heads.append( 41 | BasicConv(self.cls_head_dim, self.cls_head_dim, 42 | kernel_size=3, padding=1, stride=1, 43 | act_type=self.act_type, norm_type=self.norm_type) 44 | ) 45 | ## reg head 46 | reg_heads = [] 47 | self.reg_head_dim = out_dim 48 | for i in range(self.num_reg_head): 49 | if i == 0: 50 | reg_heads.append( 51 | BasicConv(in_dim, self.reg_head_dim, 52 | kernel_size=3, padding=1, stride=1, 53 | act_type=self.act_type, norm_type=self.norm_type) 54 | ) 55 | else: 56 | reg_heads.append( 57 | BasicConv(self.reg_head_dim, self.reg_head_dim, 58 | kernel_size=3, padding=1, stride=1, 59 | act_type=self.act_type, norm_type=self.norm_type) 60 | ) 61 | self.cls_heads = nn.Sequential(*cls_heads) 62 | self.reg_heads = nn.Sequential(*reg_heads) 63 | 64 | # pred layer 65 | self.obj_pred = nn.Conv2d(self.reg_head_dim, 1 * self.num_anchors, kernel_size=3, padding=1) 66 | self.cls_pred = nn.Conv2d(self.cls_head_dim, self.num_classes * self.num_anchors, kernel_size=3, padding=1) 67 | self.reg_pred = nn.Conv2d(self.reg_head_dim, 4 * self.num_anchors, kernel_size=3, padding=1) 68 | 69 | # init bias 70 | self._init_pred_layers() 71 | 72 | def _init_pred_layers(self): 73 | # init cls pred 74 | nn.init.normal_(self.cls_pred.weight, mean=0, std=0.01) 75 | init_prob = 0.01 76 | bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob)) 77 | nn.init.constant_(self.cls_pred.bias, bias_value) 78 | # init reg pred 79 | nn.init.normal_(self.reg_pred.weight, mean=0, std=0.01) 80 | nn.init.constant_(self.reg_pred.bias, 0.0) 81 | # init obj pred 82 | nn.init.normal_(self.obj_pred.weight, mean=0, std=0.01) 83 | nn.init.constant_(self.obj_pred.bias, 0.0) 84 | 85 | def get_anchors(self, fmp_size): 86 | """fmp_size: list -> [H, W] \n 87 | stride: int -> output stride 88 | """ 89 | # check anchor boxes 90 | if self.fmp_size is not None and self.fmp_size == fmp_size: 91 | return self.anchor_boxes 92 | else: 93 | # generate grid cells 94 | fmp_h, fmp_w = fmp_size 95 | anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)]) 96 | # [H, W, 2] -> [HW, 2] 97 | anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5 98 | # [HW, 2] -> [HW, 1, 2] -> [HW, KA, 2] 99 | anchor_xy = anchor_xy[:, None, :].repeat(1, self.num_anchors, 1) 100 | anchor_xy *= self.stride 101 | 102 | # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] 103 | anchor_wh = self.anchor_size[None, :, :].repeat(fmp_h*fmp_w, 1, 1) 104 | 105 | # [HW, KA, 4] -> [M, 4] 106 | anchor_boxes = torch.cat([anchor_xy, anchor_wh], dim=-1) 107 | anchor_boxes = anchor_boxes.view(-1, 4) 108 | 109 | self.anchor_boxes = anchor_boxes 110 | self.fmp_size = fmp_size 111 | 112 | return anchor_boxes 113 | 114 | def decode_boxes(self, anchor_boxes, pred_reg): 115 | """ 116 | anchor_boxes: (List[tensor]) [1, M, 4] 117 | pred_reg: (List[tensor]) [B, M, 4] 118 | """ 119 | # x = x_anchor + dx * w_anchor 120 | # y = y_anchor + dy * h_anchor 121 | pred_ctr_offset = pred_reg[..., :2] * anchor_boxes[..., 2:] 122 | pred_ctr_offset = torch.clamp(pred_ctr_offset, min=-self.ctr_clamp, max=self.ctr_clamp) 123 | pred_ctr_xy = anchor_boxes[..., :2] + pred_ctr_offset 124 | 125 | # w = w_anchor * exp(tw) 126 | # h = h_anchor * exp(th) 127 | pred_dwdh = pred_reg[..., 2:] 128 | pred_dwdh = torch.clamp(pred_dwdh, max=self.DEFAULT_SCALE_CLAMP) 129 | pred_wh = anchor_boxes[..., 2:] * pred_dwdh.exp() 130 | 131 | # convert [x, y, w, h] -> [x1, y1, x2, y2] 132 | pred_x1y1 = pred_ctr_xy - 0.5 * pred_wh 133 | pred_x2y2 = pred_ctr_xy + 0.5 * pred_wh 134 | pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1) 135 | 136 | return pred_box 137 | 138 | def forward(self, x, mask=None): 139 | # ------------------- Decoupled head ------------------- 140 | cls_feats = self.cls_heads(x) 141 | reg_feats = self.reg_heads(x) 142 | 143 | # ------------------- Generate anchor box ------------------- 144 | fmp_size = cls_feats.shape[2:] 145 | anchor_boxes = self.get_anchors(fmp_size) # [M, 4] 146 | anchor_boxes = anchor_boxes.to(cls_feats.device) 147 | 148 | # ------------------- Predict ------------------- 149 | obj_pred = self.obj_pred(reg_feats) 150 | cls_pred = self.cls_pred(cls_feats) 151 | reg_pred = self.reg_pred(reg_feats) 152 | 153 | # ------------------- Precoess preds ------------------- 154 | ## implicit objectness 155 | B, _, H, W = obj_pred.size() 156 | obj_pred = obj_pred.view(B, -1, 1, H, W) 157 | cls_pred = cls_pred.view(B, -1, self.num_classes, H, W) 158 | normalized_cls_pred = cls_pred + obj_pred - torch.log( 159 | 1. + 160 | torch.clamp(cls_pred, max=self.DEFAULT_EXP_CLAMP).exp() + 161 | torch.clamp(obj_pred, max=self.DEFAULT_EXP_CLAMP).exp()) 162 | # [B, KA, C, H, W] -> [B, H, W, KA, C] -> [B, M, C], M = HxWxKA 163 | normalized_cls_pred = normalized_cls_pred.permute(0, 3, 4, 1, 2).contiguous() 164 | normalized_cls_pred = normalized_cls_pred.view(B, -1, self.num_classes) 165 | # [B, KA*4, H, W] -> [B, KA, 4, H, W] -> [B, H, W, KA, 4] -> [B, M, 4] 166 | reg_pred = reg_pred.view(B, -1, 4, H, W).permute(0, 3, 4, 1, 2).contiguous() 167 | reg_pred = reg_pred.view(B, -1, 4) 168 | ## Decode bbox 169 | box_pred = self.decode_boxes(anchor_boxes[None], reg_pred) # [B, M, 4] 170 | ## adjust mask 171 | if mask is not None: 172 | # [B, H, W] 173 | mask = torch.nn.functional.interpolate(mask[None].float(), size=fmp_size).bool()[0] 174 | # [B, H, W] -> [B, HW] 175 | mask = mask.flatten(1) 176 | # [B, HW] -> [B, HW, KA] -> [BM,], M= HW x KA 177 | mask = mask[..., None].repeat(1, 1, self.num_anchors).flatten() 178 | 179 | outputs = {"pred_cls": normalized_cls_pred, 180 | "pred_reg": reg_pred, 181 | "pred_box": box_pred, 182 | "anchors": anchor_boxes, 183 | "mask": mask} 184 | 185 | return outputs 186 | -------------------------------------------------------------------------------- /config/fcos_config.py: -------------------------------------------------------------------------------- 1 | # Fully Convolutional One-Stage object detector 2 | 3 | def build_fcos_config(args): 4 | if args.model == 'fcos_r18_1x': 5 | return Fcos_R18_1x_Config() 6 | elif args.model == 'fcos_r50_1x': 7 | return Fcos_R50_1x_Config() 8 | elif args.model == 'fcos_r101_1x': 9 | return Fcos_R101_1x_Config() 10 | 11 | elif args.model == 'fcos_r18_3x': 12 | return Fcos_R18_3x_Config() 13 | elif args.model == 'fcos_r50_3x': 14 | return Fcos_R50_3x_Config() 15 | elif args.model == 'fcos_r101_3x': 16 | return Fcos_R101_3x_Config() 17 | 18 | elif args.model == 'fcos_rt_r18_3x': 19 | return FcosRT_R18_3x_Config() 20 | elif args.model == 'fcos_rt_r50_3x': 21 | return FcosRT_R50_3x_Config() 22 | 23 | else: 24 | raise NotImplementedError("No config for model: {}".format(args.model)) 25 | 26 | 27 | # -------------- Base configuration -------------- 28 | class FcosBaseConfig(object): 29 | def __init__(self): 30 | # --------- Backbone --------- 31 | self.backbone = "resnet50" 32 | self.bk_norm = "FrozeBN" 33 | self.res5_dilation = False 34 | self.use_pretrained = True 35 | self.freeze_at = 1 36 | self.max_stride = 128 37 | self.out_stride = [8, 16, 32, 64, 128] 38 | 39 | # --------- Neck --------- 40 | self.neck = 'basic_fpn' 41 | self.fpn_p6_feat = True 42 | self.fpn_p7_feat = True 43 | self.fpn_p6_from_c5 = False 44 | 45 | # --------- Head --------- 46 | self.head = 'fcos_head' 47 | self.head_dim = 256 48 | self.num_cls_head = 4 49 | self.num_reg_head = 4 50 | self.head_act = 'relu' 51 | self.head_norm = 'GN' 52 | 53 | # --------- Post-process --------- 54 | self.train_topk = 1000 55 | self.train_conf_thresh = 0.05 56 | self.train_nms_thresh = 0.6 57 | self.test_topk = 100 58 | self.test_conf_thresh = 0.5 59 | self.test_nms_thresh = 0.45 60 | self.nms_class_agnostic = True 61 | 62 | # --------- Label Assignment --------- 63 | self.matcher = 'fcos_matcher' 64 | self.matcher_hpy = {'center_sampling_radius': 1.5, 65 | 'object_sizes_of_interest': [[-1, 64], 66 | [64, 128], 67 | [128, 256], 68 | [256, 512], 69 | [512, float('inf')]] 70 | } 71 | 72 | # --------- Loss weight --------- 73 | self.focal_loss_alpha = 0.25 74 | self.focal_loss_gamma = 2.0 75 | self.loss_cls_weight = 1.0 76 | self.loss_reg_weight = 1.0 77 | self.loss_ctn_weight = 1.0 78 | 79 | # --------- Optimizer --------- 80 | self.optimizer = 'sgd' 81 | self.batch_size_base = 16 82 | self.per_image_lr = 0.01 / 16 83 | self.bk_lr_ratio = 1.0 / 1.0 84 | self.momentum = 0.9 85 | self.weight_decay = 1e-4 86 | self.clip_max_norm = -1.0 87 | 88 | # --------- LR Scheduler --------- 89 | self.lr_scheduler = 'step' 90 | self.warmup = 'linear' 91 | self.warmup_iters = 500 92 | self.warmup_factor = 0.00066667 93 | 94 | # --------- Train epoch --------- 95 | self.max_epoch = 12 # 1x 96 | self.lr_epoch = [8, 11] # 1x 97 | self.eval_epoch = 2 98 | 99 | # --------- Data process --------- 100 | ## input size 101 | self.train_min_size = [800] # short edge of image 102 | self.train_max_size = 1333 103 | self.test_min_size = [800] 104 | self.test_max_size = 1333 105 | ## Pixel mean & std 106 | self.pixel_mean = [0.485, 0.456, 0.406] 107 | self.pixel_std = [0.229, 0.224, 0.225] 108 | ## Transforms 109 | self.box_format = 'xyxy' 110 | self.normalize_coords = False 111 | self.detr_style = False 112 | self.trans_config = [ 113 | {'name': 'RandomHFlip'}, 114 | {'name': 'RandomResize'}, 115 | ] 116 | 117 | def print_config(self): 118 | config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')} 119 | for k, v in config_dict.items(): 120 | print("{} : {}".format(k, v)) 121 | 122 | # -------------- 1x scheduler -------------- 123 | class Fcos_R18_1x_Config(FcosBaseConfig): 124 | def __init__(self) -> None: 125 | super().__init__() 126 | ## Backbone 127 | self.backbone = "resnet18" 128 | 129 | class Fcos_R50_1x_Config(Fcos_R18_1x_Config): 130 | def __init__(self) -> None: 131 | super().__init__() 132 | ## Backbone 133 | self.backbone = "resnet50" 134 | 135 | class Fcos_R101_1x_Config(Fcos_R18_1x_Config): 136 | def __init__(self) -> None: 137 | super().__init__() 138 | ## Backbone 139 | self.backbone = "resnet101" 140 | 141 | # -------------- 3x scheduler -------------- 142 | class Fcos_R18_3x_Config(Fcos_R18_1x_Config): 143 | def __init__(self) -> None: 144 | super().__init__() 145 | # --------- Train epoch --------- 146 | self.max_epoch = 36 # 3x 147 | self.lr_epoch = [24, 33] # 3x 148 | self.eval_epoch = 2 149 | 150 | # --------- Data process --------- 151 | ## input size 152 | self.train_min_size = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] # short edge of image 153 | self.train_max_size = 1333 154 | self.test_min_size = [800] 155 | self.test_max_size = 1333 156 | 157 | class Fcos_R50_3x_Config(Fcos_R18_3x_Config): 158 | def __init__(self) -> None: 159 | super().__init__() 160 | ## Backbone 161 | self.backbone = "resnet50" 162 | 163 | class Fcos_R101_3x_Config(Fcos_R18_3x_Config): 164 | def __init__(self) -> None: 165 | super().__init__() 166 | ## Backbone 167 | self.backbone = "resnet101" 168 | 169 | # -------------- RT-FCOS series -------------- 170 | class FcosRT_R18_3x_Config(FcosBaseConfig): 171 | def __init__(self) -> None: 172 | super().__init__() 173 | ## Backbone 174 | self.backbone = "resnet18" 175 | self.max_stride = 32 176 | self.out_stride = [8, 16, 32] 177 | 178 | # --------- Neck --------- 179 | self.neck = 'basic_fpn' 180 | self.fpn_p6_feat = False 181 | self.fpn_p7_feat = False 182 | self.fpn_p6_from_c5 = False 183 | 184 | # --------- Head --------- 185 | self.head = 'fcos_rt_head' 186 | self.head_dim = 256 187 | self.num_cls_head = 4 188 | self.num_reg_head = 4 189 | self.head_act = 'relu' 190 | self.head_norm = 'GN' 191 | 192 | # --------- Post-process --------- 193 | self.train_topk = 1000 194 | self.train_conf_thresh = 0.05 195 | self.train_nms_thresh = 0.6 196 | self.test_topk = 100 197 | self.test_conf_thresh = 0.4 198 | self.test_nms_thresh = 0.45 199 | self.nms_class_agnostic = True 200 | 201 | # --------- Label Assignment --------- 202 | self.matcher = 'simota' 203 | self.matcher_hpy = {'soft_center_radius': 3.0, 204 | 'topk_candidates': 13} 205 | 206 | # --------- Loss weight --------- 207 | self.focal_loss_alpha = 0.25 208 | self.focal_loss_gamma = 2.0 209 | self.loss_cls_weight = 1.0 210 | self.loss_reg_weight = 2.0 211 | 212 | # --------- Train epoch --------- 213 | self.max_epoch = 36 # 3x 214 | self.lr_epoch = [24, 33] # 3x 215 | 216 | # --------- Data process --------- 217 | ## input size 218 | self.train_min_size = [256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608] # short edge of image 219 | self.train_max_size = 900 220 | self.test_min_size = [512] 221 | self.test_max_size = 736 222 | ## Pixel mean & std 223 | self.pixel_mean = [0.485, 0.456, 0.406] 224 | self.pixel_std = [0.229, 0.224, 0.225] 225 | ## Transforms 226 | self.box_format = 'xyxy' 227 | self.normalize_coords = False 228 | self.detr_style = False 229 | self.trans_config = [ 230 | {'name': 'RandomHFlip'}, 231 | {'name': 'RandomResize'}, 232 | ] 233 | 234 | class FcosRT_R50_3x_Config(FcosRT_R18_3x_Config): 235 | def __init__(self) -> None: 236 | super().__init__() 237 | ## Backbone 238 | self.backbone = "resnet50" 239 | -------------------------------------------------------------------------------- /config/yolof_config.py: -------------------------------------------------------------------------------- 1 | # You Only Look One-level Feature 2 | 3 | def build_yolof_config(args): 4 | if args.model == 'yolof_r18_c5_1x': 5 | return Yolof_R18_C5_1x_Config() 6 | elif args.model == 'yolof_r50_c5_1x': 7 | return Yolof_R50_C5_1x_Config() 8 | elif args.model == 'yolof_r101_c5_1x': 9 | return Yolof_R101_C5_1x_Config() 10 | 11 | elif args.model == 'yolof_r18_c5_3x': 12 | return Yolof_R18_C5_3x_Config() 13 | elif args.model == 'yolof_r50_c5_3x': 14 | return Yolof_R50_C5_3x_Config() 15 | elif args.model == 'yolof_r101_c5_3x': 16 | return Yolof_R101_C5_3x_Config() 17 | 18 | elif args.model == 'yolof_r50_dc5_1x': 19 | return Yolof_R50_DC5_1x_Config() 20 | elif args.model == 'yolof_r101_dc5_1x': 21 | return Yolof_R101_DC5_1x_Config() 22 | elif args.model == 'yolof_r50_dc5_3x': 23 | return Yolof_R50_DC5_3x_Config() 24 | elif args.model == 'yolof_r101_dc5_3x': 25 | return Yolof_R101_DC5_3x_Config() 26 | 27 | else: 28 | raise NotImplementedError("No config for model: {}".format(args.model)) 29 | 30 | 31 | # -------------- Base configuration -------------- 32 | class YolofBaseConfig(object): 33 | def __init__(self): 34 | # --------- Backbone --------- 35 | self.backbone = "resnet50" 36 | self.bk_norm = "FrozeBN" 37 | self.res5_dilation = False 38 | self.use_pretrained = True 39 | self.freeze_at = 1 40 | self.max_stride = 32 41 | self.out_stride = 32 42 | 43 | # --------- Neck --------- 44 | self.neck = 'dilated_encoder' 45 | self.neck_dilations = [2, 4, 6, 8] 46 | self.neck_expand_ratio = 0.25 47 | self.neck_act = 'relu' 48 | self.neck_norm = 'BN' 49 | 50 | # --------- Head --------- 51 | self.head = 'yolof_head' 52 | self.head_dim = 512 53 | self.num_cls_head = 2 54 | self.num_reg_head = 4 55 | self.head_act = 'relu' 56 | self.head_norm = 'BN' 57 | self.center_clamp = 32 58 | self.anchor_size = [[32, 32], 59 | [64, 64], 60 | [128, 128], 61 | [256, 256], 62 | [512, 512]] 63 | 64 | # --------- Post-process --------- 65 | self.train_topk = 1000 66 | self.train_conf_thresh = 0.05 67 | self.train_nms_thresh = 0.6 68 | self.test_topk = 300 69 | self.test_conf_thresh = 0.3 70 | self.test_nms_thresh = 0.45 71 | self.nms_class_agnostic = True 72 | 73 | # --------- Label Assignment --------- 74 | self.matcher = 'yolof_matcher' 75 | self.matcher_hpy = {'topk_candidates': 4, 76 | 'iou_thresh': 0.15, 77 | 'ignore_thresh': 0.7, 78 | } 79 | 80 | # --------- Loss weight --------- 81 | self.focal_loss_alpha = 0.25 82 | self.focal_loss_gamma = 2.0 83 | self.loss_cls_weight = 1.0 84 | self.loss_reg_weight = 1.0 85 | 86 | # --------- Optimizer --------- 87 | self.optimizer = 'sgd' 88 | self.batch_size_base = 64 89 | self.per_image_lr = 0.12 / 64 90 | self.bk_lr_ratio = 1.0 / 3.0 91 | self.momentum = 0.9 92 | self.weight_decay = 1e-4 93 | self.clip_max_norm = 10.0 94 | 95 | 96 | # --------- LR Scheduler --------- 97 | self.lr_scheduler = 'step' 98 | self.warmup = 'linear' 99 | self.warmup_iters = 1500 100 | self.warmup_factor = 0.00066667 101 | 102 | # --------- Train epoch --------- 103 | self.max_epoch = 12 # 1x 104 | self.lr_epoch = [8, 11] # 1x 105 | self.eval_epoch = 2 106 | 107 | # --------- Data process --------- 108 | ## input size 109 | self.train_min_size = [800] # short edge of image 110 | self.train_max_size = 1333 111 | self.test_min_size = [800] 112 | self.test_max_size = 1333 113 | ## Pixel mean & std 114 | self.pixel_mean = [0.485, 0.456, 0.406] 115 | self.pixel_std = [0.229, 0.224, 0.225] 116 | ## Transforms 117 | self.box_format = 'xyxy' 118 | self.normalize_coords = False 119 | self.detr_style = False 120 | self.trans_config = [ 121 | {'name': 'RandomHFlip'}, 122 | {'name': 'RandomResize'}, 123 | {'name': 'RandomShift', 'max_shift': 32}, 124 | ] 125 | 126 | def print_config(self): 127 | config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')} 128 | for k, v in config_dict.items(): 129 | print("{} : {}".format(k, v)) 130 | 131 | # -------------- 1x scheduler -------------- 132 | class Yolof_R18_C5_1x_Config(YolofBaseConfig): 133 | def __init__(self) -> None: 134 | super().__init__() 135 | ## Backbone 136 | # --------- Backbone --------- 137 | self.backbone = "resnet18" 138 | 139 | class Yolof_R50_C5_1x_Config(Yolof_R18_C5_1x_Config): 140 | def __init__(self) -> None: 141 | super().__init__() 142 | ## Backbone 143 | # --------- Backbone --------- 144 | self.backbone = "resnet50" 145 | 146 | class Yolof_R101_C5_1x_Config(Yolof_R18_C5_1x_Config): 147 | def __init__(self) -> None: 148 | super().__init__() 149 | ## Backbone 150 | # --------- Backbone --------- 151 | self.backbone = "resnet101" 152 | 153 | class Yolof_R50_DC5_1x_Config(YolofBaseConfig): 154 | def __init__(self) -> None: 155 | super().__init__() 156 | ## Backbone 157 | # --------- Backbone --------- 158 | self.backbone = "resnet50" 159 | self.res5_dilation = True 160 | self.use_pretrained = True 161 | self.max_stride = 16 162 | self.out_stride = 16 163 | 164 | # --------- Neck --------- 165 | self.neck = 'dilated_encoder' 166 | self.neck_dilations = [4, 8, 12, 16] 167 | self.neck_expand_ratio = 0.25 168 | self.neck_act = 'relu' 169 | self.neck_norm = 'BN' 170 | 171 | # --------- Head --------- 172 | self.anchor_size = [[16, 16], 173 | [32, 32], 174 | [64, 64], 175 | [128, 128], 176 | [256, 256], 177 | [512, 512]], 178 | 179 | # --------- Label Assignment --------- 180 | self.matcher = 'yolof_matcher' 181 | self.matcher_hpy = {'topk_candidates': 8, 182 | 'iou_thresh': 0.1, 183 | 'ignore_thresh': 0.7, 184 | } 185 | 186 | class Yolof_R101_DC5_1x_Config(Yolof_R50_DC5_1x_Config): 187 | def __init__(self) -> None: 188 | super().__init__() 189 | ## Backbone 190 | # --------- Backbone --------- 191 | self.backbone = "resnet101" 192 | 193 | # -------------- 3x scheduler -------------- 194 | class Yolof_R18_C5_3x_Config(Yolof_R18_C5_1x_Config): 195 | def __init__(self) -> None: 196 | super().__init__() 197 | # --------- Train epoch --------- 198 | self.max_epoch = 36 # 3x 199 | self.lr_epoch = [24, 33] # 3x 200 | self.eval_epoch = 2 201 | 202 | # --------- Data process --------- 203 | ## input size 204 | self.train_min_size = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] # short edge of image 205 | self.train_max_size = 1333 206 | self.test_min_size = [800] 207 | self.test_max_size = 1333 208 | 209 | class Yolof_R50_C5_3x_Config(Yolof_R18_C5_3x_Config): 210 | def __init__(self) -> None: 211 | super().__init__() 212 | # --------- Backbone --------- 213 | self.backbone = "resnet50" 214 | 215 | class Yolof_R101_C5_3x_Config(Yolof_R18_C5_3x_Config): 216 | def __init__(self) -> None: 217 | super().__init__() 218 | # --------- Backbone --------- 219 | self.backbone = "resnet101" 220 | 221 | class Yolof_R50_DC5_3x_Config(Yolof_R50_DC5_1x_Config): 222 | def __init__(self) -> None: 223 | super().__init__() 224 | # --------- Train epoch --------- 225 | self.max_epoch = 36 # 3x 226 | self.lr_epoch = [24, 33] # 3x 227 | self.eval_epoch = 2 228 | 229 | # --------- Data process --------- 230 | ## input size 231 | self.train_min_size = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] # short edge of image 232 | self.train_max_size = 1333 233 | self.test_min_size = [800] 234 | self.test_max_size = 1333 235 | 236 | class Yolof_R101_DC5_3x_Config(Yolof_R50_DC5_3x_Config): 237 | def __init__(self) -> None: 238 | super().__init__() 239 | # --------- Backbone --------- 240 | self.backbone = "resnet101" 241 | -------------------------------------------------------------------------------- /datasets/coco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | COCO dataset which returns image_id for evaluation. 4 | 5 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 6 | """ 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.utils.data 11 | import torchvision 12 | 13 | try: 14 | from .transforms import build_transform 15 | except: 16 | from transforms import build_transform 17 | 18 | 19 | # coco_labels = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') 20 | coco_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') 21 | coco_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] 22 | 23 | 24 | class CocoDetection(torchvision.datasets.CocoDetection): 25 | def __init__(self, img_folder, ann_file, transforms): 26 | super(CocoDetection, self).__init__(img_folder, ann_file) 27 | self.coco_labels = coco_labels # 80 coco labels for detection task 28 | self.coco_indexs = coco_indexs # all original coco label index 29 | self._transforms = transforms 30 | 31 | def prepare(self, image, target): 32 | w, h = image.size 33 | # load an image 34 | image_id = target["image_id"] 35 | image_id = torch.tensor([image_id]) 36 | 37 | # load an annotation 38 | anno = target["annotations"] 39 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 40 | 41 | # bbox target 42 | boxes = [obj["bbox"] for obj in anno] 43 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 44 | boxes[:, 2:] += boxes[:, :2] 45 | boxes[:, 0::2].clamp_(min=0, max=w) 46 | boxes[:, 1::2].clamp_(min=0, max=h) 47 | 48 | # class target 49 | classes = [self.coco_indexs.index(obj["category_id"]) for obj in anno] 50 | classes = torch.tensor(classes, dtype=torch.int64) 51 | 52 | # filter invalid bbox 53 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 54 | boxes = boxes[keep] 55 | classes = classes[keep] 56 | 57 | target = {} 58 | target["boxes"] = boxes 59 | target["labels"] = classes 60 | target["image_id"] = image_id 61 | 62 | # for conversion to coco api 63 | area = torch.tensor([obj["area"] for obj in anno]) 64 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 65 | target["area"] = area[keep] 66 | target["iscrowd"] = iscrowd[keep] 67 | 68 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 69 | target["size"] = torch.as_tensor([int(h), int(w)]) 70 | 71 | return image, target 72 | 73 | def __getitem__(self, idx): 74 | img, target = super(CocoDetection, self).__getitem__(idx) 75 | image_id = self.ids[idx] 76 | target = {'image_id': image_id, 'annotations': target} 77 | img, target = self.prepare(img, target) 78 | if self._transforms is not None: 79 | img, target = self._transforms(img, target) 80 | 81 | return img, target 82 | 83 | 84 | def build_coco(args, transform=None, is_train=False): 85 | root = Path(args.root) 86 | assert root.exists(), f'provided COCO path {root} does not exist' 87 | PATHS = { 88 | "train": (root / "train2017", root / "annotations" / 'instances_train2017.json'), 89 | "val": (root / "val2017", root / "annotations" / 'instances_val2017.json'), 90 | } 91 | 92 | image_set = "train" if is_train else "val" 93 | img_folder, ann_file = PATHS[image_set] 94 | 95 | # build transform 96 | dataset = CocoDetection(img_folder, ann_file, transform) 97 | 98 | return dataset 99 | 100 | 101 | if __name__ == "__main__": 102 | import argparse 103 | import cv2 104 | import numpy as np 105 | 106 | parser = argparse.ArgumentParser(description='COCO-Dataset') 107 | 108 | # opt 109 | parser.add_argument('--root', default='D:/python_work/dataset/COCO/', 110 | help='data root') 111 | parser.add_argument('--is_train', action="store_true", default=False, 112 | help='mixup augmentation.') 113 | args = parser.parse_args() 114 | 115 | np.random.seed(0) 116 | class_colors = [(np.random.randint(255), 117 | np.random.randint(255), 118 | np.random.randint(255)) for _ in range(80)] 119 | 120 | # config 121 | class BaseConfig(object): 122 | def __init__(self): 123 | # --------- Data process --------- 124 | ## input size 125 | self.train_min_size = [512] # short edge of image 126 | self.train_max_size = 736 127 | self.test_min_size = [512] 128 | self.test_max_size = 736 129 | ## Pixel mean & std 130 | self.pixel_mean = [0.485, 0.456, 0.406] 131 | self.pixel_std = [0.229, 0.224, 0.225] 132 | ## Transforms 133 | self.box_format = 'xyxy' 134 | self.normalize_coords = False 135 | self.detr_style = False 136 | self.trans_config = [ 137 | {'name': 'RandomHFlip'}, 138 | {'name': 'RandomResize'}, 139 | {'name': 'RandomShift', 'max_shift': 32}, 140 | ] 141 | 142 | cfg = BaseConfig() 143 | # build dataset 144 | transform = build_transform(cfg, is_train=True) 145 | dataset = build_coco(args, transform, is_train=False) 146 | 147 | for index, (image, target) in enumerate(dataset): 148 | print("{} / {}".format(index, len(dataset))) 149 | # to numpy 150 | image = image.permute(1, 2, 0).numpy() 151 | # denormalize 152 | image = (image * cfg.pixel_std + cfg.pixel_mean) * 255 153 | image = image.astype(np.uint8)[..., (2, 1, 0)].copy() 154 | orig_h, orig_w = image.shape[:2] 155 | 156 | tgt_bboxes = target["boxes"] 157 | tgt_labels = target["labels"] 158 | for box, label in zip(tgt_bboxes, tgt_labels): 159 | if cfg.normalize_coords: 160 | box[..., [0, 2]] *= orig_w 161 | box[..., [1, 3]] *= orig_h 162 | if cfg.box_format == 'xywh': 163 | box_x1y1 = box[..., :2] - box[..., 2:] * 0.5 164 | box_x2y2 = box[..., :2] + box[..., 2:] * 0.5 165 | box = torch.cat([box_x1y1, box_x2y2], dim=-1) 166 | # get box target 167 | x1, y1, x2, y2 = box.long() 168 | # get class label 169 | cls_name = coco_labels[label.item()] 170 | color = class_colors[label.item()] 171 | # draw bbox 172 | image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2) 173 | # put the test on the bbox 174 | cv2.putText(image, cls_name, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA) 175 | 176 | cv2.imshow("data", image) 177 | cv2.waitKey(0) 178 | 179 | -------------------------------------------------------------------------------- /models/detectors/fcos/fcos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # --------------- Model components --------------- 5 | from ...backbone import build_backbone 6 | from ...neck import build_neck 7 | from ...head import build_head 8 | 9 | # --------------- External components --------------- 10 | from utils.misc import multiclass_nms 11 | 12 | 13 | # ------------------------ Fully Convolutional One-Stage Detector ------------------------ 14 | class FCOS(nn.Module): 15 | def __init__(self, 16 | cfg, 17 | num_classes :int = 80, 18 | conf_thresh :float = 0.05, 19 | nms_thresh :float = 0.6, 20 | topk :int = 1000, 21 | ca_nms :bool = False): 22 | super(FCOS, self).__init__() 23 | # ---------------------- Basic Parameters ---------------------- 24 | self.cfg = cfg 25 | self.topk = topk 26 | self.num_classes = num_classes 27 | self.conf_thresh = conf_thresh 28 | self.nms_thresh = nms_thresh 29 | self.ca_nms = ca_nms 30 | 31 | # ---------------------- Network Parameters ---------------------- 32 | ## Backbone 33 | self.backbone, feat_dims = build_backbone(cfg) 34 | 35 | ## Neck 36 | self.fpn = build_neck(cfg, feat_dims, cfg.head_dim) 37 | 38 | ## Heads 39 | self.head = build_head(cfg, cfg.head_dim, cfg.head_dim) 40 | 41 | def post_process(self, cls_preds, ctn_preds, box_preds): 42 | """ 43 | Input: 44 | cls_preds: List(Tensor) [[B, H x W, C], ...] 45 | ctn_preds: List(Tensor) [[B, H x W, 1], ...] 46 | box_preds: List(Tensor) [[B, H x W, 4], ...] 47 | """ 48 | all_scores = [] 49 | all_labels = [] 50 | all_bboxes = [] 51 | 52 | for cls_pred_i, ctn_pred_i, box_pred_i in zip(cls_preds, ctn_preds, box_preds): 53 | cls_pred_i = cls_pred_i[0] 54 | ctn_pred_i = ctn_pred_i[0] 55 | box_pred_i = box_pred_i[0] 56 | 57 | # (H x W x C,) 58 | scores_i = torch.sqrt(cls_pred_i.sigmoid() * ctn_pred_i.sigmoid()).flatten() 59 | 60 | # Keep top k top scoring indices only. 61 | num_topk = min(self.topk, box_pred_i.size(0)) 62 | 63 | # torch.sort is actually faster than .topk (at least on GPUs) 64 | predicted_prob, topk_idxs = scores_i.sort(descending=True) 65 | topk_scores = predicted_prob[:num_topk] 66 | topk_idxs = topk_idxs[:num_topk] 67 | 68 | # filter out the proposals with low confidence score 69 | keep_idxs = topk_scores > self.conf_thresh 70 | topk_idxs = topk_idxs[keep_idxs] 71 | 72 | # final scores 73 | scores = topk_scores[keep_idxs] 74 | # final labels 75 | labels = topk_idxs % self.num_classes 76 | # final bboxes 77 | anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor') 78 | bboxes = box_pred_i[anchor_idxs] 79 | 80 | all_scores.append(scores) 81 | all_labels.append(labels) 82 | all_bboxes.append(bboxes) 83 | 84 | scores = torch.cat(all_scores) 85 | labels = torch.cat(all_labels) 86 | bboxes = torch.cat(all_bboxes) 87 | 88 | # to cpu & numpy 89 | scores = scores.cpu().numpy() 90 | labels = labels.cpu().numpy() 91 | bboxes = bboxes.cpu().numpy() 92 | 93 | # nms 94 | scores, labels, bboxes = multiclass_nms( 95 | scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms) 96 | 97 | return bboxes, scores, labels 98 | 99 | def forward(self, src, src_mask=None): 100 | # ---------------- Backbone ---------------- 101 | pyramid_feats = self.backbone(src) 102 | 103 | # ---------------- Neck ---------------- 104 | pyramid_feats = self.fpn(pyramid_feats) 105 | 106 | # ---------------- Heads ---------------- 107 | outputs = self.head(pyramid_feats, src_mask) 108 | 109 | if not self.training: 110 | # ---------------- PostProcess ---------------- 111 | cls_pred = outputs["pred_cls"] 112 | ctn_pred = outputs["pred_ctn"] 113 | box_pred = outputs["pred_box"] 114 | bboxes, scores, labels = self.post_process(cls_pred, ctn_pred, box_pred) 115 | # normalize bbox 116 | bboxes[..., 0::2] /= src.shape[-1] 117 | bboxes[..., 1::2] /= src.shape[-2] 118 | bboxes = bboxes.clip(0., 1.) 119 | 120 | outputs = { 121 | 'scores': scores, 122 | 'labels': labels, 123 | 'bboxes': bboxes 124 | } 125 | 126 | return outputs 127 | 128 | # ------------------------ Real-time FCOS ------------------------ 129 | class FcosRT(nn.Module): 130 | def __init__(self, 131 | cfg, 132 | num_classes :int = 80, 133 | conf_thresh :float = 0.05, 134 | nms_thresh :float = 0.6, 135 | topk :int = 1000, 136 | ca_nms :bool = False): 137 | super(FcosRT, self).__init__() 138 | # ---------------------- Basic Parameters ---------------------- 139 | self.cfg = cfg 140 | self.topk = topk 141 | self.num_classes = num_classes 142 | self.conf_thresh = conf_thresh 143 | self.nms_thresh = nms_thresh 144 | self.ca_nms = ca_nms 145 | 146 | # ---------------------- Network Parameters ---------------------- 147 | ## Backbone 148 | self.backbone, feat_dims = build_backbone(cfg) 149 | 150 | ## Neck 151 | self.fpn = build_neck(cfg, feat_dims, cfg.head_dim) 152 | 153 | ## Heads 154 | self.head = build_head(cfg, cfg.head_dim, cfg.head_dim) 155 | 156 | def post_process(self, cls_preds, box_preds): 157 | """ 158 | Input: 159 | cls_preds: List(Tensor) [[B, H x W, C], ...] 160 | box_preds: List(Tensor) [[B, H x W, 4], ...] 161 | """ 162 | all_scores = [] 163 | all_labels = [] 164 | all_bboxes = [] 165 | 166 | for cls_pred_i, box_pred_i in zip(cls_preds, box_preds): 167 | cls_pred_i = cls_pred_i[0] 168 | box_pred_i = box_pred_i[0] 169 | 170 | # (H x W x C,) 171 | scores_i = cls_pred_i.sigmoid().flatten() 172 | 173 | # Keep top k top scoring indices only. 174 | num_topk = min(self.topk, box_pred_i.size(0)) 175 | 176 | # torch.sort is actually faster than .topk (at least on GPUs) 177 | predicted_prob, topk_idxs = scores_i.sort(descending=True) 178 | topk_scores = predicted_prob[:num_topk] 179 | topk_idxs = topk_idxs[:num_topk] 180 | 181 | # filter out the proposals with low confidence score 182 | keep_idxs = topk_scores > self.conf_thresh 183 | topk_idxs = topk_idxs[keep_idxs] 184 | 185 | # final scores 186 | scores = topk_scores[keep_idxs] 187 | # final labels 188 | labels = topk_idxs % self.num_classes 189 | # final bboxes 190 | anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor') 191 | bboxes = box_pred_i[anchor_idxs] 192 | 193 | all_scores.append(scores) 194 | all_labels.append(labels) 195 | all_bboxes.append(bboxes) 196 | 197 | scores = torch.cat(all_scores) 198 | labels = torch.cat(all_labels) 199 | bboxes = torch.cat(all_bboxes) 200 | 201 | # to cpu & numpy 202 | scores = scores.cpu().numpy() 203 | labels = labels.cpu().numpy() 204 | bboxes = bboxes.cpu().numpy() 205 | 206 | # nms 207 | scores, labels, bboxes = multiclass_nms( 208 | scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms) 209 | 210 | return bboxes, scores, labels 211 | 212 | def forward(self, src, src_mask=None): 213 | # ---------------- Backbone ---------------- 214 | pyramid_feats = self.backbone(src) 215 | 216 | # ---------------- Neck ---------------- 217 | pyramid_feats = self.fpn(pyramid_feats) 218 | 219 | # ---------------- Heads ---------------- 220 | outputs = self.head(pyramid_feats, src_mask) 221 | 222 | if not self.training: 223 | # ---------------- PostProcess ---------------- 224 | cls_pred = outputs["pred_cls"] 225 | box_pred = outputs["pred_box"] 226 | bboxes, scores, labels = self.post_process(cls_pred, box_pred) 227 | # normalize bbox 228 | bboxes[..., 0::2] /= src.shape[-1] 229 | bboxes[..., 1::2] /= src.shape[-2] 230 | bboxes = bboxes.clip(0., 1.) 231 | 232 | outputs = { 233 | 'scores': scores, 234 | 'labels': labels, 235 | 'bboxes': bboxes 236 | } 237 | 238 | return outputs 239 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import random 4 | import argparse 5 | import numpy as np 6 | from copy import deepcopy 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | 12 | from utils import distributed_utils 13 | from utils.misc import compute_flops, collate_fn 14 | from utils.optimizer import build_optimizer 15 | from utils.lr_scheduler import build_wp_lr_scheduler, build_lr_scheduler 16 | 17 | from config import build_config 18 | from evaluator import build_evluator 19 | from datasets import build_dataset, build_dataloader, build_transform 20 | 21 | from models.detectors import build_model 22 | from engine import train_one_epoch 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser('General 2D Object Detection', add_help=False) 27 | # Random seed 28 | parser.add_argument('--seed', default=42, type=int) 29 | # GPU 30 | parser.add_argument('--cuda', action='store_true', default=False, 31 | help='use cuda.') 32 | # Batch size 33 | parser.add_argument('-bs', '--batch_size', default=16, type=int, 34 | help='total batch size on all GPUs.') 35 | # Model 36 | parser.add_argument('-m', '--model', default='yolof_r18_c5_1x', 37 | help='build object detector') 38 | parser.add_argument('-r', '--resume', default=None, type=str, 39 | help='keep training') 40 | # Dataset 41 | parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/', 42 | help='data root') 43 | parser.add_argument('-d', '--dataset', default='coco', 44 | help='coco, voc, widerface, crowdhuman') 45 | parser.add_argument('--vis_tgt', action="store_true", default=False, 46 | help="visualize input data.") 47 | # Dataloader 48 | parser.add_argument('--num_workers', default=2, type=int, 49 | help='Number of workers used in dataloading') 50 | # Epoch 51 | parser.add_argument('--save_folder', default='weights/', type=str, 52 | help='path to save weight') 53 | parser.add_argument('--eval_first', action="store_true", default=False, 54 | help="visualize input data.") 55 | # DDP train 56 | parser.add_argument('-dist', '--distributed', action='store_true', default=False, 57 | help='distributed training') 58 | parser.add_argument('--dist_url', default='env://', 59 | help='url used to set up distributed training') 60 | parser.add_argument('--world_size', default=1, type=int, 61 | help='number of distributed processes') 62 | parser.add_argument('--sybn', action='store_true', default=False, 63 | help='use sybn.') 64 | # Debug setting 65 | parser.add_argument('--debug', action='store_true', default=False, 66 | help='debug codes.') 67 | 68 | return parser.parse_args() 69 | 70 | 71 | def fix_random_seed(args): 72 | seed = args.seed + distributed_utils.get_rank() 73 | torch.manual_seed(seed) 74 | np.random.seed(seed) 75 | random.seed(seed) 76 | 77 | 78 | def main(): 79 | args = parse_args() 80 | print("Setting Arguments.. : ", args) 81 | print("----------------------------------------------------------") 82 | 83 | # path to save model 84 | path_to_save = os.path.join(args.save_folder, args.dataset, args.model) 85 | os.makedirs(path_to_save, exist_ok=True) 86 | 87 | # ---------------------------- Build DDP ---------------------------- 88 | local_rank = local_process_rank = -1 89 | if args.distributed: 90 | distributed_utils.init_distributed_mode(args) 91 | print("git:\n {}\n".format(distributed_utils.get_sha())) 92 | try: 93 | # Multiple Mechine & Multiple GPUs (world size > 8) 94 | local_rank = torch.distributed.get_rank() 95 | local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0')) 96 | except: 97 | # Single Mechine & Multiple GPUs (world size <= 8) 98 | local_rank = local_process_rank = torch.distributed.get_rank() 99 | world_size = distributed_utils.get_world_size() 100 | per_gpu_batch = args.batch_size // world_size 101 | print("LOCAL RANK: ", local_rank) 102 | print("LOCAL_PROCESS_RANL: ", local_process_rank) 103 | print('WORLD SIZE: {}'.format(world_size)) 104 | 105 | # ---------------------------- Build CUDA ---------------------------- 106 | if args.cuda and torch.cuda.is_available(): 107 | print('use cuda') 108 | device = torch.device("cuda") 109 | else: 110 | device = torch.device("cpu") 111 | 112 | # ---------------------------- Fix random seed ---------------------------- 113 | fix_random_seed(args) 114 | 115 | # ---------------------------- Build config ---------------------------- 116 | cfg = build_config(args) 117 | 118 | # ---------------------------- Build Dataset ---------------------------- 119 | transforms = build_transform(cfg, is_train=True) 120 | dataset = build_dataset(args, cfg, transforms, is_train=True) 121 | 122 | # ---------------------------- Build Dataloader ---------------------------- 123 | train_loader = build_dataloader(args, dataset, per_gpu_batch, collate_fn, is_train=True) 124 | 125 | # ---------------------------- Build model ---------------------------- 126 | ## Build model 127 | model, criterion = build_model(args, cfg, is_val=True) 128 | model.to(device) 129 | model_without_ddp = model 130 | ## Calcute Params & GFLOPs 131 | if distributed_utils.is_main_process(): 132 | model_copy = deepcopy(model_without_ddp) 133 | model_copy.trainable = False 134 | model_copy.eval() 135 | compute_flops(model=model_copy, 136 | min_size=cfg.test_min_size, 137 | max_size=cfg.test_max_size, 138 | device=device) 139 | del model_copy 140 | if args.distributed: 141 | dist.barrier() 142 | 143 | # ---------------------------- Build Optimizer ---------------------------- 144 | cfg.grad_accumulate = max(cfg.batch_size_base // args.batch_size, 1) 145 | cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate 146 | optimizer, start_epoch = build_optimizer(cfg, model_without_ddp, args.resume) 147 | 148 | # ---------------------------- Build LR Scheduler ---------------------------- 149 | wp_lr_scheduler = build_wp_lr_scheduler(cfg) 150 | lr_scheduler = build_lr_scheduler(cfg, optimizer, args.resume) 151 | 152 | # ---------------------------- Build DDP model ---------------------------- 153 | if args.distributed: 154 | model = DDP(model, device_ids=[args.gpu]) 155 | model_without_ddp = model.module 156 | 157 | # ---------------------------- Build Evaluator ---------------------------- 158 | evaluator = build_evluator(args, cfg, device) 159 | 160 | # ----------------------- Eval before training ----------------------- 161 | if args.eval_first and distributed_utils.is_main_process(): 162 | evaluator.evaluate(model_without_ddp) 163 | return 164 | 165 | # ----------------------- Training ----------------------- 166 | print("Start training") 167 | best_map = -1. 168 | for epoch in range(start_epoch, cfg.max_epoch): 169 | if args.distributed: 170 | train_loader.batch_sampler.sampler.set_epoch(epoch) 171 | 172 | # Train one epoch 173 | train_one_epoch(cfg, 174 | model, 175 | criterion, 176 | train_loader, 177 | optimizer, 178 | device, 179 | epoch, 180 | args.vis_tgt, 181 | wp_lr_scheduler, 182 | debug=args.debug) 183 | 184 | # LR Scheduler 185 | lr_scheduler.step() 186 | 187 | # Evaluate 188 | if distributed_utils.is_main_process(): 189 | model_eval = model_without_ddp 190 | to_save = False 191 | if (epoch % cfg.eval_epoch) == 0 or (epoch == cfg.max_epoch - 1): 192 | if evaluator is None: 193 | to_save = True 194 | else: 195 | evaluator.evaluate(model_eval) 196 | # Save model 197 | if evaluator.map >= best_map: 198 | best_map = evaluator.map 199 | to_save = True 200 | 201 | if to_save: 202 | # save model 203 | print('Saving state, epoch:', epoch) 204 | torch.save({'model': model_eval.state_dict(), 205 | 'optimizer': optimizer.state_dict(), 206 | 'lr_scheduler': lr_scheduler.state_dict(), 207 | 'mAP': round(best_map*100, 1), 208 | 'epoch': epoch, 209 | 'args': args}, 210 | os.path.join(path_to_save, '{}_best.pth'.format(args.model))) 211 | if args.distributed: 212 | dist.barrier() 213 | 214 | if args.debug: 215 | print("For debug mode, we only train the model with 1 epoch.") 216 | exit(0) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() 221 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import time 4 | import numpy as np 5 | import imageio 6 | import argparse 7 | from PIL import Image 8 | 9 | import torch 10 | 11 | # load transform 12 | from datasets import coco_labels, build_transform 13 | 14 | # load some utils 15 | from utils.misc import load_weight 16 | from utils.vis_tools import visualize 17 | 18 | from config import build_config 19 | from models.detectors import build_model 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='General Object Detection Demo') 24 | # Basic 25 | parser.add_argument('--mode', default='image', 26 | type=str, help='Use the data from image, video or camera') 27 | parser.add_argument('--cuda', action='store_true', default=False, 28 | help='Use cuda') 29 | parser.add_argument('--path_to_img', default='./dataset/demo/images/', 30 | type=str, help='The path to image files') 31 | parser.add_argument('--path_to_vid', default='dataset/demo/videos/', 32 | type=str, help='The path to video files') 33 | parser.add_argument('--path_to_save', default='det_results/demos/', 34 | type=str, help='The path to save the detection results') 35 | parser.add_argument('-vt', '--visual_threshold', default=0.3, type=float, 36 | help='Final confidence threshold') 37 | parser.add_argument('--show', action='store_true', default=False, 38 | help='show visualization') 39 | parser.add_argument('--gif', action='store_true', default=False, 40 | help='generate gif.') 41 | # Model 42 | parser.add_argument('-m', '--model', default='fcos_r18_1x', type=str, 43 | help='build detector') 44 | parser.add_argument('-nc', '--num_classes', default=80, type=int, 45 | help='number of classes.') 46 | parser.add_argument('--weight', default=None, 47 | type=str, help='Trained state_dict file path to open') 48 | parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float, 49 | help='confidence threshold') 50 | parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float, 51 | help='NMS threshold') 52 | parser.add_argument('--topk', default=100, type=int, 53 | help='topk candidates for testing') 54 | parser.add_argument("--deploy", action="store_true", default=False, 55 | help="deploy mode or not") 56 | parser.add_argument('--fuse_conv_bn', action='store_true', default=False, 57 | help='fuse Conv & BN') 58 | 59 | return parser.parse_args() 60 | 61 | 62 | def detect(args, model, device, transform, class_names, class_colors): 63 | # path to save 64 | save_path = os.path.join(args.path_to_save, args.mode) 65 | os.makedirs(save_path, exist_ok=True) 66 | 67 | # ------------------------- Camera ---------------------------- 68 | if args.mode == 'camera': 69 | print('use camera !!!') 70 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 71 | save_size = (640, 480) 72 | cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time())) 73 | save_video_name = os.path.join(save_path, cur_time+'.avi') 74 | fps = 15.0 75 | out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size) 76 | print(save_video_name) 77 | image_list = [] 78 | 79 | cap = cv2.VideoCapture(0, cv2.CAP_DSHOW) 80 | while True: 81 | ret, frame = cap.read() 82 | if ret: 83 | if cv2.waitKey(1) == ord('q'): 84 | break 85 | orig_h, orig_w, _ = frame.shape 86 | 87 | # to PIL 88 | image = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)) 89 | 90 | # prepare 91 | x = transform(image)[0] 92 | x = x.unsqueeze(0).to(device) 93 | 94 | # Inference 95 | t0 = time.time() 96 | bboxes, scores, labels = model(x) 97 | print("Infer. time: {}".format(time.time() - t0, "s")) 98 | 99 | # Rescale bboxes 100 | bboxes[..., 0::2] *= orig_w 101 | bboxes[..., 1::2] *= orig_h 102 | 103 | # vis detection 104 | frame_vis = visualize(frame, bboxes, scores, labels, args.visual_threshold, class_colors, class_names) 105 | frame_resized = cv2.resize(frame_vis, save_size) 106 | out.write(frame_resized) 107 | 108 | if args.gif: 109 | gif_resized = cv2.resize(frame, (640, 480)) 110 | gif_resized_rgb = gif_resized[..., (2, 1, 0)] 111 | image_list.append(gif_resized_rgb) 112 | 113 | if args.show: 114 | cv2.imshow('detection', frame_resized) 115 | cv2.waitKey(1) 116 | else: 117 | break 118 | cap.release() 119 | out.release() 120 | cv2.destroyAllWindows() 121 | 122 | # generate GIF 123 | if args.gif: 124 | save_gif_path = os.path.join(save_path, 'gif_files') 125 | os.makedirs(save_gif_path, exist_ok=True) 126 | save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time)) 127 | print('generating GIF ...') 128 | imageio.mimsave(save_gif_name, image_list, fps=fps) 129 | print('GIF done: {}'.format(save_gif_name)) 130 | 131 | # ------------------------- Video --------------------------- 132 | elif args.mode == 'video': 133 | video = cv2.VideoCapture(args.path_to_vid) 134 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 135 | save_size = (640, 480) 136 | cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time())) 137 | save_video_name = os.path.join(save_path, cur_time+'.avi') 138 | fps = 15.0 139 | out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size) 140 | print(save_video_name) 141 | image_list = [] 142 | 143 | while(True): 144 | ret, frame = video.read() 145 | 146 | if ret: 147 | # ------------------------- Detection --------------------------- 148 | orig_h, orig_w, _ = frame.shape 149 | 150 | # to PIL 151 | image = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)) 152 | 153 | # prepare 154 | x = transform(image)[0] 155 | x = x.unsqueeze(0).to(device) 156 | 157 | # Inference 158 | t0 = time.time() 159 | bboxes, scores, labels = model(x) 160 | print("Infer. time: {}".format(time.time() - t0, "s")) 161 | 162 | # Rescale bboxes 163 | bboxes[..., 0::2] *= orig_w 164 | bboxes[..., 1::2] *= orig_h 165 | 166 | # vis detection 167 | frame_vis = visualize(frame, bboxes, scores, labels, args.visual_threshold, class_colors, class_names) 168 | frame_resized = cv2.resize(frame_vis, save_size) 169 | out.write(frame_resized) 170 | 171 | if args.gif: 172 | gif_resized = cv2.resize(frame, (640, 480)) 173 | gif_resized_rgb = gif_resized[..., (2, 1, 0)] 174 | image_list.append(gif_resized_rgb) 175 | 176 | if args.show: 177 | cv2.imshow('detection', frame_resized) 178 | cv2.waitKey(1) 179 | else: 180 | break 181 | video.release() 182 | out.release() 183 | cv2.destroyAllWindows() 184 | 185 | # generate GIF 186 | if args.gif: 187 | save_gif_path = os.path.join(save_path, 'gif_files') 188 | os.makedirs(save_gif_path, exist_ok=True) 189 | save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time)) 190 | print('generating GIF ...') 191 | imageio.mimsave(save_gif_name, image_list, fps=fps) 192 | print('GIF done: {}'.format(save_gif_name)) 193 | 194 | # ------------------------- Image ---------------------------- 195 | elif args.mode == 'image': 196 | for i, img_id in enumerate(os.listdir(args.path_to_img)): 197 | cv2_image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR) 198 | orig_h, orig_w, _ = cv2_image.shape 199 | 200 | # to PIL 201 | image = Image.fromarray(cv2.cvtColor(cv2_image,cv2.COLOR_BGR2RGB)) 202 | 203 | # prepare 204 | x = transform(image)[0] 205 | x = x.unsqueeze(0).to(device) 206 | 207 | # Inference 208 | t0 = time.time() 209 | bboxes, scores, labels = model(x) 210 | print("Infer. time: {}".format(time.time() - t0, "s")) 211 | 212 | # Rescale bboxes 213 | bboxes[..., 0::2] *= orig_w 214 | bboxes[..., 1::2] *= orig_h 215 | 216 | # vis detection 217 | img_processed = visualize(cv2_image, bboxes, scores, labels, args.visual_threshold, class_colors, class_names) 218 | cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed) 219 | if args.show: 220 | cv2.imshow('detection', img_processed) 221 | cv2.waitKey(0) 222 | 223 | 224 | def run(): 225 | args = parse_args() 226 | # cuda 227 | if args.cuda: 228 | print('use cuda') 229 | device = torch.device("cuda") 230 | else: 231 | device = torch.device("cpu") 232 | 233 | # Dataset & Model Config 234 | cfg = build_config(args) 235 | 236 | # Transform 237 | transform = build_transform(cfg, is_train=False) 238 | 239 | np.random.seed(0) 240 | class_colors = [(np.random.randint(255), 241 | np.random.randint(255), 242 | np.random.randint(255)) 243 | for _ in range(args.num_classes)] 244 | 245 | # Model 246 | model = build_model(args, cfg, device, args.num_classes, False) 247 | model = load_weight(model, args.weight, args.fuse_conv_bn) 248 | model.to(device).eval() 249 | 250 | print("================= DETECT =================") 251 | # run 252 | detect(args, model, device, transform, coco_labels, class_colors) 253 | 254 | 255 | if __name__ == '__main__': 256 | run() 257 | -------------------------------------------------------------------------------- /models/basic/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import warnings 4 | from typing import List 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from ..basic.mlp import FFN, MLP 11 | from ..basic.conv import LayerNorm2D, BasicConv 12 | 13 | 14 | # ----------------- Basic Ops ----------------- 15 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 16 | """Copy from timm""" 17 | with torch.no_grad(): 18 | """Copy from timm""" 19 | def norm_cdf(x): 20 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 21 | 22 | if (mean < a - 2 * std) or (mean > b + 2 * std): 23 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 24 | "The distribution of values may be incorrect.", 25 | stacklevel=2) 26 | 27 | l = norm_cdf((a - mean) / std) 28 | u = norm_cdf((b - mean) / std) 29 | 30 | tensor.uniform_(2 * l - 1, 2 * u - 1) 31 | tensor.erfinv_() 32 | 33 | tensor.mul_(std * math.sqrt(2.)) 34 | tensor.add_(mean) 35 | 36 | tensor.clamp_(min=a, max=b) 37 | 38 | return tensor 39 | 40 | def get_clones(module, N): 41 | if N <= 0: 42 | return None 43 | else: 44 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 45 | 46 | def inverse_sigmoid(x, eps=1e-5): 47 | x = x.clamp(min=0., max=1.) 48 | return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps)) 49 | 50 | def build_transformer(cfg, num_classes=80, return_intermediate=False): 51 | if cfg['transformer'] == 'plain_detr_transformer': 52 | return PlainDETRTransformer(d_model = cfg['hidden_dim'], 53 | num_heads = cfg['de_num_heads'], 54 | ffn_dim = cfg['de_ffn_dim'], 55 | dropout = cfg['de_dropout'], 56 | act_type = cfg['de_act'], 57 | pre_norm = cfg['de_pre_norm'], 58 | rpe_hidden_dim = cfg['rpe_hidden_dim'], 59 | feature_stride = cfg['out_stride'], 60 | num_layers = cfg['de_num_layers'], 61 | return_intermediate = return_intermediate, 62 | use_checkpoint = cfg['use_checkpoint'], 63 | num_queries_one2one = cfg['num_queries_one2one'], 64 | num_queries_one2many = cfg['num_queries_one2many'], 65 | proposal_feature_levels = cfg['proposal_feature_levels'], 66 | proposal_in_stride = cfg['out_stride'], 67 | proposal_tgt_strides = cfg['proposal_tgt_strides'], 68 | ) 69 | elif cfg['transformer'] == 'rtdetr_transformer': 70 | return RTDETRTransformer(in_dims = cfg['backbone_feat_dims'], 71 | hidden_dim = cfg['hidden_dim'], 72 | strides = cfg['out_stride'], 73 | num_classes = num_classes, 74 | num_queries = cfg['num_queries'], 75 | num_heads = cfg['de_num_heads'], 76 | num_layers = cfg['de_num_layers'], 77 | num_levels = 3, 78 | num_points = cfg['de_num_points'], 79 | ffn_dim = cfg['de_ffn_dim'], 80 | dropout = cfg['de_dropout'], 81 | act_type = cfg['de_act'], 82 | pre_norm = cfg['de_pre_norm'], 83 | return_intermediate = return_intermediate, 84 | num_denoising = cfg['dn_num_denoising'], 85 | label_noise_ratio = cfg['dn_label_noise_ratio'], 86 | box_noise_scale = cfg['dn_box_noise_scale'], 87 | learnt_init_query = cfg['learnt_init_query'], 88 | ) 89 | 90 | 91 | # ----------------- Transformer Encoder ----------------- 92 | class TransformerEncoderLayer(nn.Module): 93 | def __init__(self, 94 | d_model :int = 256, 95 | num_heads :int = 8, 96 | ffn_dim :int = 1024, 97 | dropout :float = 0.1, 98 | act_type :str = "relu", 99 | pre_norm :bool = False, 100 | ): 101 | super().__init__() 102 | # ----------- Basic parameters ----------- 103 | self.d_model = d_model 104 | self.num_heads = num_heads 105 | self.ffn_dim = ffn_dim 106 | self.dropout = dropout 107 | self.act_type = act_type 108 | self.pre_norm = pre_norm 109 | # ----------- Basic parameters ----------- 110 | # Multi-head Self-Attn 111 | self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True) 112 | self.dropout = nn.Dropout(dropout) 113 | self.norm = nn.LayerNorm(d_model) 114 | 115 | # Feedforwaed Network 116 | self.ffn = FFN(d_model, ffn_dim, dropout, act_type) 117 | 118 | def with_pos_embed(self, tensor, pos): 119 | return tensor if pos is None else tensor + pos 120 | 121 | def forward_pre_norm(self, src, pos_embed): 122 | """ 123 | Input: 124 | src: [torch.Tensor] -> [B, N, C] 125 | pos_embed: [torch.Tensor] -> [B, N, C] 126 | Output: 127 | src: [torch.Tensor] -> [B, N, C] 128 | """ 129 | src = self.norm(src) 130 | q = k = self.with_pos_embed(src, pos_embed) 131 | 132 | # -------------- MHSA -------------- 133 | src2 = self.self_attn(q, k, value=src)[0] 134 | src = src + self.dropout(src2) 135 | 136 | # -------------- FFN -------------- 137 | src = self.ffn(src) 138 | 139 | return src 140 | 141 | def forward_post_norm(self, src, pos_embed): 142 | """ 143 | Input: 144 | src: [torch.Tensor] -> [B, N, C] 145 | pos_embed: [torch.Tensor] -> [B, N, C] 146 | Output: 147 | src: [torch.Tensor] -> [B, N, C] 148 | """ 149 | q = k = self.with_pos_embed(src, pos_embed) 150 | 151 | # -------------- MHSA -------------- 152 | src2 = self.self_attn(q, k, value=src)[0] 153 | src = src + self.dropout(src2) 154 | src = self.norm(src) 155 | 156 | # -------------- FFN -------------- 157 | src = self.ffn(src) 158 | 159 | return src 160 | 161 | def forward(self, src, pos_embed): 162 | if self.pre_norm: 163 | return self.forward_pre_norm(src, pos_embed) 164 | else: 165 | return self.forward_post_norm(src, pos_embed) 166 | 167 | class TransformerEncoder(nn.Module): 168 | def __init__(self, 169 | d_model :int = 256, 170 | num_heads :int = 8, 171 | num_layers :int = 1, 172 | ffn_dim :int = 1024, 173 | pe_temperature :float = 10000., 174 | dropout :float = 0.1, 175 | act_type :str = "relu", 176 | pre_norm :bool = False, 177 | ): 178 | super().__init__() 179 | # ----------- Basic parameters ----------- 180 | self.d_model = d_model 181 | self.num_heads = num_heads 182 | self.num_layers = num_layers 183 | self.ffn_dim = ffn_dim 184 | self.dropout = dropout 185 | self.act_type = act_type 186 | self.pre_norm = pre_norm 187 | self.pe_temperature = pe_temperature 188 | self.pos_embed = None 189 | # ----------- Basic parameters ----------- 190 | self.encoder_layers = get_clones( 191 | TransformerEncoderLayer(d_model, num_heads, ffn_dim, dropout, act_type, pre_norm), num_layers) 192 | 193 | def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.): 194 | assert embed_dim % 4 == 0, \ 195 | 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 196 | 197 | # ----------- Check cahed pos_embed ----------- 198 | if self.pos_embed is not None and \ 199 | self.pos_embed.shape[2:] == [h, w]: 200 | return self.pos_embed 201 | 202 | # ----------- Generate grid coords ----------- 203 | grid_w = torch.arange(int(w), dtype=torch.float32) 204 | grid_h = torch.arange(int(h), dtype=torch.float32) 205 | grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W] 206 | 207 | pos_dim = embed_dim // 4 208 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim 209 | omega = 1. / (temperature**omega) 210 | 211 | out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C] 212 | out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C] 213 | 214 | # shape: [1, N, C] 215 | pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :] 216 | pos_embed = pos_embed.to(device) 217 | self.pos_embed = pos_embed 218 | 219 | return pos_embed 220 | 221 | def forward(self, src): 222 | """ 223 | Input: 224 | src: [torch.Tensor] -> [B, C, H, W] 225 | Output: 226 | src: [torch.Tensor] -> [B, C, H, W] 227 | """ 228 | # -------- Transformer encoder -------- 229 | channels, fmp_h, fmp_w = src.shape[1:] 230 | # [B, C, H, W] -> [B, N, C], N=HxW 231 | src_flatten = src.flatten(2).permute(0, 2, 1).contiguous() 232 | memory = src_flatten 233 | 234 | # PosEmbed: [1, N, C] 235 | pos_embed = self.build_2d_sincos_position_embedding( 236 | src.device, fmp_w, fmp_h, channels, self.pe_temperature) 237 | 238 | # Transformer Encoder layer 239 | for encoder in self.encoder_layers: 240 | memory = encoder(memory, pos_embed=pos_embed) 241 | 242 | # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W] 243 | src = memory.permute(0, 2, 1).contiguous() 244 | src = src.view([-1, channels, fmp_h, fmp_w]) 245 | 246 | return src 247 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/detectors/fcos/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils.box_ops import get_ious 6 | from utils.misc import sigmoid_focal_loss 7 | from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized 8 | 9 | from .matcher import FcosMatcher, SimOtaMatcher 10 | 11 | 12 | class SetCriterion(nn.Module): 13 | def __init__(self, cfg): 14 | super().__init__() 15 | # ------------- Basic parameters ------------- 16 | self.cfg = cfg 17 | self.num_classes = cfg.num_classes 18 | # ------------- Focal loss ------------- 19 | self.alpha = cfg.focal_loss_alpha 20 | self.gamma = cfg.focal_loss_gamma 21 | # ------------- Loss weight ------------- 22 | # ------------- Matcher & Loss weight ------------- 23 | self.matcher_cfg = cfg.matcher_hpy 24 | if cfg.matcher == 'fcos_matcher': 25 | self.weight_dict = {'loss_cls': cfg.loss_cls_weight, 26 | 'loss_reg': cfg.loss_reg_weight, 27 | 'loss_ctn': cfg.loss_ctn_weight} 28 | self.matcher = FcosMatcher(cfg.num_classes, 29 | self.matcher_cfg['center_sampling_radius'], 30 | self.matcher_cfg['object_sizes_of_interest'], 31 | [1., 1., 1., 1.] 32 | ) 33 | elif cfg.matcher == 'simota': 34 | self.weight_dict = {'loss_cls': cfg.loss_cls_weight, 35 | 'loss_reg': cfg.loss_reg_weight} 36 | self.matcher = SimOtaMatcher(cfg.num_classes, 37 | self.matcher_cfg['soft_center_radius'], 38 | self.matcher_cfg['topk_candidates']) 39 | else: 40 | raise NotImplementedError("Unknown matcher: {}.".format(cfg.matcher)) 41 | 42 | def loss_labels(self, pred_cls, tgt_cls, num_boxes=1.0): 43 | """ 44 | pred_cls: (Tensor) [N, C] 45 | tgt_cls: (Tensor) [N, C] 46 | """ 47 | # cls loss: [V, C] 48 | loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma) 49 | 50 | return loss_cls.sum() / num_boxes 51 | 52 | def loss_labels_qfl(self, pred_cls, target, beta=2.0, num_boxes=1.0): 53 | # Quality FocalLoss 54 | """ 55 | pred_cls: (torch.Tensor): [N, C]。 56 | target: (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N) 57 | """ 58 | label, score = target 59 | pred_sigmoid = pred_cls.sigmoid() 60 | scale_factor = pred_sigmoid 61 | zerolabel = scale_factor.new_zeros(pred_cls.shape) 62 | 63 | ce_loss = F.binary_cross_entropy_with_logits( 64 | pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta) 65 | 66 | bg_class_ind = pred_cls.shape[-1] 67 | pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1) 68 | if pos.shape[0] > 0: 69 | pos_label = label[pos].long() 70 | 71 | scale_factor = score[pos] - pred_sigmoid[pos, pos_label] 72 | 73 | ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits( 74 | pred_cls[pos, pos_label], score[pos], 75 | reduction='none') * scale_factor.abs().pow(beta) 76 | 77 | return ce_loss.sum() / num_boxes 78 | 79 | def loss_bboxes_ltrb(self, pred_delta, tgt_delta, bbox_quality=None, num_boxes=1.0): 80 | """ 81 | pred_box: (Tensor) [N, 4] 82 | tgt_box: (Tensor) [N, 4] 83 | """ 84 | pred_delta = torch.cat((-pred_delta[..., :2], pred_delta[..., 2:]), dim=-1) 85 | tgt_delta = torch.cat((-tgt_delta[..., :2], tgt_delta[..., 2:]), dim=-1) 86 | 87 | eps = torch.finfo(torch.float32).eps 88 | 89 | pred_area = (pred_delta[..., 2] - pred_delta[..., 0]).clamp_(min=0) \ 90 | * (pred_delta[..., 3] - pred_delta[..., 1]).clamp_(min=0) 91 | tgt_area = (tgt_delta[..., 2] - tgt_delta[..., 0]).clamp_(min=0) \ 92 | * (tgt_delta[..., 3] - tgt_delta[..., 1]).clamp_(min=0) 93 | 94 | w_intersect = (torch.min(pred_delta[..., 2], tgt_delta[..., 2]) 95 | - torch.max(pred_delta[..., 0], tgt_delta[..., 0])).clamp_(min=0) 96 | h_intersect = (torch.min(pred_delta[..., 3], tgt_delta[..., 3]) 97 | - torch.max(pred_delta[..., 1], tgt_delta[..., 1])).clamp_(min=0) 98 | 99 | area_intersect = w_intersect * h_intersect 100 | area_union = tgt_area + pred_area - area_intersect 101 | ious = area_intersect / area_union.clamp(min=eps) 102 | 103 | # giou 104 | g_w_intersect = torch.max(pred_delta[..., 2], tgt_delta[..., 2]) \ 105 | - torch.min(pred_delta[..., 0], tgt_delta[..., 0]) 106 | g_h_intersect = torch.max(pred_delta[..., 3], tgt_delta[..., 3]) \ 107 | - torch.min(pred_delta[..., 1], tgt_delta[..., 1]) 108 | ac_uion = g_w_intersect * g_h_intersect 109 | gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps) 110 | loss_box = 1 - gious 111 | 112 | if bbox_quality is not None: 113 | loss_box = loss_box * bbox_quality.view(loss_box.size()) 114 | 115 | return loss_box.sum() / num_boxes 116 | 117 | def loss_bboxes_xyxy(self, pred_box, gt_box, num_boxes=1.0): 118 | ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou') 119 | loss_box = 1.0 - ious 120 | 121 | return loss_box.sum() / num_boxes 122 | 123 | def fcos_loss(self, outputs, targets): 124 | """ 125 | outputs['pred_cls']: (Tensor) [B, M, C] 126 | outputs['pred_reg']: (Tensor) [B, M, 4] 127 | outputs['pred_ctn']: (Tensor) [B, M, 1] 128 | outputs['strides']: (List) [8, 16, 32, ...] stride of the model output 129 | targets: (List) [dict{'boxes': [...], 130 | 'labels': [...], 131 | 'orig_size': ...}, ...] 132 | """ 133 | # -------------------- Pre-process -------------------- 134 | device = outputs['pred_cls'][0].device 135 | fpn_strides = outputs['strides'] 136 | anchors = outputs['anchors'] 137 | pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes) 138 | pred_delta = torch.cat(outputs['pred_reg'], dim=1).view(-1, 4) 139 | pred_ctn = torch.cat(outputs['pred_ctn'], dim=1).view(-1, 1) 140 | masks = ~torch.cat(outputs['mask'], dim=1).view(-1) 141 | 142 | # -------------------- Label Assignment -------------------- 143 | gt_classes, gt_deltas, gt_centerness = self.matcher(fpn_strides, anchors, targets) 144 | gt_classes = gt_classes.flatten().to(device) 145 | gt_deltas = gt_deltas.view(-1, 4).to(device) 146 | gt_centerness = gt_centerness.view(-1, 1).to(device) 147 | 148 | foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) 149 | num_foreground = foreground_idxs.sum() 150 | 151 | if is_dist_avail_and_initialized(): 152 | torch.distributed.all_reduce(num_foreground) 153 | num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item() 154 | 155 | num_foreground_centerness = gt_centerness[foreground_idxs].sum() 156 | if is_dist_avail_and_initialized(): 157 | torch.distributed.all_reduce(num_foreground_centerness) 158 | num_targets = torch.clamp(num_foreground_centerness / get_world_size(), min=1).item() 159 | 160 | # -------------------- classification loss -------------------- 161 | gt_classes_target = torch.zeros_like(pred_cls) 162 | gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 163 | valid_idxs = (gt_classes >= 0) & masks 164 | loss_labels = self.loss_labels( 165 | pred_cls[valid_idxs], gt_classes_target[valid_idxs], num_foreground) 166 | 167 | # -------------------- regression loss -------------------- 168 | loss_bboxes = self.loss_bboxes_ltrb( 169 | pred_delta[foreground_idxs], gt_deltas[foreground_idxs], gt_centerness[foreground_idxs], num_targets) 170 | 171 | # -------------------- centerness loss -------------------- 172 | loss_centerness = F.binary_cross_entropy_with_logits( 173 | pred_ctn[foreground_idxs], gt_centerness[foreground_idxs], reduction='none') 174 | loss_centerness = loss_centerness.sum() / num_foreground 175 | 176 | loss_dict = dict( 177 | loss_cls = loss_labels, 178 | loss_reg = loss_bboxes, 179 | loss_ctn = loss_centerness, 180 | ) 181 | 182 | return loss_dict 183 | 184 | def ota_loss(self, outputs, targets): 185 | """ 186 | outputs['pred_cls']: (Tensor) [B, M, C] 187 | outputs['pred_reg']: (Tensor) [B, M, 4] 188 | outputs['pred_box']: (Tensor) [B, M, 4] 189 | outputs['strides']: (List) [8, 16, 32, ...] stride of the model output 190 | targets: (List) [dict{'boxes': [...], 191 | 'labels': [...], 192 | 'orig_size': ...}, ...] 193 | """ 194 | # -------------------- Pre-process -------------------- 195 | bs = outputs['pred_cls'][0].shape[0] 196 | device = outputs['pred_cls'][0].device 197 | fpn_strides = outputs['strides'] 198 | anchors = outputs['anchors'] 199 | # preds: [B, M, C] 200 | # preds: [B, M, C] 201 | cls_preds = torch.cat(outputs['pred_cls'], dim=1) 202 | box_preds = torch.cat(outputs['pred_box'], dim=1) 203 | masks = ~torch.cat(outputs['mask'], dim=1).view(-1) 204 | 205 | # -------------------- Label Assignment -------------------- 206 | cls_targets = [] 207 | box_targets = [] 208 | assign_metrics = [] 209 | for batch_idx in range(bs): 210 | tgt_labels = targets[batch_idx]["labels"].to(device) # [N,] 211 | tgt_bboxes = targets[batch_idx]["boxes"].to(device) # [N, 4] 212 | # refine target 213 | tgt_boxes_wh = tgt_bboxes[..., 2:] - tgt_bboxes[..., :2] 214 | min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0] 215 | keep = (min_tgt_size >= 8) 216 | tgt_bboxes = tgt_bboxes[keep] 217 | tgt_labels = tgt_labels[keep] 218 | # label assignment 219 | assigned_result = self.matcher(fpn_strides=fpn_strides, 220 | anchors=anchors, 221 | pred_cls=cls_preds[batch_idx].detach(), 222 | pred_box=box_preds[batch_idx].detach(), 223 | gt_labels=tgt_labels, 224 | gt_bboxes=tgt_bboxes 225 | ) 226 | cls_targets.append(assigned_result['assigned_labels']) 227 | box_targets.append(assigned_result['assigned_bboxes']) 228 | assign_metrics.append(assigned_result['assign_metrics']) 229 | 230 | # List[B, M, C] -> Tensor[BM, C] 231 | cls_targets = torch.cat(cls_targets, dim=0) 232 | box_targets = torch.cat(box_targets, dim=0) 233 | assign_metrics = torch.cat(assign_metrics, dim=0) 234 | 235 | valid_idxs = (cls_targets >= 0) & masks 236 | foreground_idxs = (cls_targets >= 0) & (cls_targets != self.num_classes) 237 | num_fgs = assign_metrics.sum() 238 | 239 | if is_dist_avail_and_initialized(): 240 | torch.distributed.all_reduce(num_fgs) 241 | num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item() 242 | 243 | # -------------------- classification loss -------------------- 244 | cls_preds = cls_preds.view(-1, self.num_classes)[valid_idxs] 245 | qfl_targets = (cls_targets[valid_idxs], assign_metrics[valid_idxs]) 246 | loss_labels = self.loss_labels_qfl(cls_preds, qfl_targets, 2.0, num_fgs) 247 | 248 | # -------------------- regression loss -------------------- 249 | box_preds_pos = box_preds.view(-1, 4)[foreground_idxs] 250 | box_targets_pos = box_targets[foreground_idxs] 251 | loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs) 252 | 253 | loss_dict = dict( 254 | loss_cls = loss_labels, 255 | loss_reg = loss_bboxes, 256 | ) 257 | 258 | return loss_dict 259 | 260 | def forward(self, outputs, targets): 261 | """ 262 | outputs['pred_cls']: (Tensor) [B, M, C] 263 | outputs['pred_reg']: (Tensor) [B, M, 4] 264 | outputs['pred_ctn']: (Tensor) [B, M, 1] 265 | outputs['strides']: (List) [8, 16, 32, ...] stride of the model output 266 | targets: (List) [dict{'boxes': [...], 267 | 'labels': [...], 268 | 'orig_size': ...}, ...] 269 | """ 270 | if self.cfg.matcher == "fcos_matcher": 271 | return self.fcos_loss(outputs, targets) 272 | elif self.cfg.matcher == "simota": 273 | return self.ota_loss(outputs, targets) 274 | else: 275 | raise NotImplementedError 276 | 277 | 278 | if __name__ == "__main__": 279 | pass 280 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Transforms and data augmentation for both image + bbox. 4 | """ 5 | import PIL 6 | import random 7 | 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as T 11 | import torchvision.transforms.functional as F 12 | 13 | 14 | # ----------------- Basic transform functions ----------------- 15 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 16 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 17 | 18 | def crop(image, target, region): 19 | cropped_image = F.crop(image, *region) 20 | 21 | target = target.copy() 22 | i, j, h, w = region 23 | 24 | # should we do something wrt the original size? 25 | target["size"] = torch.tensor([h, w]) 26 | 27 | fields = ["labels", "area", "iscrowd"] 28 | 29 | if "boxes" in target: 30 | boxes = target["boxes"] 31 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 32 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 33 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 34 | cropped_boxes = cropped_boxes.clamp(min=0) 35 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 36 | target["boxes"] = cropped_boxes.reshape(-1, 4) 37 | target["area"] = area 38 | fields.append("boxes") 39 | 40 | if "masks" in target: 41 | # FIXME should we update the area here if there are no boxes? 42 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 43 | fields.append("masks") 44 | 45 | # remove elements for which the boxes or masks that have zero area 46 | if "boxes" in target or "masks" in target: 47 | # favor boxes selection when defining which elements to keep 48 | # this is compatible with previous implementation 49 | if "boxes" in target: 50 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 51 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 52 | else: 53 | keep = target['masks'].flatten(1).any(1) 54 | 55 | for field in fields: 56 | target[field] = target[field][keep] 57 | 58 | return cropped_image, target 59 | 60 | def hflip(image, target): 61 | flipped_image = F.hflip(image) 62 | 63 | w, h = image.size 64 | 65 | target = target.copy() 66 | if "boxes" in target: 67 | boxes = target["boxes"] 68 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 69 | target["boxes"] = boxes 70 | 71 | if "masks" in target: 72 | target['masks'] = target['masks'].flip(-1) 73 | 74 | return flipped_image, target 75 | 76 | def resize(image, target, size, max_size=None): 77 | # size can be min_size (scalar) or (w, h) tuple 78 | 79 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 80 | w, h = image_size 81 | if max_size is not None: 82 | min_original_size = float(min((w, h))) 83 | max_original_size = float(max((w, h))) 84 | if max_original_size / min_original_size * size > max_size: 85 | size = int(round(max_size * min_original_size / max_original_size)) 86 | 87 | if (w <= h and w == size) or (h <= w and h == size): 88 | return (h, w) 89 | 90 | if w < h: 91 | ow = size 92 | oh = int(size * h / w) 93 | else: 94 | oh = size 95 | ow = int(size * w / h) 96 | 97 | return (oh, ow) 98 | 99 | def get_size(image_size, size, max_size=None): 100 | if isinstance(size, (list, tuple)): 101 | return size[::-1] 102 | else: 103 | return get_size_with_aspect_ratio(image_size, size, max_size) 104 | 105 | size = get_size(image.size, size, max_size) 106 | rescaled_image = F.resize(image, size) 107 | 108 | if target is None: 109 | return rescaled_image, None 110 | 111 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 112 | ratio_width, ratio_height = ratios 113 | 114 | target = target.copy() 115 | if "boxes" in target: 116 | boxes = target["boxes"] 117 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 118 | target["boxes"] = scaled_boxes 119 | 120 | if "area" in target: 121 | area = target["area"] 122 | scaled_area = area * (ratio_width * ratio_height) 123 | target["area"] = scaled_area 124 | 125 | h, w = size 126 | target["size"] = torch.tensor([h, w]) 127 | 128 | if "masks" in target: 129 | target['masks'] = interpolate( 130 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 131 | 132 | return rescaled_image, target 133 | 134 | def pad(image, target, padding): 135 | # assumes that we only pad on the bottom right corners 136 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 137 | if target is None: 138 | return padded_image, None 139 | target = target.copy() 140 | # should we do something wrt the original size? 141 | target["size"] = torch.tensor(padded_image.size[::-1]) 142 | if "masks" in target: 143 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 144 | return padded_image, target 145 | 146 | 147 | # ----------------- Basic transform ----------------- 148 | class RandomCrop(object): 149 | def __init__(self, size): 150 | self.size = size 151 | 152 | def __call__(self, img, target=None): 153 | region = T.RandomCrop.get_params(img, self.size) 154 | return crop(img, target, region) 155 | 156 | class RandomSizeCrop(object): 157 | def __init__(self, min_size: int, max_size: int): 158 | self.min_size = min_size 159 | self.max_size = max_size 160 | 161 | def __call__(self, img: PIL.Image.Image, target: dict = None): 162 | w = random.randint(self.min_size, min(img.width, self.max_size)) 163 | h = random.randint(self.min_size, min(img.height, self.max_size)) 164 | region = T.RandomCrop.get_params(img, [h, w]) 165 | return crop(img, target, region) 166 | 167 | class RandomHorizontalFlip(object): 168 | def __init__(self, p=0.5): 169 | self.p = p 170 | 171 | def __call__(self, img, target=None): 172 | if random.random() < self.p: 173 | return hflip(img, target) 174 | return img, target 175 | 176 | class RandomResize(object): 177 | def __init__(self, sizes, max_size=None): 178 | assert isinstance(sizes, (list, tuple)) 179 | self.sizes = sizes 180 | self.max_size = max_size 181 | 182 | def __call__(self, img, target=None): 183 | size = random.choice(self.sizes) 184 | return resize(img, target, size, self.max_size) 185 | 186 | class RandomShift(object): 187 | def __init__(self, p=0.5, max_shift=32): 188 | self.p = p 189 | self.max_shift = max_shift 190 | 191 | def __call__(self, image, target=None): 192 | if random.random() < self.p: 193 | img_h, img_w = image.height, image.width 194 | shift_x = random.randint(-self.max_shift, self.max_shift) 195 | shift_y = random.randint(-self.max_shift, self.max_shift) 196 | shifted_image = F.affine(image, translate=[shift_x, shift_y], angle=0, scale=1.0, shear=0) 197 | 198 | target = target.copy() 199 | target["boxes"][..., [0, 2]] += shift_x 200 | target["boxes"][..., [1, 3]] += shift_y 201 | target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]].clip(0, img_w) 202 | target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]].clip(0, img_h) 203 | 204 | return shifted_image, target 205 | 206 | return image, target 207 | 208 | class RandomSelect(object): 209 | """ 210 | Randomly selects between transforms1 and transforms2, 211 | with probability p for transforms1 and (1 - p) for transforms2 212 | """ 213 | def __init__(self, transforms1, transforms2, p=0.5): 214 | self.transforms1 = transforms1 215 | self.transforms2 = transforms2 216 | self.p = p 217 | 218 | def __call__(self, img, target=None): 219 | if random.random() < self.p: 220 | return self.transforms1(img, target) 221 | return self.transforms2(img, target) 222 | 223 | class ToTensor(object): 224 | def __call__(self, img, target=None): 225 | return F.to_tensor(img), target 226 | 227 | class Normalize(object): 228 | def __init__(self, mean, std, normalize_coords=False): 229 | self.mean = mean 230 | self.std = std 231 | self.normalize_coords = normalize_coords 232 | 233 | def __call__(self, image, target=None): 234 | image = F.normalize(image, mean=self.mean, std=self.std) 235 | if target is None: 236 | return image, None 237 | if self.normalize_coords: 238 | target = target.copy() 239 | h, w = image.shape[-2:] 240 | if "boxes" in target: 241 | boxes = target["boxes"] 242 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 243 | target["boxes"] = boxes 244 | return image, target 245 | 246 | class RefineBBox(object): 247 | def __init__(self, min_box_size=1): 248 | self.min_box_size = min_box_size 249 | 250 | def __call__(self, img, target): 251 | boxes = target["boxes"].clone() 252 | labels = target["labels"].clone() 253 | 254 | tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2] 255 | min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0] 256 | 257 | keep = (min_tgt_size >= self.min_box_size) 258 | 259 | target["boxes"] = boxes[keep] 260 | target["labels"] = labels[keep] 261 | 262 | return img, target 263 | 264 | class ConvertBoxFormat(object): 265 | def __init__(self, box_format="xyxy"): 266 | self.box_format = box_format 267 | 268 | def __call__(self, image, target=None): 269 | # convert box format 270 | if self.box_format == "xyxy" or target is None: 271 | pass 272 | elif self.box_format == "xywh": 273 | target = target.copy() 274 | if "boxes" in target: 275 | boxes_xyxy = target["boxes"] 276 | boxes_xywh = torch.zeros_like(boxes_xyxy) 277 | boxes_xywh[..., :2] = (boxes_xyxy[..., :2] + boxes_xyxy[..., 2:]) * 0.5 # cxcy 278 | boxes_xywh[..., 2:] = boxes_xyxy[..., 2:] - boxes_xyxy[..., :2] # bwbh 279 | target["boxes"] = boxes_xywh 280 | else: 281 | raise NotImplementedError("Unknown box format: {}".format(self.box_format)) 282 | 283 | return image, target 284 | 285 | class Compose(object): 286 | def __init__(self, transforms): 287 | self.transforms = transforms 288 | 289 | def __call__(self, image, target=None): 290 | for t in self.transforms: 291 | image, target = t(image, target) 292 | return image, target 293 | 294 | def __repr__(self): 295 | format_string = self.__class__.__name__ + "(" 296 | for t in self.transforms: 297 | format_string += "\n" 298 | format_string += " {0}".format(t) 299 | format_string += "\n)" 300 | return format_string 301 | 302 | 303 | # build transforms 304 | def build_transform(cfg, is_train=False): 305 | # ---------------- Transform for Training ---------------- 306 | if is_train: 307 | transforms = [] 308 | trans_config = cfg.trans_config 309 | # build transform 310 | if not cfg.detr_style: 311 | for t in trans_config: 312 | if t['name'] == 'RandomHFlip': 313 | transforms.append(RandomHorizontalFlip()) 314 | if t['name'] == 'RandomResize': 315 | transforms.append(RandomResize(cfg.train_min_size, max_size=cfg.train_max_size)) 316 | if t['name'] == 'RandomSizeCrop': 317 | transforms.append(RandomSizeCrop(t['min_crop_size'], max_size=t['max_crop_size'])) 318 | if t['name'] == 'RandomShift': 319 | transforms.append(RandomShift(max_shift=t['max_shift'])) 320 | if t['name'] == 'RefineBBox': 321 | transforms.append(RefineBBox(min_box_size=t['min_box_size'])) 322 | transforms.extend([ 323 | ToTensor(), 324 | Normalize(cfg.pixel_mean, cfg.pixel_std, cfg.normalize_coords), 325 | ConvertBoxFormat(cfg.box_format) 326 | ]) 327 | # build transform for DETR-style detector 328 | else: 329 | transforms = [ 330 | RandomHorizontalFlip(), 331 | RandomSelect( 332 | RandomResize(cfg.train_min_size, max_size=cfg.train_max_size), 333 | Compose([ 334 | RandomResize(cfg.train_min_size2), 335 | RandomSizeCrop(*cfg.random_crop_size), 336 | RandomResize(cfg.train_min_size, max_size=cfg.train_max_size), 337 | ]) 338 | ), 339 | ToTensor(), 340 | Normalize(cfg.pixel_mean, cfg.pixel_std, cfg.normalize_coords), 341 | ConvertBoxFormat(cfg.box_format) 342 | ] 343 | 344 | # ---------------- Transform for Evaluating ---------------- 345 | else: 346 | transforms = [ 347 | RandomResize(cfg.test_min_size, max_size=cfg.test_max_size), 348 | ToTensor(), 349 | Normalize(cfg.pixel_mean, cfg.pixel_std, cfg.normalize_coords), 350 | ConvertBoxFormat(cfg.box_format) 351 | ] 352 | 353 | return Compose(transforms) 354 | -------------------------------------------------------------------------------- /models/head/fcos_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..basic.conv import BasicConv 5 | 6 | 7 | class Scale(nn.Module): 8 | """ 9 | Multiply the output regression range by a learnable constant value 10 | """ 11 | def __init__(self, init_value=1.0): 12 | """ 13 | init_value : initial value for the scalar 14 | """ 15 | super().__init__() 16 | self.scale = nn.Parameter( 17 | torch.tensor(init_value, dtype=torch.float32), 18 | requires_grad=True 19 | ) 20 | 21 | def forward(self, x): 22 | """ 23 | input -> scale * input 24 | """ 25 | return x * self.scale 26 | 27 | class FcosHead(nn.Module): 28 | def __init__(self, cfg, in_dim, out_dim,): 29 | super().__init__() 30 | self.fmp_size = None 31 | # ------------------ Basic parameters ------------------- 32 | self.cfg = cfg 33 | self.in_dim = in_dim 34 | self.stride = cfg.out_stride 35 | self.num_classes = cfg.num_classes 36 | self.num_cls_head = cfg.num_cls_head 37 | self.num_reg_head = cfg.num_reg_head 38 | self.act_type = cfg.head_act 39 | self.norm_type = cfg.head_norm 40 | 41 | # ------------------ Network parameters ------------------- 42 | ## cls head 43 | cls_heads = [] 44 | self.cls_head_dim = out_dim 45 | for i in range(self.num_cls_head): 46 | if i == 0: 47 | cls_heads.append( 48 | BasicConv(in_dim, self.cls_head_dim, 49 | kernel_size=3, padding=1, stride=1, 50 | act_type=self.act_type, norm_type=self.norm_type) 51 | ) 52 | else: 53 | cls_heads.append( 54 | BasicConv(self.cls_head_dim, self.cls_head_dim, 55 | kernel_size=3, padding=1, stride=1, 56 | act_type=self.act_type, norm_type=self.norm_type) 57 | ) 58 | 59 | ## reg head 60 | reg_heads = [] 61 | self.reg_head_dim = out_dim 62 | for i in range(self.num_reg_head): 63 | if i == 0: 64 | reg_heads.append( 65 | BasicConv(in_dim, self.reg_head_dim, 66 | kernel_size=3, padding=1, stride=1, 67 | act_type=self.act_type, norm_type=self.norm_type) 68 | ) 69 | else: 70 | reg_heads.append( 71 | BasicConv(self.reg_head_dim, self.reg_head_dim, 72 | kernel_size=3, padding=1, stride=1, 73 | act_type=self.act_type, norm_type=self.norm_type) 74 | ) 75 | self.cls_heads = nn.Sequential(*cls_heads) 76 | self.reg_heads = nn.Sequential(*reg_heads) 77 | 78 | ## pred layers 79 | self.cls_pred = nn.Conv2d(self.cls_head_dim, cfg.num_classes, kernel_size=3, padding=1) 80 | self.reg_pred = nn.Conv2d(self.reg_head_dim, 4, kernel_size=3, padding=1) 81 | self.ctn_pred = nn.Conv2d(self.reg_head_dim, 1, kernel_size=3, padding=1) 82 | 83 | ## scale layers 84 | self.scales = nn.ModuleList( 85 | Scale() for _ in range(len(self.stride)) 86 | ) 87 | 88 | # init bias 89 | self._init_layers() 90 | 91 | def _init_layers(self): 92 | for module in [self.cls_heads, self.reg_heads, self.cls_pred, self.reg_pred, self.ctn_pred]: 93 | for layer in module.modules(): 94 | if isinstance(layer, nn.Conv2d): 95 | torch.nn.init.normal_(layer.weight, mean=0, std=0.01) 96 | if layer.bias is not None: 97 | torch.nn.init.constant_(layer.bias, 0) 98 | if isinstance(layer, nn.GroupNorm): 99 | torch.nn.init.constant_(layer.weight, 1) 100 | if layer.bias is not None: 101 | torch.nn.init.constant_(layer.bias, 0) 102 | # init the bias of cls pred 103 | init_prob = 0.01 104 | bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob)) 105 | torch.nn.init.constant_(self.cls_pred.bias, bias_value) 106 | 107 | def get_anchors(self, level, fmp_size): 108 | """ 109 | fmp_size: (List) [H, W] 110 | """ 111 | # generate grid cells 112 | fmp_h, fmp_w = fmp_size 113 | anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)]) 114 | # [H, W, 2] -> [HW, 2] 115 | anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5 116 | anchors *= self.stride[level] 117 | 118 | return anchors 119 | 120 | def decode_boxes(self, pred_deltas, anchors): 121 | """ 122 | pred_deltas: (List[Tensor]) [B, M, 4] or [M, 4] (l, t, r, b) 123 | anchors: (List[Tensor]) [1, M, 2] or [M, 2] 124 | """ 125 | # x1 = x_anchor - l, x2 = x_anchor + r 126 | # y1 = y_anchor - t, y2 = y_anchor + b 127 | pred_x1y1 = anchors - pred_deltas[..., :2] 128 | pred_x2y2 = anchors + pred_deltas[..., 2:] 129 | pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1) 130 | 131 | return pred_box 132 | 133 | def forward(self, pyramid_feats, mask=None): 134 | all_masks = [] 135 | all_anchors = [] 136 | all_cls_preds = [] 137 | all_reg_preds = [] 138 | all_box_preds = [] 139 | all_ctn_preds = [] 140 | for level, feat in enumerate(pyramid_feats): 141 | # ------------------- Decoupled head ------------------- 142 | cls_feat = self.cls_heads(feat) 143 | reg_feat = self.reg_heads(feat) 144 | 145 | # ------------------- Generate anchor box ------------------- 146 | B, _, H, W = cls_feat.size() 147 | fmp_size = [H, W] 148 | anchors = self.get_anchors(level, fmp_size) # [M, 4] 149 | anchors = anchors.to(cls_feat.device) 150 | 151 | # ------------------- Predict ------------------- 152 | cls_pred = self.cls_pred(cls_feat) 153 | reg_pred = self.reg_pred(reg_feat) 154 | ctn_pred = self.ctn_pred(reg_feat) 155 | 156 | # ------------------- Process preds ------------------- 157 | ## [B, C, H, W] -> [B, H, W, C] -> [B, M, C] 158 | cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes) 159 | ctn_pred = ctn_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1) 160 | reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4) 161 | reg_pred = nn.functional.relu(self.scales[level](reg_pred)) * self.stride[level] 162 | ## Decode bbox 163 | box_pred = self.decode_boxes(reg_pred, anchors) 164 | ## Adjust mask 165 | if mask is not None: 166 | # [B, H, W] 167 | mask_i = torch.nn.functional.interpolate(mask[None].float(), size=[H, W]).bool()[0] 168 | # [B, H, W] -> [B, M] 169 | mask_i = mask_i.flatten(1) 170 | all_masks.append(mask_i) 171 | 172 | all_anchors.append(anchors) 173 | all_cls_preds.append(cls_pred) 174 | all_reg_preds.append(reg_pred) 175 | all_box_preds.append(box_pred) 176 | all_ctn_preds.append(ctn_pred) 177 | 178 | outputs = {"pred_cls": all_cls_preds, # List [B, M, C] 179 | "pred_reg": all_reg_preds, # List [B, M, 4] 180 | "pred_box": all_box_preds, # List [B, M, 4] 181 | "pred_ctn": all_ctn_preds, # List [B, M, 1] 182 | "anchors": all_anchors, # List [B, M, 2] 183 | "strides": self.stride, 184 | "mask": all_masks} # List [B, M,] 185 | 186 | return outputs 187 | 188 | class FcosRTHead(nn.Module): 189 | def __init__(self, cfg, in_dim, out_dim,): 190 | super().__init__() 191 | self.fmp_size = None 192 | # ------------------ Basic parameters ------------------- 193 | self.cfg = cfg 194 | self.in_dim = in_dim 195 | self.stride = cfg.out_stride 196 | self.num_classes = cfg.num_classes 197 | self.num_cls_head = cfg.num_cls_head 198 | self.num_reg_head = cfg.num_reg_head 199 | self.act_type = cfg.head_act 200 | self.norm_type = cfg.head_norm 201 | 202 | # ------------------ Network parameters ------------------- 203 | ## cls head 204 | cls_heads = [] 205 | self.cls_head_dim = out_dim 206 | for i in range(self.num_cls_head): 207 | if i == 0: 208 | cls_heads.append( 209 | BasicConv(in_dim, self.cls_head_dim, 210 | kernel_size=3, padding=1, stride=1, 211 | act_type=self.act_type, norm_type=self.norm_type) 212 | ) 213 | else: 214 | cls_heads.append( 215 | BasicConv(self.cls_head_dim, self.cls_head_dim, 216 | kernel_size=3, padding=1, stride=1, 217 | act_type=self.act_type, norm_type=self.norm_type) 218 | ) 219 | 220 | ## reg head 221 | reg_heads = [] 222 | self.reg_head_dim = out_dim 223 | for i in range(self.num_reg_head): 224 | if i == 0: 225 | reg_heads.append( 226 | BasicConv(in_dim, self.reg_head_dim, 227 | kernel_size=3, padding=1, stride=1, 228 | act_type=self.act_type, norm_type=self.norm_type) 229 | ) 230 | else: 231 | reg_heads.append( 232 | BasicConv(self.reg_head_dim, self.reg_head_dim, 233 | kernel_size=3, padding=1, stride=1, 234 | act_type=self.act_type, norm_type=self.norm_type) 235 | ) 236 | self.cls_heads = nn.Sequential(*cls_heads) 237 | self.reg_heads = nn.Sequential(*reg_heads) 238 | 239 | ## pred layers 240 | self.cls_pred = nn.Conv2d(self.cls_head_dim, cfg.num_classes, kernel_size=3, padding=1) 241 | self.reg_pred = nn.Conv2d(self.reg_head_dim, 4, kernel_size=3, padding=1) 242 | 243 | # init bias 244 | self._init_layers() 245 | 246 | def _init_layers(self): 247 | for module in [self.cls_heads, self.reg_heads, self.cls_pred, self.reg_pred]: 248 | for layer in module.modules(): 249 | if isinstance(layer, nn.Conv2d): 250 | torch.nn.init.normal_(layer.weight, mean=0, std=0.01) 251 | if layer.bias is not None: 252 | torch.nn.init.constant_(layer.bias, 0) 253 | if isinstance(layer, nn.GroupNorm): 254 | torch.nn.init.constant_(layer.weight, 1) 255 | if layer.bias is not None: 256 | torch.nn.init.constant_(layer.bias, 0) 257 | # init the bias of cls pred 258 | init_prob = 0.01 259 | bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob)) 260 | torch.nn.init.constant_(self.cls_pred.bias, bias_value) 261 | 262 | def get_anchors(self, level, fmp_size): 263 | """ 264 | fmp_size: (List) [H, W] 265 | """ 266 | # generate grid cells 267 | fmp_h, fmp_w = fmp_size 268 | anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)]) 269 | # [H, W, 2] -> [HW, 2] 270 | anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5 271 | anchors *= self.stride[level] 272 | 273 | return anchors 274 | 275 | def decode_boxes(self, pred_deltas, anchors, stride): 276 | """ 277 | pred_deltas: (List[Tensor]) [B, M, 4] or [M, 4] (dx, dy, dw, dh) 278 | anchors: (List[Tensor]) [1, M, 2] or [M, 2] 279 | """ 280 | pred_cxcy = anchors + pred_deltas[..., :2] * stride 281 | pred_bwbh = pred_deltas[..., 2:].exp() * stride 282 | 283 | pred_x1y1 = pred_cxcy - 0.5 * pred_bwbh 284 | pred_x2y2 = pred_cxcy + 0.5 * pred_bwbh 285 | 286 | pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1) 287 | 288 | return pred_box 289 | 290 | def forward(self, pyramid_feats, mask=None): 291 | all_masks = [] 292 | all_anchors = [] 293 | all_cls_preds = [] 294 | all_reg_preds = [] 295 | all_box_preds = [] 296 | for level, feat in enumerate(pyramid_feats): 297 | # ------------------- Decoupled head ------------------- 298 | cls_feat = self.cls_heads(feat) 299 | reg_feat = self.reg_heads(feat) 300 | 301 | # ------------------- Generate anchor box ------------------- 302 | B, _, H, W = cls_feat.size() 303 | fmp_size = [H, W] 304 | anchors = self.get_anchors(level, fmp_size) # [M, 4] 305 | anchors = anchors.to(cls_feat.device) 306 | 307 | # ------------------- Predict ------------------- 308 | cls_pred = self.cls_pred(cls_feat) 309 | reg_pred = self.reg_pred(reg_feat) 310 | 311 | # ------------------- Process preds ------------------- 312 | ## [B, C, H, W] -> [B, H, W, C] -> [B, M, C] 313 | cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes) 314 | reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4) 315 | box_pred = self.decode_boxes(reg_pred, anchors, self.stride[level]) 316 | ## Adjust mask 317 | if mask is not None: 318 | # [B, H, W] 319 | mask_i = torch.nn.functional.interpolate(mask[None].float(), size=[H, W]).bool()[0] 320 | # [B, H, W] -> [B, M] 321 | mask_i = mask_i.flatten(1) 322 | all_masks.append(mask_i) 323 | 324 | all_anchors.append(anchors) 325 | all_cls_preds.append(cls_pred) 326 | all_reg_preds.append(reg_pred) 327 | all_box_preds.append(box_pred) 328 | 329 | outputs = {"pred_cls": all_cls_preds, # List [B, M, C] 330 | "pred_reg": all_reg_preds, # List [B, M, 4] 331 | "pred_box": all_box_preds, # List [B, M, 4] 332 | "anchors": all_anchors, # List [B, M, 2] 333 | "strides": self.stride, 334 | "mask": all_masks} # List [B, M,] 335 | 336 | return outputs 337 | --------------------------------------------------------------------------------