├── cal_rescall ├── __init__.py ├── __pycache__ │ ├── script.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ └── rrc_evaluation_funcs.cpython-37.pyc └── script.py ├── models ├── dcn │ ├── functions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── deform_conv.cpython-37.pyc │ │ │ └── deform_pool.cpython-37.pyc │ │ ├── deform_pool.py │ │ └── deform_conv.py │ ├── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── deform_conv.cpython-37.pyc │ │ │ └── deform_pool.cpython-37.pyc │ │ ├── deform_conv.py │ │ └── deform_pool.py │ ├── make.sh │ ├── setup.py │ ├── __init__.py │ └── src │ │ ├── deform_pool_cuda.cpp │ │ └── deform_pool_cuda_kernel.cu ├── __pycache__ │ ├── DBNet.cpython-37.pyc │ └── __init__.cpython-37.pyc ├── head │ ├── __pycache__ │ │ ├── DB_Head.cpython-37.pyc │ │ ├── FPN_Head.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── seg_detector.cpython-37.pyc │ ├── __init__.py │ ├── DB_Head.py │ ├── FPN_Head.py │ └── seg_detector.py ├── backbone │ ├── __pycache__ │ │ ├── resnet.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ ├── __init__.py │ └── resnet.py ├── __init__.py └── DBNet.py ├── show ├── 1.jpg └── 2.jpg ├── loss ├── __pycache__ │ ├── loss.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── l1_loss.cpython-37.pyc │ ├── dice_loss.cpython-37.pyc │ └── balance_cross_entropy_loss.cpython-37.pyc ├── __init__.py ├── l1_loss.py ├── loss.py ├── balance_cross_entropy_loss.py └── dice_loss.py ├── utils ├── __pycache__ │ ├── tools.cpython-37.pyc │ ├── Logger.cpython-37.pyc │ ├── metrics.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── fuse_model.cpython-37.pyc │ ├── model_eval.cpython-37.pyc │ ├── DB_postprocesss.cpython-37.pyc │ └── set_optimizer.cpython-37.pyc ├── __init__.py ├── set_optimizer.py ├── fuse_model.py ├── metrics.py ├── Logger.py ├── tools.py ├── model_eval.py └── DB_postprocesss.py ├── dataloader ├── __pycache__ │ ├── MakeSegMap.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── dataload.cpython-37.pyc │ ├── MakeBorderMap.cpython-37.pyc │ └── random_thansform.cpython-37.pyc ├── __init__.py ├── MakeSegMap.py ├── dataload.py ├── MakeBorderMap.py └── random_thansform.py ├── pruned ├── __init__.py ├── prune_inference.py ├── get_pruned_model.py ├── train_fintune.py └── prune.py ├── requirement.txt ├── config.yaml ├── README.md ├── README_en.md ├── inference.py └── train.py /cal_rescall/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/dcn/functions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/dcn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/dcn/make.sh: -------------------------------------------------------------------------------- 1 | python3 setup.py build_ext --inplace -------------------------------------------------------------------------------- /show/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/show/1.jpg -------------------------------------------------------------------------------- /show/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/show/2.jpg -------------------------------------------------------------------------------- /loss/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/l1_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/l1_loss.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/DBNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/__pycache__/DBNet.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/Logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/dice_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/dice_loss.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/fuse_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/fuse_model.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/model_eval.cpython-37.pyc -------------------------------------------------------------------------------- /cal_rescall/__pycache__/script.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/cal_rescall/__pycache__/script.cpython-37.pyc -------------------------------------------------------------------------------- /cal_rescall/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/cal_rescall/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/MakeSegMap.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/MakeSegMap.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/dataload.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/dataload.cpython-37.pyc -------------------------------------------------------------------------------- /models/head/__pycache__/DB_Head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/head/__pycache__/DB_Head.cpython-37.pyc -------------------------------------------------------------------------------- /models/head/__pycache__/FPN_Head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/head/__pycache__/FPN_Head.cpython-37.pyc -------------------------------------------------------------------------------- /models/head/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/head/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/DB_postprocesss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/DB_postprocesss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/set_optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/set_optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/backbone/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/MakeBorderMap.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/MakeBorderMap.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/backbone/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/head/__pycache__/seg_detector.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/head/__pycache__/seg_detector.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/random_thansform.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/random_thansform.cpython-37.pyc -------------------------------------------------------------------------------- /models/dcn/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/balance_cross_entropy_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/balance_cross_entropy_loss.cpython-37.pyc -------------------------------------------------------------------------------- /models/dcn/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/dcn/modules/__pycache__/deform_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/modules/__pycache__/deform_conv.cpython-37.pyc -------------------------------------------------------------------------------- /models/dcn/modules/__pycache__/deform_pool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/modules/__pycache__/deform_pool.cpython-37.pyc -------------------------------------------------------------------------------- /cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-37.pyc -------------------------------------------------------------------------------- /models/dcn/functions/__pycache__/deform_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/functions/__pycache__/deform_conv.cpython-37.pyc -------------------------------------------------------------------------------- /models/dcn/functions/__pycache__/deform_pool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/functions/__pycache__/deform_pool.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: __init__.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: __init__.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ -------------------------------------------------------------------------------- /pruned/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: __init__.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: __init__.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: __init__.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ -------------------------------------------------------------------------------- /models/head/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: __init__.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: __init__.py.py 7 | @time: 2020/6/29 21:39 8 | 9 | """ -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | tqdm 3 | tensorboardX 4 | opencv-python==4.1.2.30 5 | anyconfig 6 | munch 7 | scipy 8 | sortedcontainers 9 | shapely 10 | pyclipper 11 | gevent 12 | gevent-websocket 13 | flask 14 | editdistance 15 | scikit-image 16 | imgaug 17 | tabulate -------------------------------------------------------------------------------- /models/dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='deform_conv', 6 | ext_modules=[ 7 | CUDAExtension('deform_conv_cuda', [ 8 | 'src/deform_conv_cuda.cpp', 9 | 'src/deform_conv_cuda_kernel.cu', 10 | ]), 11 | CUDAExtension('deform_pool_cuda', [ 12 | 'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu' 13 | ]), 14 | ], 15 | cmdclass={'build_ext': BuildExtension}) 16 | -------------------------------------------------------------------------------- /models/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions.deform_conv import deform_conv, modulated_deform_conv 2 | from .functions.deform_pool import deform_roi_pooling 3 | from .modules.deform_conv import (DeformConv, ModulatedDeformConv, 4 | DeformConvPack, ModulatedDeformConvPack) 5 | from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, 6 | ModulatedDeformRoIPoolingPack) 7 | 8 | __all__ = [ 9 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 10 | 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', 11 | 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv', 12 | 'deform_roi_pooling' 13 | ] 14 | -------------------------------------------------------------------------------- /utils/set_optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: set_optimizer.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | 11 | def lr_poly(base_lr, epoch, max_epoch=1200, factor=0.9): 12 | return base_lr*((1-float(epoch)/max_epoch)**(factor)) 13 | 14 | def adjust_learning_rate_poly(base_lr, optimizer, epoch, max_epoch=1200, factor=0.9): 15 | lr = lr_poly(base_lr, epoch, max_epoch, factor) 16 | optimizer.param_groups[0]['lr'] = lr 17 | 18 | 19 | def adjust_learning_rate(config, optimizer, epoch,gama = 0.1): 20 | if epoch in config['train']['schedule']: 21 | config['train']['base_lr'] =config['train']['base_lr'] * gama 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = config['train']['base_lr'] -------------------------------------------------------------------------------- /models/DBNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: DBNet.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import torch.nn as nn 11 | from models.head.seg_detector import SegDetector 12 | from models.backbone.resnet import resnet18, resnet50,deformable_resnet50, deformable_resnet18 13 | 14 | class DBNet(nn.Module): 15 | def __init__(self, config,is_train = True): 16 | super(DBNet, self).__init__() 17 | self.backbone = globals().get(config['train']['backbone'])(pretrained=config['train']['pretrained']) 18 | if(is_train is False): 19 | config['train']['adaptive'] = config['test']['adaptive'] 20 | self.decode = SegDetector(headname = config['train']['HeadName'], 21 | in_channels = config['train']['in_channels'], 22 | inner_channels = config['train']['inner_channels'], 23 | k = config['train']['k'], 24 | bias=False, adaptive= config['train']['adaptive'], smooth=False, serial=False) 25 | def forward(self, x): 26 | x = self.backbone(x) 27 | out = self.decode.forward(x) 28 | return out 29 | 30 | 31 | -------------------------------------------------------------------------------- /loss/l1_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MaskL1Loss(nn.Module): 6 | def __init__(self): 7 | super(MaskL1Loss, self).__init__() 8 | 9 | def forward(self, pred: torch.Tensor, gt, mask): 10 | mask_sum = mask.sum() 11 | if mask_sum.item() == 0: 12 | return mask_sum, dict(l1_loss=mask_sum) 13 | else: 14 | loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sum 15 | return loss, dict(l1_loss=loss) 16 | 17 | 18 | class BalanceL1Loss(nn.Module): 19 | def __init__(self, negative_ratio=3.): 20 | super(BalanceL1Loss, self).__init__() 21 | self.negative_ratio = negative_ratio 22 | 23 | def forward(self, pred: torch.Tensor, gt, mask): 24 | ''' 25 | Args: 26 | pred: (N, 1, H, W). 27 | gt: (N, H, W). 28 | mask: (N, H, W). 29 | ''' 30 | loss = torch.abs(pred[:, 0] - gt) 31 | positive = loss * mask 32 | negative = loss * (1 - mask) 33 | positive_count = int(mask.sum()) 34 | negative_count = min( 35 | int((1 - mask).sum()), 36 | int(positive_count * self.negative_ratio)) 37 | negative_loss, _ = torch.topk(negative.view(-1), negative_count) 38 | negative_loss = negative_loss.sum() / negative_count 39 | positive_loss = positive.sum() / positive_count 40 | return positive_loss + negative_loss,\ 41 | dict(l1_loss=positive_loss, nge_l1_loss=negative_loss) 42 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: loss.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import torch.nn as nn 11 | from loss.dice_loss import DiceLoss,dice_loss 12 | from loss.l1_loss import MaskL1Loss 13 | from loss.balance_cross_entropy_loss import BalanceCrossEntropyLoss 14 | 15 | class L1BalanceCELoss(nn.Module): 16 | ''' 17 | Balanced CrossEntropy Loss on `binary`, 18 | MaskL1Loss on `thresh`, 19 | DiceLoss on `thresh_binary`. 20 | Note: The meaning of inputs can be figured out in `SegDetectorLossBuilder`. 21 | ''' 22 | 23 | def __init__(self, eps=1e-6, l1_scale=10, bce_scale=1): 24 | super(L1BalanceCELoss, self).__init__() 25 | # self.dice_loss = DiceLoss(eps=eps) 26 | self.dice_loss = dice_loss 27 | self.l1_loss = MaskL1Loss() 28 | self.bce_loss = BalanceCrossEntropyLoss() 29 | self.l1_scale = l1_scale 30 | self.bce_scale = bce_scale 31 | 32 | def forward(self, pred, batch): 33 | bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask']) 34 | metrics = dict(bce_loss=bce_loss) 35 | if 'thresh' in pred: 36 | l1_loss, l1_metric = self.l1_loss(pred['thresh'], batch['thresh_map'], batch['thresh_mask']) 37 | dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'], batch['mask']) 38 | metrics['thresh_loss'] = dice_loss 39 | loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale 40 | metrics.update(**l1_metric) 41 | else: 42 | loss = bce_loss 43 | return loss, metrics -------------------------------------------------------------------------------- /utils/fuse_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: fuse_model.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | import time 13 | import sys 14 | import numpy as np 15 | import torchvision 16 | import torch.nn.functional as F 17 | 18 | 19 | class DummyModule(nn.Module): 20 | def __init__(self): 21 | super(DummyModule, self).__init__() 22 | 23 | def forward(self, x): 24 | return x 25 | 26 | def fuse(conv, bn): 27 | # *******************conv参数******************** 28 | w = conv.weight 29 | 30 | # ********************BN参数********************* 31 | mean = bn.running_mean 32 | var_sqrt = torch.sqrt(bn.running_var + bn.eps) 33 | gamma = bn.weight 34 | beta = bn.bias 35 | 36 | if conv.bias is not None: 37 | b = conv.bias 38 | else: 39 | b = mean.new_zeros(mean.shape) 40 | 41 | w = w * (gamma / var_sqrt).reshape([conv.out_channels, 1, 1, 1]) 42 | b = (b - mean)/var_sqrt * gamma + beta 43 | 44 | fused_conv = nn.Conv2d(conv.in_channels, 45 | conv.out_channels, 46 | conv.kernel_size, 47 | conv.stride, 48 | conv.padding, 49 | bias=True) 50 | fused_conv.weight = nn.Parameter(w) 51 | fused_conv.bias = nn.Parameter(b) 52 | return fused_conv 53 | 54 | def fuse_module(m): 55 | children = list(m.named_children()) 56 | c = None 57 | cn = None 58 | for name, child in children: 59 | if isinstance(child, nn.BatchNorm2d) and c is not None: 60 | bc = fuse(c, child) 61 | m._modules[cn] = bc 62 | m._modules[name] = DummyModule() 63 | c = None 64 | elif isinstance(child, nn.Conv2d): 65 | c = child 66 | cn = name 67 | else: 68 | fuse_module(child) 69 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | class runningScore(object): 7 | 8 | def __init__(self, n_classes): 9 | self.n_classes = n_classes 10 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 11 | 12 | def _fast_hist(self, label_true, label_pred, n_class): 13 | mask = (label_true >= 0) & (label_true < n_class) 14 | 15 | if np.sum((label_pred[mask] < 0)) > 0: 16 | print (label_pred[label_pred < 0]) 17 | hist = np.bincount( 18 | n_class * label_true[mask].astype(int) + 19 | label_pred[mask], minlength=n_class**2).reshape(n_class, n_class) 20 | return hist 21 | 22 | def update(self, label_trues, label_preds): 23 | # print label_trues.dtype, label_preds.dtype 24 | for lt, lp in zip(label_trues, label_preds): 25 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 26 | 27 | def get_scores(self): 28 | """Returns accuracy score evaluation result. 29 | - overall accuracy 30 | - mean accuracy 31 | - mean IU 32 | - fwavacc 33 | """ 34 | hist = self.confusion_matrix 35 | acc = np.diag(hist).sum() / (hist.sum() + 0.0001) 36 | acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001) 37 | acc_cls = np.nanmean(acc_cls) 38 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001) 39 | mean_iu = np.nanmean(iu) 40 | freq = hist.sum(axis=1) / (hist.sum() + 0.0001) 41 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 42 | cls_iu = dict(zip(range(self.n_classes), iu)) 43 | 44 | return {'Overall Acc': acc, 45 | 'Mean Acc': acc_cls, 46 | 'FreqW Acc': fwavacc, 47 | 'Mean IoU': mean_iu,}, cls_iu 48 | 49 | def reset(self): 50 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) -------------------------------------------------------------------------------- /loss/balance_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BalanceCrossEntropyLoss(nn.Module): 6 | ''' 7 | Balanced cross entropy loss. 8 | Shape: 9 | - Input: :math:`(N, 1, H, W)` 10 | - GT: :math:`(N, 1, H, W)`, same shape as the input 11 | - Mask: :math:`(N, H, W)`, same spatial shape as the input 12 | - Output: scalar. 13 | 14 | Examples:: 15 | 16 | >>> m = nn.Sigmoid() 17 | >>> loss = nn.BCELoss() 18 | >>> input = torch.randn(3, requires_grad=True) 19 | >>> target = torch.empty(3).random_(2) 20 | >>> output = loss(m(input), target) 21 | >>> output.backward() 22 | ''' 23 | 24 | def __init__(self, negative_ratio=3.0, eps=1e-6): 25 | super(BalanceCrossEntropyLoss, self).__init__() 26 | self.negative_ratio = negative_ratio 27 | self.eps = eps 28 | 29 | def forward(self, 30 | pred: torch.Tensor, 31 | gt: torch.Tensor, 32 | mask: torch.Tensor, 33 | return_origin=False): 34 | ''' 35 | Args: 36 | pred: shape :math:`(N, 1, H, W)`, the prediction of network 37 | gt: shape :math:`(N, 1, H, W)`, the target 38 | mask: shape :math:`(N, H, W)`, the mask indicates positive regions 39 | ''' 40 | positive = (gt * mask).byte() 41 | negative = ((1 - gt) * mask).byte() 42 | positive_count = int(positive.float().sum()) 43 | negative_count = min(int(negative.float().sum()), 44 | int(positive_count * self.negative_ratio)) 45 | loss = nn.functional.binary_cross_entropy( 46 | pred, gt, reduction='none')[:, 0, :, :] 47 | positive_loss = loss * positive.float() 48 | negative_loss = loss * negative.float() 49 | negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) 50 | 51 | balance_loss = (positive_loss.sum() + negative_loss.sum()) /\ 52 | (positive_count + negative_count + self.eps) 53 | 54 | if return_origin: 55 | return balance_loss, loss 56 | return balance_loss 57 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | gpu_id: '0' 3 | backbone: 'resnet18' 4 | pretrained: True 5 | HeadName: 'DB' 6 | in_channels: [64, 128, 256, 512] #[256, 512, 1024, 2048] 7 | inner_channels: 256 8 | k: 50 9 | adaptive: True 10 | start_val_epoch: 1000 11 | n_epoch: 1200 12 | batch_size: 16 13 | use_sr: True 14 | sr_lr: 0.00001 15 | base_lr: 0.002 16 | num_workers: 0 17 | show_step: 5 18 | print_format: 'linux' # linux or windows 19 | restore: True 20 | resume: './checkpoints/DB_resnet18_bs_16_ep_1200/DB.pth.tar' 21 | checkpoints: './checkpoints' 22 | is_icdar2015: True 23 | is_transform: True 24 | is_show: False 25 | train_img_format: '.jpg' 26 | val_img_format: '.jpg' 27 | train_img_dir: '/home/aistudio/work/data/icdar/train_img/' 28 | train_gt_dir: '/home/aistudio/work/data/icdar/train_gt/' 29 | val_img_dir: '/home/aistudio/work/data/icdar/test_img/' 30 | val_gt_dir: '/home/aistudio/work/data/icdar/test_gt/' 31 | radom_angle: [-10, 10] 32 | output_path: './outputs_val' 33 | decay_method: 'e_decay' # e_decay: 指数衰减, s_decay: 指定epoch衰减 34 | schedule: [500,800,1000] 35 | gama: 0.1 36 | test: 37 | gpu_id: '0' 38 | pretrained: False 39 | merge_conv_bn: False 40 | adaptive: False 41 | short_side: 736 42 | thresh: 0.5 43 | box_thresh: 0.6 44 | unclip_ratio: 2 45 | min_size: 3 46 | max_candidates: 1000 47 | is_poly: False 48 | is_icdar2015: True 49 | test_img_format: '.jpg' 50 | test_img_dir: '/home/aistudio/work/data/icdar/test_img/' 51 | test_gt_dir: '/home/aistudio/work/data/icdar/test_gt/' 52 | checkpoints: './checkpoints/DB_resnet18_bs_16_ep_1200/DB.pth.tar' 53 | out_dir: './outputs_test' 54 | 55 | pruned: 56 | gpu_id: '0' 57 | scale: [73, 77, 81, 85] 58 | base_num: 8 59 | cut_percent: 0.8 60 | pruned_checkpoints: './pruned/checkpoint/pruned_dict.pth.tar' 61 | checkpoints_dict: './pruned/checkpoint/pruned_dict.dict' 62 | save_checkpoints: './pruned/checkpoint' 63 | checkpoints: './checkpoints/DB_resnet18_bs_16_ep_1200/DB.pth.tar' 64 | finetune_lr: 0.0005 65 | resume: './checkpoints/DB_resnet18_bs_16_ep_1200/DB.pth.tar' 66 | restore: True 67 | n_epoch: 100 68 | start_val_epoch: 40 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [English](README_en.md) | 简体中文 2 | 3 | # DBNet-lite-pytorch 4 | 5 | 6 | ## 这个项目之后会在这里更新,我把之前的项目都做了下整合[pytorchOCR](https://github.com/BADBADBADBOY/pytorchOCR) 7 | 8 | ## 环境配置 9 | 10 | ``` 11 | pip install -r requirement.txt 12 | cd models/dcn/ 13 | sh make.sh 14 | ``` 15 | 16 | *** 17 | 18 | ## 水平或倾斜文本格式 19 | 20 | 照着icdar2015的格式, x1,y1,x2,y2,x3,y3,x4,y4,label, 21 | ``` 22 | 23 | image 24 | │ 1.jpg 25 | │ 2.jpg 26 | │ ... 27 | label 28 | │ gt_1.txt 29 | │ gt_2.txt 30 | | ... 31 | ``` 32 | *** 33 | 34 | ## 弧形文本的格式 35 | 36 | 数据格式, x1,y1,x2,y2,x3,y3,x4,y4 ...xn,yn,label 37 | 38 | n个点组成,n的个数可以不定 39 | 40 | ``` 41 | image 42 | │ 1.jpg 43 | │ 2.jpg 44 | │ ... 45 | label 46 | │ gt_1.txt 47 | │ gt_2.txt 48 | | ... 49 | ``` 50 | 51 | *** 52 | 53 | 54 | ## 训练部分 55 | 56 | 在根目录的config.yaml里配置train部分的一些参数,例如一些图片位置,如果你的图片和gt文件名字是一样的,你可以设置is_icdar2015=False。 57 | 如果你不想做验证,可以直接设置start_val_epoch大于n_epoch,如果设置了做验证,会保存一个hmean最高的最优模型。 58 | 59 | ``` 60 | python3 train.py 61 | ``` 62 | *** 63 | 64 | 65 | ## 测试部分 66 | 67 | 测试时,配置config.yaml中test部分,对于弧形文本设置is_poly=True,其它非弧形文本设置为False 68 | 69 | ``` 70 | python3 inference.py 71 | ``` 72 | *** 73 | ## 模型压缩之通道剪裁 74 | 75 | ### 训练部分 76 | 1. 先进行稀疏训练,首先修改config.yaml将use_sr 设置为True,并设定sr_lr,这个设置越大压的越多,注意设置太大有可能不收敛. 77 | 78 | ``` 79 | python3 train.py 80 | ``` 81 | 2. 压缩模型 82 | 设置好config.yaml中pruned部分参数,运行 83 | ``` 84 | python3 ./pruned/prune.py 85 | ``` 86 | 3. 重新finetune模型 87 | 这里精度会很快回升,一般可以训练50-100epoch,具体自己做实验 88 | ``` 89 | python3 ./pruned/train_fintune.py 90 | ``` 91 | 92 | ### 测试部分 93 | 94 | ``` 95 | python3 ./pruned/prune_inference.py 96 | ``` 97 | 98 | *** 99 | 100 | 101 | ## 在icdar2015的测试结果 102 | 103 | |Method| head|extra data|prune ratio|model size(M)|precision(%)| recall(%) | hmean(%)|model_file| 104 | | - | - | - | - | - | - |- | - |- | 105 | | Resnet18|FPN|no|0|62.6|86.11| 76.45| 80.99|[baiduyun](https://pan.baidu.com/s/1wmbGMoluWlZ97LCqOnwjOg) (extract code: p0bk)| 106 | | Resnet18|DB|no|0.8|20.1|85.55| 76.40| 80.72|| 107 | *** 108 | ## 在icdar2015的测试结果图 109 | 110 | 111 | 112 | *** 113 | 114 | #### 该项目会做 115 | - [x] 转换作者的代码便于阅读和调试 116 | - [x] 展示一些训练结果 117 | - [ ] 加入轻量化的backbone压缩模型 118 | - [ ] 通过通道剪裁压缩DB模型,精度基本不变 119 | - [ ] 通过知识蒸馏进一步提升压缩后模型效果 120 | 121 | 122 | 123 | 124 | # 参考 125 | 126 | 1. https://github.com/whai362/PSENet 127 | 2. https://github.com/MhLiao/DB 128 | 3. https://github.com/Jzz24/pytorch_quantization 129 | 130 | 131 | -------------------------------------------------------------------------------- /models/dcn/functions/deform_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from .. import deform_pool_cuda 5 | 6 | 7 | class DeformRoIPoolingFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, 11 | data, 12 | rois, 13 | offset, 14 | spatial_scale, 15 | out_size, 16 | out_channels, 17 | no_trans, 18 | group_size=1, 19 | part_size=None, 20 | sample_per_part=4, 21 | trans_std=.0): 22 | ctx.spatial_scale = spatial_scale 23 | ctx.out_size = out_size 24 | ctx.out_channels = out_channels 25 | ctx.no_trans = no_trans 26 | ctx.group_size = group_size 27 | ctx.part_size = out_size if part_size is None else part_size 28 | ctx.sample_per_part = sample_per_part 29 | ctx.trans_std = trans_std 30 | 31 | assert 0.0 <= ctx.trans_std <= 1.0 32 | if not data.is_cuda: 33 | raise NotImplementedError 34 | 35 | n = rois.shape[0] 36 | output = data.new_empty(n, out_channels, out_size, out_size) 37 | output_count = data.new_empty(n, out_channels, out_size, out_size) 38 | deform_pool_cuda.deform_psroi_pooling_cuda_forward( 39 | data, rois, offset, output, output_count, ctx.no_trans, 40 | ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size, 41 | ctx.part_size, ctx.sample_per_part, ctx.trans_std) 42 | 43 | if data.requires_grad or rois.requires_grad or offset.requires_grad: 44 | ctx.save_for_backward(data, rois, offset) 45 | ctx.output_count = output_count 46 | 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | if not grad_output.is_cuda: 52 | raise NotImplementedError 53 | 54 | data, rois, offset = ctx.saved_tensors 55 | output_count = ctx.output_count 56 | grad_input = torch.zeros_like(data) 57 | grad_rois = None 58 | grad_offset = torch.zeros_like(offset) 59 | 60 | deform_pool_cuda.deform_psroi_pooling_cuda_backward( 61 | grad_output, data, rois, offset, output_count, grad_input, 62 | grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels, 63 | ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part, 64 | ctx.trans_std) 65 | return (grad_input, grad_rois, grad_offset, None, None, None, None, 66 | None, None, None, None) 67 | 68 | 69 | deform_roi_pooling = DeformRoIPoolingFunction.apply 70 | -------------------------------------------------------------------------------- /utils/Logger.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | # A simple torch style logger 3 | # (C) Wei YANG 2017 4 | from __future__ import absolute_import 5 | 6 | class Logger(object): 7 | def __init__(self, fpath, title=None, resume=False): 8 | self.file = None 9 | self.resume = resume 10 | self.title = '' if title == None else title 11 | if fpath is not None: 12 | if resume: 13 | self.file = open(fpath, 'r') 14 | name = self.file.readline() 15 | self.names = name.rstrip().split('\t') 16 | self.numbers = {} 17 | for _, name in enumerate(self.names): 18 | self.numbers[name] = [] 19 | 20 | for numbers in self.file: 21 | numbers = numbers.rstrip().split('\t') 22 | for i in range(0, len(numbers)): 23 | self.numbers[self.names[i]].append(numbers[i]) 24 | self.file.close() 25 | self.file = open(fpath, 'a') 26 | else: 27 | self.file = open(fpath, 'w') 28 | 29 | def set_names(self, names): 30 | if self.resume: 31 | pass 32 | self.numbers = {} 33 | self.names = names 34 | for _, name in enumerate(self.names): 35 | self.file.write(name) 36 | self.file.write('\t') 37 | self.numbers[name] = [] 38 | self.file.write('\n') 39 | self.file.flush() 40 | 41 | def set_split(self, names): 42 | if self.resume: 43 | pass 44 | self.numbers = {} 45 | self.names = names 46 | for _, name in enumerate(self.names): 47 | self.file.write(name) 48 | self.numbers[name] = [] 49 | self.file.write('\n') 50 | self.file.flush() 51 | 52 | def append(self, numbers): 53 | assert len(self.names) == len(numbers), 'Numbers do not match names' 54 | for index, num in enumerate(numbers): 55 | self.file.write("{0:.6f}".format(num)) 56 | self.file.write('\t') 57 | self.numbers[self.names[index]].append(num) 58 | self.file.write('\n') 59 | self.file.flush() 60 | 61 | def close(self): 62 | if self.file is not None: 63 | self.file.close() 64 | 65 | # if __name__ == '__main__': 66 | # import numpy as np 67 | # logger = Logger('test.txt') 68 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 69 | # 70 | # length = 100 71 | # t = np.arange(length) 72 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 73 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 74 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 75 | # 76 | # for i in range(0, length): 77 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 78 | # logger.set_split(['----------','------------','--------']) 79 | 80 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | English| [简体中文](README.md) 2 | 3 | # DBNet-lite-pytorch 4 | 5 | ## setup 6 | 7 | ``` 8 | pip install -r requirement.txt 9 | cd models/dcn/ 10 | sh make.sh 11 | ``` 12 | 13 | *** 14 | 15 | ## data format for Horizontal or slanted text 16 | 17 | follow icdar15 dataset format, x1,y1,x2,y2,x3,y3,x4,y4,label 18 | ``` 19 | 20 | image 21 | │ 1.jpg 22 | │ 2.jpg 23 | │ ... 24 | label 25 | │ gt_1.txt 26 | │ gt_2.txt 27 | | ... 28 | ``` 29 | *** 30 | 31 | ## data format for curved text 32 | 33 | dataset format, x1,y1,x2,y2,x3,y3,x4,y4 ...xn,yn,label 34 | 35 | The number of N can be inconsistent,The arrangement of points is clockwise or counterclockwise 36 | 37 | ``` 38 | image 39 | │ 1.jpg 40 | │ 2.jpg 41 | │ ... 42 | label 43 | │ gt_1.txt 44 | │ gt_2.txt 45 | | ... 46 | ``` 47 | 48 | *** 49 | 50 | 51 | ## train 52 | 53 | Go to configure config.yaml in the root directory 54 | 55 | ``` 56 | python3 train.py 57 | ``` 58 | *** 59 | 60 | 61 | ## test 62 | 63 | set is_poly = True in config.yaml for curved text , others set is_poly = False 64 | 65 | ``` 66 | python3 inference.py 67 | ``` 68 | *** 69 | ## Channel clipping for model compression 70 | 71 | ### Training section 72 | 1. sparse training is performed first. firstly, modify config.yaml to set use_sr to True, and set sr_lr. the larger this setting is, the more pressure it will have. 73 | pay attention to the fact that it may not converge if it is too large. 74 | 75 | ``` 76 | python3 train.py 77 | ``` 78 | 2. Compression model 79 | Set the parameters of pruned part in config.yaml and run it 80 | ``` 81 | python3 ./pruned/prune.py 82 | ``` 83 | 3. Re-finetune model 84 | Here, the accuracy will pick up quickly. Generally, you can train 50-100epoch and do your own experiments 85 | ``` 86 | python3 ./pruned/train_fintune.py 87 | ``` 88 | 89 | ### test section 90 | 91 | ``` 92 | python3 ./pruned/prune_inference.py 93 | ``` 94 | 95 | ## performance in icdar2015 96 | 97 | |Method| head|extra data|prune ratio|model size(M)|precision(%)| recall(%) | hmean(%)|model_file| 98 | | - | - | - | - | - | - |- | - |- | 99 | | Resnet18|FPN|no|0|62.6|86.11| 76.45| 80.99|[baiduyun](https://pan.baidu.com/s/1wmbGMoluWlZ97LCqOnwjOg) (extract code: p0bk)| 100 | | Resnet18|DB|no|0.8|20.1|85.55| 76.40| 80.72|| 101 | | mobilev3|DB|no|0|2.5|85.55| 76.40| 74.71|| 102 | 103 | *** 104 | ## some result 105 | 106 | 107 | 108 | *** 109 | 110 | #### ToDoList 111 | - [x] tranform DB code format from MhLiao/DB 112 | - [x] add some performance 113 | - [ ] add light backbone 114 | - [ ] pruned big model by channel clipping 115 | - [ ] Model distillation 116 | 117 | 118 | 119 | 120 | # reference 121 | 122 | 1. https://github.com/whai362/PSENet 123 | 2. https://github.com/MhLiao/DB 124 | 3. https://github.com/Jzz24/pytorch_quantization 125 | 126 | 127 | -------------------------------------------------------------------------------- /models/head/DB_Head.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: DB_Head.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | 13 | class ConvBnRelu(nn.Module): 14 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding,bias): 15 | super(ConvBnRelu,self).__init__() 16 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,bias=bias) # Reduce channels 17 | self.bn = nn.BatchNorm2d(out_channels) 18 | self.relu = nn.ReLU(inplace=True) 19 | def forward(self, x): 20 | x = self.conv(x) 21 | x = self.bn(x) 22 | x = self.relu(x) 23 | return x 24 | 25 | class DB_Head(nn.Module): 26 | def __init__(self,in_channels,inner_channels,bias=False): 27 | super(DB_Head,self).__init__() 28 | 29 | self.up5 = nn.Upsample(scale_factor=2, mode='nearest') 30 | self.up4 = nn.Upsample(scale_factor=2, mode='nearest') 31 | self.up3 = nn.Upsample(scale_factor=2, mode='nearest') 32 | 33 | self.in5 = ConvBnRelu(in_channels[-1], inner_channels, 1,stride=1, padding=0, bias=bias) 34 | self.in4 = ConvBnRelu(in_channels[-2], inner_channels, 1,stride=1, padding=0, bias=bias) 35 | self.in3 = ConvBnRelu(in_channels[-3], inner_channels, 1, stride=1, padding=0,bias=bias) 36 | self.in2 = ConvBnRelu(in_channels[-4], inner_channels, 1,stride=1, padding=0, bias=bias) 37 | 38 | self.out5 = nn.Sequential( 39 | ConvBnRelu(inner_channels, inner_channels //4, 3,stride=1, padding=1, bias=bias), 40 | nn.Upsample(scale_factor=8, mode='nearest')) 41 | self.out4 = nn.Sequential( 42 | ConvBnRelu(inner_channels, inner_channels //4, 3, stride=1,padding=1, bias=bias), 43 | nn.Upsample(scale_factor=4, mode='nearest')) 44 | self.out3 = nn.Sequential( 45 | ConvBnRelu(inner_channels, inner_channels //4, 3, stride=1,padding=1, bias=bias), 46 | nn.Upsample(scale_factor=2, mode='nearest')) 47 | self.out2 = ConvBnRelu(inner_channels, inner_channels//4, 3, stride=1,padding=1, bias=bias) 48 | 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight.data) 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1.) 54 | m.bias.data.fill_(1e-4) 55 | 56 | def forward(self, x): 57 | 58 | c2, c3, c4, c5 = x 59 | in5 = self.in5(c5) 60 | in4 = self.in4(c4) 61 | in3 = self.in3(c3) 62 | in2 = self.in2(c2) 63 | 64 | out4 = self.up5(in5) + in4 # 1/16 65 | out3 = self.up4(out4) + in3 # 1/8 66 | out2 = self.up3(out3) + in2 # 1/4 67 | 68 | p5 = self.out5(in5) 69 | p4 = self.out4(out4) 70 | p3 = self.out3(out3) 71 | p2 = self.out2(out2) 72 | fuse = torch.cat((p5, p4, p3, p2), 1) 73 | return fuse -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: tools.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | 11 | import copy 12 | import math 13 | import cv2 14 | import os 15 | import torch 16 | import numpy as np 17 | from tabulate import tabulate 18 | 19 | def judgePoint(point1,point2,center_point): 20 | if(point1[0]<=0 and point2[0]>0): 21 | return True 22 | if(point1[0]==0 and point2[0]==0): 23 | return point1[1]0): 26 | return True 27 | if(det<0): 28 | return False 29 | d1 = (point1[0]-center_point[0])*(point1[0]-center_point[0])-(point1[1]-center_point[1])*(point1[1]-center_point[1]) 30 | d2 = (point2[0]-center_point[0])*(point2[0]-center_point[0])-(point2[1]-center_point[1])*(point2[1]-center_point[1]) 31 | return d1 thresh] = 1 75 | pred_binary = pred_binary.astype(np.int32) 76 | gt_binary = gt_binarys.data.cpu().numpy() * training_masks 77 | gt_binary = gt_binary.astype(np.int32) 78 | running_metric_binary.update(gt_binary, pred_binary) 79 | score_binary, _ = running_metric_binary.get_scores() 80 | return score_binary 81 | 82 | 83 | def save_checkpoint(state, checkpoint='checkpoints', filename='DB.pth.tar'): 84 | filepath = os.path.join(checkpoint, filename) 85 | torch.save(state, filepath) -------------------------------------------------------------------------------- /models/head/FPN_Head.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: FPN_Head.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | class ConvBnRelu(nn.Module): 15 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding): 16 | super(ConvBnRelu,self).__init__() 17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) # Reduce channels 18 | self.bn = nn.BatchNorm2d(out_channels) 19 | self.relu = nn.ReLU(inplace=True) 20 | def forward(self, x): 21 | x = self.conv(x) 22 | x = self.bn(x) 23 | x = self.relu(x) 24 | return x 25 | 26 | 27 | class FPN_Head(nn.Module): 28 | def __init__(self,in_channels,inner_channels): 29 | super(FPN_Head,self).__init__() 30 | # Top layer 31 | self.toplayer = ConvBnRelu(in_channels[-1], inner_channels, kernel_size=1, stride=1, padding=0) # Reduce channels 32 | # Smooth layers 33 | self.smooth1 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1) 34 | self.smooth2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1) 35 | self.smooth3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1) 36 | # Lateral layers 37 | self.latlayer1 = ConvBnRelu(in_channels[-2], inner_channels, kernel_size=1, stride=1, padding=0) 38 | self.latlayer2 = ConvBnRelu(in_channels[-3], inner_channels, kernel_size=1, stride=1, padding=0) 39 | self.latlayer3 = ConvBnRelu(in_channels[-4], inner_channels, kernel_size=1, stride=1, padding=0) 40 | # Out map 41 | self.conv_out= nn.Conv2d(inner_channels*4, inner_channels, kernel_size=3, stride=1, padding=1) 42 | 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | nn.init.kaiming_normal_(m.weight.data) 46 | elif isinstance(m, nn.BatchNorm2d): 47 | m.weight.data.fill_(1.) 48 | m.bias.data.fill_(1e-4) 49 | 50 | def _upsample(self, x, y, scale=1): 51 | _, _, H, W = y.size() 52 | # return F.upsample(x, size=(H // scale, W // scale), mode='nearest') 53 | return F.interpolate(x, size=(H// scale, W// scale), mode='bilinear', align_corners=True) 54 | 55 | def _upsample_add(self, x, y): 56 | _, _, H, W = y.size() 57 | # return F.upsample(x, size=(H, W), mode='nearest') + y 58 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 59 | 60 | def forward(self, x): 61 | c2, c3, c4, c5 = x 62 | ## 63 | p5 = self.toplayer(c5) 64 | c4 = self.latlayer1(c4) 65 | p4 = self._upsample_add(p5, c4) 66 | p4 = self.smooth1(p4) 67 | c3 = self.latlayer2(c3) 68 | p3 = self._upsample_add(p4, c3) 69 | p3 = self.smooth2(p3) 70 | c2 = self.latlayer3(c2) 71 | p2 = self._upsample_add(p3, c2) 72 | p2 = self.smooth3(p2) 73 | ## 74 | p3 = self._upsample(p3, p2) 75 | p4 = self._upsample(p4, p2) 76 | p5 = self._upsample(p5, p2) 77 | 78 | out = torch.cat((p2, p3, p4, p5), 1) 79 | out = self.conv_out(out) 80 | return out 81 | -------------------------------------------------------------------------------- /dataloader/MakeSegMap.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: MakeSegMap.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | 11 | import cv2 12 | import pyclipper 13 | from shapely.geometry import Polygon 14 | import numpy as np 15 | 16 | class MakeSegDetectionData(): 17 | r''' 18 | Making binary mask from detection data with ICDAR format. 19 | Typically following the process of class `MakeICDARData`. 20 | ''' 21 | def __init__(self, min_text_size = 8,shrink_ratio = 0.4,is_training = True): 22 | self.min_text_size =min_text_size 23 | self.shrink_ratio = shrink_ratio 24 | self.is_training = is_training 25 | def process(self, img,polys,dontcare): 26 | ''' 27 | requied keys: 28 | image, polygons, ignore_tags, filename 29 | adding keys: 30 | mask 31 | ''' 32 | h, w = img.shape[:2] 33 | if self.is_training: 34 | polys, dontcare = self.validate_polygons( 35 | polys, dontcare, h, w) 36 | gt = np.zeros((1, h, w), dtype=np.float32) 37 | mask = np.ones((h, w), dtype=np.float32) 38 | for i in range(len(polys)): 39 | polygon = polys[i] 40 | height = max(polygon[:, 1]) - min(polygon[:, 1]) 41 | width = max(polygon[:, 0]) - min(polygon[:, 0]) 42 | if dontcare[i] or min(height, width) < self.min_text_size: 43 | cv2.fillPoly(mask, polygon.astype( 44 | np.int32)[np.newaxis, :, :], 0) 45 | dontcare[i] = True 46 | else: 47 | polygon_shape = Polygon(polygon) 48 | distance = polygon_shape.area * \ 49 | (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 50 | subject = [tuple(l) for l in polys[i]] 51 | padding = pyclipper.PyclipperOffset() 52 | padding.AddPath(subject, pyclipper.JT_ROUND, 53 | pyclipper.ET_CLOSEDPOLYGON) 54 | shrinked = padding.Execute(-distance) 55 | if shrinked == []: 56 | cv2.fillPoly(mask, polygon.astype( 57 | np.int32)[np.newaxis, :, :], 0) 58 | dontcare[i] = True 59 | continue 60 | shrinked = np.array(shrinked[0]).reshape(-1, 2) 61 | cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) 62 | return img,gt,mask 63 | 64 | def validate_polygons(self, polygons, ignore_tags, h, w): 65 | ''' 66 | polygons (numpy.array, required): of shape (num_instances, num_points, 2) 67 | ''' 68 | if len(polygons) == 0: 69 | return polygons, ignore_tags 70 | assert len(polygons) == len(ignore_tags) 71 | for polygon in polygons: 72 | polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) 73 | polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) 74 | 75 | for i in range(len(polygons)): 76 | area = self.polygon_area(polygons[i]) 77 | if abs(area) < 1: 78 | ignore_tags[i] = True 79 | if area > 0: 80 | polygons[i] = polygons[i][::-1, :] 81 | return polygons, ignore_tags 82 | 83 | def polygon_area(self, polygon): 84 | edge = 0 85 | for i in range(polygon.shape[0]): 86 | next_index = (i + 1) % polygon.shape[0] 87 | edge += (polygon[next_index, 0] - polygon[i, 0]) * (polygon[next_index, 1] - polygon[i, 1]) 88 | 89 | return edge / 2. 90 | -------------------------------------------------------------------------------- /utils/model_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: model_eval.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import sys 11 | import cv2 12 | import torch 13 | import math 14 | import os 15 | import glob 16 | import argparse 17 | import pyclipper 18 | import torchvision.transforms as transforms 19 | import numpy as np 20 | from shapely.geometry import Polygon 21 | from tqdm import tqdm 22 | import time 23 | from cal_rescall.script import cal_recall_precison_f1 24 | from .DB_postprocesss import DBPostProcess 25 | import copy 26 | import yaml 27 | from utils.tools import * 28 | from PIL import Image 29 | 30 | def val(model,config): 31 | model.eval() 32 | files = glob.glob(os.path.join(config['train']['val_img_dir'],'*'+config['train']['val_img_format'])) 33 | if not (os.path.exists(config['train']['output_path'])): 34 | os.mkdir(config['train']['output_path']) 35 | 36 | if not (os.path.exists(os.path.join(config['train']['output_path'],'img_text'))): 37 | os.mkdir(os.path.join(config['train']['output_path'],'img_text')) 38 | 39 | if not (os.path.exists(os.path.join(config['train']['output_path'],'img_result'))): 40 | os.mkdir(os.path.join(config['train']['output_path'],'img_result')) 41 | 42 | bar = tqdm(total=len(files)) 43 | 44 | 45 | params = {'thresh':config['test']['thresh'], 46 | 'box_thresh':config['test']['box_thresh'], 47 | 'max_candidates':config['test']['max_candidates'], 48 | 'is_poly':config['test']['is_poly'], 49 | 'unclip_ratio':config['test']['unclip_ratio'], 50 | 'min_size':config['test']['min_size'] 51 | } 52 | 53 | dbprocess = DBPostProcess(params) 54 | total_frame = 0.0 55 | total_time = 0.0 56 | for file in files: 57 | 58 | bar.update(1) 59 | img = cv2.imread(file) 60 | img_ori = img.copy() 61 | img_name = file.split('/')[-1].split('.')[0] 62 | img = resize_image(img,config['test']['short_side']) 63 | img = Image.fromarray(img) 64 | img = img.convert('RGB') 65 | img = transforms.ToTensor()(img) 66 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img).unsqueeze(0).cuda() 67 | 68 | torch.cuda.synchronize() 69 | start = time.time() 70 | 71 | with torch.no_grad(): 72 | out = model(img) 73 | 74 | scale = (img_ori.shape[1] * 1.0 / out.shape[3], img_ori.shape[0] * 1.0 / out.shape[2]) 75 | bbox_batch,score_batch = dbprocess(out.cpu().numpy(),[scale]) 76 | 77 | torch.cuda.synchronize() 78 | end = time.time() 79 | total_frame += 1 80 | total_time += (end - start) 81 | sys.stdout.flush() 82 | 83 | for bbox in bbox_batch[0]: 84 | img_ori = cv2.drawContours(img_ori.copy(), [bbox.reshape(-1, 2).astype(np.int)], -1, (0, 255, 0), 1) 85 | 86 | if config['test']['is_icdar2015']: 87 | text_file = 'res_' + img_name + '.txt' 88 | else: 89 | text_file = img_name + '.txt' 90 | 91 | with open(os.path.join(config['train']['output_path'],'img_text',text_file),'w+',encoding='utf-8') as fid: 92 | for bbox in bbox_batch[0]: 93 | if(len(bbox)==0): 94 | continue 95 | bbox = bbox.reshape(-1).tolist() 96 | bbox = [str(x) for x in bbox] 97 | bbox = ','.join(bbox) 98 | fid.write(bbox+'\n') 99 | 100 | cv2.imwrite(os.path.join(config['train']['output_path'],'img_result',img_name+'.jpg'),img_ori) 101 | bar.close() 102 | print('fps: %.2f'%(total_frame / total_time)) 103 | from cal_rescall.script import cal_recall_precison_f1 104 | result_dict = cal_recall_precison_f1(config['train']['val_gt_dir'], os.path.join(config['train']['output_path'], 'img_text')) 105 | return result_dict -------------------------------------------------------------------------------- /dataloader/dataload.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: dataload.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import numpy as np 11 | from PIL import Image 12 | from torch.utils import data 13 | import glob 14 | import cv2 15 | import random 16 | import os 17 | import torchvision.transforms as transforms 18 | import torch 19 | from .random_thansform import Random_Augment 20 | from .MakeSegMap import MakeSegDetectionData 21 | from .MakeBorderMap import MakeBorderMap 22 | 23 | def get_img(img_path): 24 | img = cv2.imread(img_path) 25 | return img 26 | 27 | def get_bboxes(gt_path,config): 28 | with open(gt_path,'r',encoding='utf-8') as fid: 29 | lines = fid.readlines() 30 | polys = [] 31 | tags = [] 32 | for line in lines: 33 | line = line.replace('\ufeff','').replace( '\xef\xbb\xbf','') 34 | gt = line.split(',') 35 | if "#" in gt[-1]: 36 | tags.append(True) 37 | else: 38 | tags.append(False) 39 | if(config['train']['is_icdar2015']): 40 | box = [int(gt[i]) for i in range(8)] 41 | else: 42 | box = [int(gt[i]) for i in range(len(gt)-1)] 43 | polys.append(box) 44 | return np.array(polys), tags 45 | 46 | class DataLoader(data.Dataset): 47 | def __init__(self, config): 48 | self.config = config 49 | self.ra = Random_Augment() 50 | self.ms = MakeSegDetectionData() 51 | self.mb = MakeBorderMap() 52 | img_paths = glob.glob(os.path.join(config['train']['train_img_dir'],'*'+config['train']['train_img_format'])) 53 | gt_paths = [] 54 | for img_path in img_paths: 55 | im_name = img_path.split('/')[-1].split('.')[0] 56 | if(config['train']['is_icdar2015']): 57 | gt_file_name = 'gt_'+im_name+'.txt' 58 | else: 59 | gt_file_name = im_name + '.txt' 60 | gt_paths.append(os.path.join(config['train']['train_gt_dir'],gt_file_name)) 61 | self.img_paths = img_paths 62 | self.gt_paths = gt_paths 63 | 64 | def __len__(self): 65 | return len(self.img_paths) 66 | 67 | def __getitem__(self, index): 68 | img_path = self.img_paths[index] 69 | gt_path = self.gt_paths[index] 70 | 71 | img = get_img(img_path) 72 | polys, dontcare = get_bboxes(gt_path,self.config) 73 | 74 | if self.config['train']['is_transform']: 75 | img, polys = self.ra.random_scale(img, polys, 640) 76 | img, polys = self.ra.random_rotate(img, polys,self.config['train']['radom_angle']) 77 | img, polys = self.ra.random_flip(img, polys) 78 | img, polys, dontcare = self.ra.random_crop_db(img, polys, dontcare) 79 | 80 | img, gt, gt_mask = self.ms.process(img, polys, dontcare) 81 | img, thresh_map, thresh_mask = self.mb.process(img, polys, dontcare) 82 | 83 | if self.config['train']['is_show']: 84 | cv2.imwrite('img.jpg',img) 85 | cv2.imwrite('gt.jpg',gt[0]*255) 86 | cv2.imwrite('gt_mask.jpg',gt_mask[0]*255) 87 | cv2.imwrite('thresh_map.jpg',thresh_map*255) 88 | cv2.imwrite('thresh_mask.jpg',thresh_mask*255) 89 | 90 | if self.config['train']['is_transform']: 91 | img = Image.fromarray(img) 92 | img = img.convert('RGB') 93 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 94 | else: 95 | img = Image.fromarray(img) 96 | img = img.convert('RGB') 97 | 98 | img = transforms.ToTensor()(img) 99 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 100 | 101 | gt = torch.from_numpy(gt).float() 102 | gt_mask = torch.from_numpy(gt_mask).float() 103 | thresh_map = torch.from_numpy(thresh_map).float() 104 | thresh_mask = torch.from_numpy(thresh_mask).float() 105 | 106 | return img, gt,gt_mask,thresh_map,thresh_mask 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /models/dcn/src/deform_pool_cuda.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c 3 | 4 | // based on 5 | // author: Charles Shang 6 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | void DeformablePSROIPoolForward( 14 | const at::Tensor data, const at::Tensor bbox, const at::Tensor trans, 15 | at::Tensor out, at::Tensor top_count, const int batch, const int channels, 16 | const int height, const int width, const int num_bbox, 17 | const int channels_trans, const int no_trans, const float spatial_scale, 18 | const int output_dim, const int group_size, const int pooled_size, 19 | const int part_size, const int sample_per_part, const float trans_std); 20 | 21 | void DeformablePSROIPoolBackwardAcc( 22 | const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox, 23 | const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad, 24 | at::Tensor trans_grad, const int batch, const int channels, 25 | const int height, const int width, const int num_bbox, 26 | const int channels_trans, const int no_trans, const float spatial_scale, 27 | const int output_dim, const int group_size, const int pooled_size, 28 | const int part_size, const int sample_per_part, const float trans_std); 29 | 30 | void deform_psroi_pooling_cuda_forward( 31 | at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, 32 | at::Tensor top_count, const int no_trans, const float spatial_scale, 33 | const int output_dim, const int group_size, const int pooled_size, 34 | const int part_size, const int sample_per_part, const float trans_std) { 35 | AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); 36 | 37 | const int batch = input.size(0); 38 | const int channels = input.size(1); 39 | const int height = input.size(2); 40 | const int width = input.size(3); 41 | const int channels_trans = no_trans ? 2 : trans.size(1); 42 | 43 | const int num_bbox = bbox.size(0); 44 | if (num_bbox != out.size(0)) 45 | AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", 46 | out.size(0), num_bbox); 47 | 48 | DeformablePSROIPoolForward( 49 | input, bbox, trans, out, top_count, batch, channels, height, width, 50 | num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size, 51 | pooled_size, part_size, sample_per_part, trans_std); 52 | } 53 | 54 | void deform_psroi_pooling_cuda_backward( 55 | at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, 56 | at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, 57 | const int no_trans, const float spatial_scale, const int output_dim, 58 | const int group_size, const int pooled_size, const int part_size, 59 | const int sample_per_part, const float trans_std) { 60 | AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); 61 | AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); 62 | 63 | const int batch = input.size(0); 64 | const int channels = input.size(1); 65 | const int height = input.size(2); 66 | const int width = input.size(3); 67 | const int channels_trans = no_trans ? 2 : trans.size(1); 68 | 69 | const int num_bbox = bbox.size(0); 70 | if (num_bbox != out_grad.size(0)) 71 | AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", 72 | out_grad.size(0), num_bbox); 73 | 74 | DeformablePSROIPoolBackwardAcc( 75 | out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch, 76 | channels, height, width, num_bbox, channels_trans, no_trans, 77 | spatial_scale, output_dim, group_size, pooled_size, part_size, 78 | sample_per_part, trans_std); 79 | } 80 | 81 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 82 | m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward, 83 | "deform psroi pooling forward(CUDA)"); 84 | m.def("deform_psroi_pooling_cuda_backward", 85 | &deform_psroi_pooling_cuda_backward, 86 | "deform psroi pooling backward(CUDA)"); 87 | } -------------------------------------------------------------------------------- /models/head/seg_detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: seg_detector.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | 11 | from collections import OrderedDict 12 | import torch 13 | import torch.nn as nn 14 | from models.head.FPN_Head import FPN_Head 15 | from models.head.DB_Head import DB_Head 16 | 17 | BatchNorm2d = nn.BatchNorm2d 18 | 19 | class SegDetector(nn.Module): 20 | def __init__(self,headname='DB', 21 | in_channels=[64, 128, 256, 512], 22 | inner_channels=256, k=10, 23 | bias=False, adaptive=False, smooth=False, serial=False, 24 | *args, **kwargs): 25 | ''' 26 | bias: Whether conv layers have bias or not. 27 | adaptive: Whether to use adaptive threshold training or not. 28 | smooth: If true, use bilinear instead of deconv. 29 | serial: If true, thresh prediction will combine segmentation result as input. 30 | ''' 31 | super(SegDetector, self).__init__() 32 | self.k = k 33 | self.serial = serial 34 | 35 | if(headname == 'FPN'): 36 | self.head = FPN_Head(in_channels, inner_channels) 37 | else: 38 | self.head = DB_Head(in_channels, inner_channels) 39 | 40 | self.binarize = nn.Sequential( 41 | nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias), 42 | BatchNorm2d(inner_channels // 4), 43 | nn.ReLU(inplace=True), 44 | nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), 45 | BatchNorm2d(inner_channels // 4), 46 | nn.ReLU(inplace=True), 47 | nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), 48 | nn.Sigmoid()) 49 | self.binarize.apply(self.weights_init) 50 | 51 | self.adaptive = adaptive 52 | if adaptive: 53 | self.thresh = self._init_thresh( 54 | inner_channels, serial=serial, smooth=smooth, bias=bias) 55 | self.thresh.apply(self.weights_init) 56 | 57 | def weights_init(self, m): 58 | classname = m.__class__.__name__ 59 | if classname.find('Conv') != -1: 60 | nn.init.kaiming_normal_(m.weight.data) 61 | elif classname.find('BatchNorm') != -1: 62 | m.weight.data.fill_(1.) 63 | m.bias.data.fill_(1e-4) 64 | 65 | def _init_thresh(self, inner_channels, 66 | serial=False, smooth=False, bias=False): 67 | in_channels = inner_channels 68 | if serial: 69 | in_channels += 1 70 | self.thresh = nn.Sequential( 71 | nn.Conv2d(in_channels, inner_channels // 72 | 4, 3, padding=1, bias=bias), 73 | BatchNorm2d(inner_channels // 4), 74 | nn.ReLU(inplace=True), 75 | self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias), 76 | BatchNorm2d(inner_channels // 4), 77 | nn.ReLU(inplace=True), 78 | self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias), 79 | nn.Sigmoid()) 80 | return self.thresh 81 | 82 | def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False): 83 | if smooth: 84 | inter_out_channels = out_channels 85 | if out_channels == 1: 86 | inter_out_channels = in_channels 87 | module_list = [ 88 | nn.Upsample(scale_factor=2, mode='nearest'), 89 | nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)] 90 | if out_channels == 1: 91 | module_list.append( 92 | nn.Conv2d(in_channels, out_channels, 93 | kernel_size=1, stride=1, padding=1, bias=True)) 94 | 95 | return nn.Sequential(module_list) 96 | else: 97 | return nn.ConvTranspose2d(in_channels, out_channels, 2, 2) 98 | 99 | def forward(self, features, gt=None, masks=None, training=False): 100 | 101 | fuse = self.head(features) 102 | binary = self.binarize(fuse) 103 | 104 | if self.training: 105 | result = OrderedDict(binary=binary) 106 | else: 107 | return binary 108 | if self.adaptive and self.training: 109 | if self.serial: 110 | fuse = torch.cat( 111 | (fuse, nn.functional.interpolate( 112 | binary, fuse.shape[2:])), 1) 113 | thresh = self.thresh(fuse) 114 | thresh_binary = self.step_function(binary, thresh) 115 | result.update(thresh=thresh, thresh_binary=thresh_binary) 116 | return result 117 | 118 | def step_function(self, x, y): 119 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) 120 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: inference.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | 11 | import sys 12 | sys.path.append('/home/aistudio/external-libraries') 13 | import cv2 14 | import torch 15 | import math 16 | import os 17 | import glob 18 | import time 19 | import copy 20 | import yaml 21 | import argparse 22 | import pyclipper 23 | import numpy as np 24 | from tqdm import tqdm 25 | from models.DBNet import DBNet 26 | import torchvision.transforms as transforms 27 | from shapely.geometry import Polygon 28 | from cal_rescall.script import cal_recall_precison_f1 29 | from utils.DB_postprocesss import * 30 | from utils.tools import * 31 | from utils.fuse_model import fuse_module 32 | 33 | def test_net(config): 34 | os.environ["CUDA_VISIBLE_DEVICES"] = config['test']['gpu_id'] 35 | 36 | config['train']['pretrained'] = config['test']['pretrained'] 37 | files = glob.glob(os.path.join(config['test']['test_img_dir'],'*'+config['test']['test_img_format'])) 38 | model = DBNet(config).cuda() 39 | 40 | model_dict = torch.load(config['test']['checkpoints'])['state_dict'] 41 | state = model.state_dict() 42 | for key in state.keys(): 43 | if key in model_dict.keys(): 44 | state[key] = model_dict[key] 45 | model.load_state_dict(state) 46 | 47 | if(config['test']['merge_conv_bn']): 48 | fuse_module(model) 49 | print('merge conv bn ok!!!') 50 | 51 | if not (os.path.exists(config['test']['out_dir'])): 52 | os.mkdir(config['test']['out_dir']) 53 | 54 | if not (os.path.exists(os.path.join(config['test']['out_dir'],'img_text'))): 55 | os.mkdir(os.path.join(config['test']['out_dir'],'img_text')) 56 | 57 | if not (os.path.exists(os.path.join(config['test']['out_dir'],'img_result'))): 58 | os.mkdir(os.path.join(config['test']['out_dir'],'img_result')) 59 | 60 | bar = tqdm(total=len(files)) 61 | params = {'thresh':config['test']['thresh'], 62 | 'box_thresh':config['test']['box_thresh'], 63 | 'max_candidates':config['test']['max_candidates'], 64 | 'is_poly':config['test']['is_poly'], 65 | 'unclip_ratio':config['test']['unclip_ratio'], 66 | 'min_size':config['test']['min_size'] 67 | } 68 | 69 | dbprocess = DBPostProcess(params) 70 | 71 | total_frame = 0.0 72 | total_time = 0.0 73 | 74 | for file in files: 75 | model.eval() 76 | bar.update(1) 77 | img = cv2.imread(file) 78 | img_ori = img.copy() 79 | img_name = file.split('/')[-1].split('.')[0] 80 | img = resize_image(img,config['test']['short_side']) 81 | 82 | img = transforms.ToTensor()(img) 83 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img).unsqueeze(0).cuda() 84 | 85 | torch.cuda.synchronize() 86 | start = time.time() 87 | 88 | with torch.no_grad(): 89 | out = model(img) 90 | 91 | scale = (img_ori.shape[1] * 1.0 / out.shape[3], img_ori.shape[0] * 1.0 / out.shape[2]) 92 | bbox_batch,score_batch = dbprocess(out.cpu().numpy(),[scale]) 93 | 94 | torch.cuda.synchronize() 95 | end = time.time() 96 | total_frame += 1 97 | total_time += (end - start) 98 | sys.stdout.flush() 99 | 100 | for bbox in bbox_batch[0]: 101 | bbox = bbox.reshape(-1, 2).astype(np.int) 102 | # bbox = sort_coord(bbox) 103 | img_ori = cv2.drawContours(img_ori.copy(), [bbox], -1, (0, 255, 0), 1) 104 | 105 | if config['test']['is_icdar2015']: 106 | text_file = 'res_' + img_name + '.txt' 107 | else: 108 | text_file = img_name + '.txt' 109 | 110 | with open(os.path.join(config['test']['out_dir'],'img_text',text_file),'w+',encoding='utf-8') as fid: 111 | for bbox in bbox_batch[0]: 112 | if(len(bbox)==0): 113 | continue 114 | bbox = bbox.reshape(-1, 2).astype(np.int) 115 | # bbox = sort_coord(bbox) 116 | bbox = bbox.reshape(-1).tolist() 117 | bbox = [str(x) for x in bbox] 118 | bbox = ','.join(bbox) 119 | fid.write(bbox+'\n') 120 | 121 | cv2.imwrite(os.path.join(config['test']['out_dir'],'img_result',img_name+'.jpg'),img_ori) 122 | bar.close() 123 | print('fps: %.2f'%(total_frame / total_time)) 124 | result_dict = cal_recall_precison_f1(config['test']['test_gt_dir'], os.path.join(config['test']['out_dir'], 'img_text')) 125 | return result_dict 126 | 127 | if __name__ == '__main__': 128 | stream = open('config.yaml', 'r', encoding='utf-8') 129 | config = yaml.load(stream,Loader=yaml.FullLoader) 130 | result_dict = test_net(config) 131 | print(result_dict) -------------------------------------------------------------------------------- /pruned/prune_inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: test.py 5 | @time: 2020/04/28 6 | """ 7 | 8 | import sys 9 | 10 | sys.path.append('/home/aistudio/external-libraries') 11 | sys.path.append('./') 12 | import cv2 13 | import torch 14 | import math 15 | import os 16 | import glob 17 | import time 18 | import copy 19 | import yaml 20 | import argparse 21 | import pyclipper 22 | import numpy as np 23 | from tqdm import tqdm 24 | from models.DBNet import DBNet 25 | import torchvision.transforms as transforms 26 | from shapely.geometry import Polygon 27 | from cal_rescall.script import cal_recall_precison_f1 28 | from utils.DB_postprocesss import * 29 | from utils.tools import * 30 | from utils.fuse_model import fuse_module 31 | from pruned.get_pruned_model import load_prune_model 32 | 33 | 34 | def test_net(config): 35 | os.environ["CUDA_VISIBLE_DEVICES"] = config['test']['gpu_id'] 36 | 37 | config['train']['pretrained'] = config['test']['pretrained'] 38 | files = glob.glob(os.path.join(config['test']['test_img_dir'], '*' + config['test']['test_img_format'])) 39 | model = DBNet(config) 40 | model = load_prune_model(model, config['pruned']['checkpoints_dict']).cuda() 41 | 42 | model_dict = torch.load(config['pruned']['checkpoints'])['state_dict'] 43 | state = model.state_dict() 44 | for key in state.keys(): 45 | if key in model_dict.keys(): 46 | state[key] = model_dict[key] 47 | model.load_state_dict(state) 48 | 49 | if (config['test']['merge_conv_bn']): 50 | fuse_module(model) 51 | print('merge conv bn ok!!!') 52 | 53 | if not (os.path.exists(config['test']['out_dir'])): 54 | os.mkdir(config['test']['out_dir']) 55 | 56 | if not (os.path.exists(os.path.join(config['test']['out_dir'], 'img_text'))): 57 | os.mkdir(os.path.join(config['test']['out_dir'], 'img_text')) 58 | 59 | if not (os.path.exists(os.path.join(config['test']['out_dir'], 'img_result'))): 60 | os.mkdir(os.path.join(config['test']['out_dir'], 'img_result')) 61 | 62 | bar = tqdm(total=len(files)) 63 | params = {'thresh': config['test']['thresh'], 64 | 'box_thresh': config['test']['box_thresh'], 65 | 'max_candidates': config['test']['max_candidates'], 66 | 'is_poly': config['test']['is_poly'], 67 | 'unclip_ratio': config['test']['unclip_ratio'], 68 | 'min_size': config['test']['min_size'] 69 | } 70 | 71 | dbprocess = DBPostProcess(params) 72 | 73 | total_frame = 0.0 74 | total_time = 0.0 75 | 76 | for file in files: 77 | model.eval() 78 | bar.update(1) 79 | img = cv2.imread(file) 80 | img_ori = img.copy() 81 | img_name = file.split('/')[-1].split('.')[0] 82 | img = resize_image(img, config['test']['short_side']) 83 | 84 | img = transforms.ToTensor()(img) 85 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img).unsqueeze(0).cuda() 86 | 87 | torch.cuda.synchronize() 88 | start = time.time() 89 | 90 | with torch.no_grad(): 91 | out = model(img) 92 | 93 | scale = (img_ori.shape[1] * 1.0 / out.shape[3], img_ori.shape[0] * 1.0 / out.shape[2]) 94 | bbox_batch, score_batch = dbprocess(out.cpu().numpy(), [scale]) 95 | 96 | torch.cuda.synchronize() 97 | end = time.time() 98 | total_frame += 1 99 | total_time += (end - start) 100 | sys.stdout.flush() 101 | 102 | for bbox in bbox_batch[0]: 103 | bbox = bbox.reshape(-1, 2).astype(np.int) 104 | # bbox = sort_coord(bbox) 105 | img_ori = cv2.drawContours(img_ori.copy(), [bbox], -1, (0, 255, 0), 1) 106 | 107 | if config['test']['is_icdar2015']: 108 | text_file = 'res_' + img_name + '.txt' 109 | else: 110 | text_file = img_name + '.txt' 111 | 112 | with open(os.path.join(config['test']['out_dir'], 'img_text', text_file), 'w+', encoding='utf-8') as fid: 113 | for bbox in bbox_batch[0]: 114 | if (len(bbox) == 0): 115 | continue 116 | bbox = bbox.reshape(-1, 2).astype(np.int) 117 | # bbox = sort_coord(bbox) 118 | bbox = bbox.reshape(-1).tolist() 119 | bbox = [str(x) for x in bbox] 120 | bbox = ','.join(bbox) 121 | fid.write(bbox + '\n') 122 | 123 | cv2.imwrite(os.path.join(config['test']['out_dir'], 'img_result', img_name + '.jpg'), img_ori) 124 | bar.close() 125 | print('fps: %.2f' % (total_frame / total_time)) 126 | result_dict = cal_recall_precison_f1(config['test']['test_gt_dir'], 127 | os.path.join(config['test']['out_dir'], 'img_text')) 128 | return result_dict 129 | 130 | 131 | if __name__ == '__main__': 132 | stream = open('./config.yaml', 'r', encoding='utf-8') 133 | config = yaml.load(stream, Loader=yaml.FullLoader) 134 | result_dict = test_net(config) 135 | print(result_dict) -------------------------------------------------------------------------------- /dataloader/MakeBorderMap.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: MakeBorderMap.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import warnings 11 | import numpy as np 12 | import cv2 13 | from shapely.geometry import Polygon 14 | import pyclipper 15 | import pdb 16 | class MakeBorderMap(): 17 | def __init__(self,shrink_ratio=0.4,thresh_min=0.3,thresh_max=0.7): 18 | super(MakeBorderMap,self).__init__() 19 | self.shrink_ratio = shrink_ratio 20 | self.thresh_min = thresh_min 21 | self.thresh_max = thresh_max 22 | def process(self, img,polys,dontcare): 23 | thresh_map = np.zeros(img.shape[:2], dtype=np.float32) 24 | thresh_mask = np.zeros(img.shape[:2], dtype=np.float32) 25 | 26 | for i in range(len(polys)): 27 | if dontcare[i]: 28 | continue 29 | self.draw_border_map(polys[i], thresh_map, mask=thresh_mask) 30 | thresh_map = thresh_map * (self.thresh_max - self.thresh_min) + self.thresh_min 31 | return img,thresh_map,thresh_mask 32 | 33 | def draw_border_map(self, polygon, canvas, mask): 34 | polygon = np.array(polygon) 35 | assert polygon.ndim == 2 36 | assert polygon.shape[1] == 2 37 | 38 | polygon_shape = Polygon(polygon) 39 | distance = polygon_shape.area * \ 40 | (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 41 | subject = [tuple(l) for l in polygon] 42 | padding = pyclipper.PyclipperOffset() 43 | padding.AddPath(subject, pyclipper.JT_ROUND, 44 | pyclipper.ET_CLOSEDPOLYGON) 45 | padded_polygon = np.array(padding.Execute(distance)[0]) 46 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) 47 | 48 | xmin = padded_polygon[:, 0].min() 49 | xmax = padded_polygon[:, 0].max() 50 | ymin = padded_polygon[:, 1].min() 51 | ymax = padded_polygon[:, 1].max() 52 | width = xmax - xmin + 1 53 | height = ymax - ymin + 1 54 | 55 | polygon[:, 0] = polygon[:, 0] - xmin 56 | polygon[:, 1] = polygon[:, 1] - ymin 57 | 58 | xs = np.broadcast_to( 59 | np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)) 60 | ys = np.broadcast_to( 61 | np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)) 62 | 63 | distance_map = np.zeros( 64 | (polygon.shape[0], height, width), dtype=np.float32) 65 | for i in range(polygon.shape[0]): 66 | j = (i + 1) % polygon.shape[0] 67 | absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) 68 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1) 69 | distance_map = distance_map.min(axis=0) 70 | 71 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) 72 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) 73 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) 74 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) 75 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( 76 | 1 - distance_map[ 77 | ymin_valid-ymin:ymax_valid-ymax+height, 78 | xmin_valid-xmin:xmax_valid-xmax+width], 79 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) 80 | 81 | def distance(self, xs, ys, point_1, point_2): 82 | ''' 83 | compute the distance from point to a line 84 | ys: coordinates in the first axis 85 | xs: coordinates in the second axis 86 | point_1, point_2: (x, y), the end of the line 87 | ''' 88 | height, width = xs.shape[:2] 89 | square_distance_1 = np.square( 90 | xs - point_1[0]) + np.square(ys - point_1[1]) 91 | square_distance_2 = np.square( 92 | xs - point_2[0]) + np.square(ys - point_2[1]) 93 | square_distance = np.square( 94 | point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1]) 95 | 96 | cosin = (square_distance - square_distance_1 - square_distance_2) / \ 97 | (2 * np.sqrt(square_distance_1 * square_distance_2)) 98 | square_sin = 1 - np.square(cosin) 99 | square_sin = np.nan_to_num(square_sin) 100 | result = np.sqrt(square_distance_1 * square_distance_2 * 101 | square_sin / square_distance) 102 | 103 | result[cosin < 0] = np.sqrt(np.fmin( 104 | square_distance_1, square_distance_2))[cosin < 0] 105 | # self.extend_line(point_1, point_2, result) 106 | return result 107 | 108 | def extend_line(self, point_1, point_2, result): 109 | ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))), 110 | int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio)))) 111 | cv2.line(result, tuple(ex_point_1), tuple(point_1), 112 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 113 | ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))), 114 | int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio)))) 115 | cv2.line(result, tuple(ex_point_2), tuple(point_2), 116 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 117 | return ex_point_1, ex_point_2 -------------------------------------------------------------------------------- /models/dcn/modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.utils import _pair 6 | 7 | from ..functions.deform_conv import deform_conv, modulated_deform_conv 8 | 9 | 10 | class DeformConv(nn.Module): 11 | 12 | def __init__(self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | padding=0, 18 | dilation=1, 19 | groups=1, 20 | deformable_groups=1, 21 | bias=False): 22 | super(DeformConv, self).__init__() 23 | 24 | assert not bias 25 | assert in_channels % groups == 0, \ 26 | 'in_channels {} cannot be divisible by groups {}'.format( 27 | in_channels, groups) 28 | assert out_channels % groups == 0, \ 29 | 'out_channels {} cannot be divisible by groups {}'.format( 30 | out_channels, groups) 31 | 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.kernel_size = _pair(kernel_size) 35 | self.stride = _pair(stride) 36 | self.padding = _pair(padding) 37 | self.dilation = _pair(dilation) 38 | self.groups = groups 39 | self.deformable_groups = deformable_groups 40 | 41 | self.weight = nn.Parameter( 42 | torch.Tensor(out_channels, in_channels // self.groups, 43 | *self.kernel_size)) 44 | 45 | self.reset_parameters() 46 | 47 | def reset_parameters(self): 48 | n = self.in_channels 49 | for k in self.kernel_size: 50 | n *= k 51 | stdv = 1. / math.sqrt(n) 52 | self.weight.data.uniform_(-stdv, stdv) 53 | 54 | def forward(self, x, offset): 55 | return deform_conv(x, offset, self.weight, self.stride, self.padding, 56 | self.dilation, self.groups, self.deformable_groups) 57 | 58 | 59 | class DeformConvPack(DeformConv): 60 | 61 | def __init__(self, *args, **kwargs): 62 | super(DeformConvPack, self).__init__(*args, **kwargs) 63 | 64 | self.conv_offset = nn.Conv2d( 65 | self.in_channels, 66 | self.deformable_groups * 2 * self.kernel_size[0] * 67 | self.kernel_size[1], 68 | kernel_size=self.kernel_size, 69 | stride=_pair(self.stride), 70 | padding=_pair(self.padding), 71 | bias=True) 72 | self.init_offset() 73 | 74 | def init_offset(self): 75 | self.conv_offset.weight.data.zero_() 76 | self.conv_offset.bias.data.zero_() 77 | 78 | def forward(self, x): 79 | offset = self.conv_offset(x) 80 | return deform_conv(x, offset, self.weight, self.stride, self.padding, 81 | self.dilation, self.groups, self.deformable_groups) 82 | 83 | 84 | class ModulatedDeformConv(nn.Module): 85 | 86 | def __init__(self, 87 | in_channels, 88 | out_channels, 89 | kernel_size, 90 | stride=1, 91 | padding=0, 92 | dilation=1, 93 | groups=1, 94 | deformable_groups=1, 95 | bias=True): 96 | super(ModulatedDeformConv, self).__init__() 97 | self.in_channels = in_channels 98 | self.out_channels = out_channels 99 | self.kernel_size = _pair(kernel_size) 100 | self.stride = stride 101 | self.padding = padding 102 | self.dilation = dilation 103 | self.groups = groups 104 | self.deformable_groups = deformable_groups 105 | self.with_bias = bias 106 | 107 | self.weight = nn.Parameter( 108 | torch.Tensor(out_channels, in_channels // groups, 109 | *self.kernel_size)) 110 | if bias: 111 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 112 | else: 113 | self.register_parameter('bias', None) 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | n = self.in_channels 118 | for k in self.kernel_size: 119 | n *= k 120 | stdv = 1. / math.sqrt(n) 121 | self.weight.data.uniform_(-stdv, stdv) 122 | if self.bias is not None: 123 | self.bias.data.zero_() 124 | 125 | def forward(self, x, offset, mask): 126 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 127 | self.stride, self.padding, self.dilation, 128 | self.groups, self.deformable_groups) 129 | 130 | 131 | class ModulatedDeformConvPack(ModulatedDeformConv): 132 | 133 | def __init__(self, *args, **kwargs): 134 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) 135 | 136 | self.conv_offset_mask = nn.Conv2d( 137 | self.in_channels, 138 | self.deformable_groups * 3 * self.kernel_size[0] * 139 | self.kernel_size[1], 140 | kernel_size=self.kernel_size, 141 | stride=_pair(self.stride), 142 | padding=_pair(self.padding), 143 | bias=True) 144 | self.init_offset() 145 | 146 | def init_offset(self): 147 | self.conv_offset_mask.weight.data.zero_() 148 | self.conv_offset_mask.bias.data.zero_() 149 | 150 | def forward(self, x): 151 | out = self.conv_offset_mask(x) 152 | o1, o2, mask = torch.chunk(out, 3, dim=1) 153 | offset = torch.cat((o1, o2), dim=1) 154 | mask = torch.sigmoid(mask) 155 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 156 | self.stride, self.padding, self.dilation, 157 | self.groups, self.deformable_groups) 158 | -------------------------------------------------------------------------------- /pruned/get_pruned_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: tt.py 7 | @time: 2020/6/20 10:51 8 | 9 | """ 10 | 11 | import models 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | 17 | def get_new_model(model, new_model, prued_mask, bn_index): 18 | merge1_index = [3, 12, 18] 19 | merge2_index = [25, 28, 34] 20 | merge3_index = [41, 44, 50] 21 | merge4_index = [57, 60, 66] 22 | 23 | index_0 = [] 24 | for item in merge1_index: 25 | index_0.append(bn_index.index(item)) 26 | mask1 = prued_mask[index_0[0]] | prued_mask[index_0[1]] | prued_mask[index_0[2]] 27 | 28 | index_1 = [] 29 | for item in merge2_index: 30 | index_1.append(bn_index.index(item)) 31 | mask2 = prued_mask[index_1[0]] | prued_mask[index_1[1]] | prued_mask[index_1[2]] 32 | 33 | index_2 = [] 34 | for item in merge3_index: 35 | index_2.append(bn_index.index(item)) 36 | mask3 = prued_mask[index_2[0]] | prued_mask[index_2[1]] | prued_mask[index_2[2]] 37 | 38 | index_3 = [] 39 | for item in merge4_index: 40 | index_3.append(bn_index.index(item)) 41 | mask4 = prued_mask[index_3[0]] | prued_mask[index_3[1]] | prued_mask[index_3[2]] 42 | 43 | for index in index_0: 44 | prued_mask[index] = mask1 45 | 46 | for index in index_1: 47 | prued_mask[index] = mask2 48 | 49 | for index in index_2: 50 | prued_mask[index] = mask3 51 | 52 | for index in index_3: 53 | prued_mask[index] = mask4 54 | 55 | ############################################################## 56 | index_bn = 0 57 | index_conv = 0 58 | 59 | bn_mask = [] 60 | conv_in_mask = [] 61 | conv_out_mask = [] 62 | tag = 0 63 | for m in new_model.modules(): 64 | if (tag > 69): 65 | break 66 | if (isinstance(m, nn.BatchNorm2d)): 67 | m.num_features = prued_mask[index_bn].sum() 68 | bn_mask.append(prued_mask[index_bn]) 69 | index_bn += 1 70 | elif (isinstance(m, nn.Conv2d)): 71 | if (index_conv == 0): 72 | m.in_channels = 3 73 | conv_in_mask.append(torch.ones(3)) 74 | else: 75 | m.in_channels = prued_mask[index_conv - 1].sum() 76 | conv_in_mask.append(prued_mask[index_conv - 1]) 77 | m.out_channels = int(prued_mask[index_conv].sum()) 78 | conv_out_mask.append(prued_mask[index_conv]) 79 | index_conv += 1 80 | tag += 1 81 | 82 | conv_change_index = [27, 43, 59] # 83 | change_conv_bn_index = [18, 34, 50] # 84 | tag = 0 85 | for m in new_model.modules(): 86 | if (tag > 69): 87 | break 88 | if (isinstance(m, nn.Conv2d)): 89 | if (tag in conv_change_index): 90 | index = conv_change_index.index(tag) 91 | index = change_conv_bn_index[index] 92 | index = bn_index.index(index) 93 | mask = prued_mask[index] 94 | conv_in_mask[index + 3] = mask 95 | m.in_channels = mask.sum() 96 | tag += 1 97 | 98 | ############################################################# 99 | bn_i = 0 100 | conv_i = 0 101 | scale_i = 0 102 | scale_mask = [mask4, mask3, mask2, mask1] 103 | # scale = [70,86,90,94] # FPN 104 | scale = [73, 77, 81, 85] # DB 105 | for [m0, m1] in zip(model.modules(), new_model.modules()): 106 | if (scale_i > 69): 107 | if isinstance(m0, nn.Conv2d): 108 | if (scale_i in scale): 109 | index = scale.index(scale_i) 110 | m1.in_channels = scale_mask[index].sum() 111 | idx0 = np.squeeze(np.argwhere(np.asarray(scale_mask[index].cpu().numpy()))) 112 | idx1 = np.squeeze(np.argwhere(np.asarray(torch.ones(256).cpu().numpy()))) 113 | if idx0.size == 1: 114 | idx0 = np.resize(idx0, (1,)) 115 | if idx1.size == 1: 116 | idx1 = np.resize(idx1, (1,)) 117 | w = m0.weight.data[:, idx0, :, :].clone() 118 | m1.weight.data = w[idx1, :, :, :].clone() 119 | if m1.bias is not None: 120 | m1.bias.data = m0.bias.data[idx1].clone() 121 | 122 | else: 123 | m1.weight.data = m0.weight.data.clone() 124 | if m1.bias is not None: 125 | m1.bias.data = m0.bias.data.clone() 126 | 127 | elif isinstance(m0, nn.BatchNorm2d): 128 | m1.weight.data = m0.weight.data.clone() 129 | if m1.bias is not None: 130 | m1.bias.data = m0.bias.data.clone() 131 | m1.running_mean = m0.running_mean.clone() 132 | m1.running_var = m0.running_var.clone() 133 | 134 | else: 135 | if isinstance(m0, nn.BatchNorm2d): 136 | idx1 = np.squeeze(np.argwhere(np.asarray(bn_mask[bn_i].cpu().numpy()))) 137 | if idx1.size == 1: 138 | idx1 = np.resize(idx1, (1,)) 139 | m1.weight.data = m0.weight.data[idx1].clone() 140 | if m1.bias is not None: 141 | m1.bias.data = m0.bias.data[idx1].clone() 142 | m1.running_mean = m0.running_mean[idx1].clone() 143 | m1.running_var = m0.running_var[idx1].clone() 144 | bn_i += 1 145 | elif isinstance(m0, nn.Conv2d): 146 | if (isinstance(conv_in_mask[conv_i], list)): 147 | idx0 = np.squeeze(np.argwhere(np.asarray(torch.cat(conv_in_mask[conv_i], 0).cpu().numpy()))) 148 | else: 149 | idx0 = np.squeeze(np.argwhere(np.asarray(conv_in_mask[conv_i].cpu().numpy()))) 150 | idx1 = np.squeeze(np.argwhere(np.asarray(conv_out_mask[conv_i].cpu().numpy()))) 151 | if idx0.size == 1: 152 | idx0 = np.resize(idx0, (1,)) 153 | if idx1.size == 1: 154 | idx1 = np.resize(idx1, (1,)) 155 | w = m0.weight.data[:, idx0, :, :].clone() 156 | m1.weight.data = w[idx1, :, :, :].clone() 157 | if m1.bias is not None: 158 | m1.bias.data = m0.bias.data[idx1].clone() 159 | conv_i += 1 160 | 161 | scale_i += 1 162 | 163 | return new_model 164 | 165 | 166 | def load_prune_model(model, pruned_model_dict_path): 167 | _load = torch.load(pruned_model_dict_path) 168 | prued_mask = _load['prued_mask'] 169 | bn_index = _load['bn_index'] 170 | prune_model = get_new_model(model, model, prued_mask, bn_index) 171 | return prune_model 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /models/dcn/modules/deform_pool.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from ..functions.deform_pool import deform_roi_pooling 4 | 5 | 6 | class DeformRoIPooling(nn.Module): 7 | 8 | def __init__(self, 9 | spatial_scale, 10 | out_size, 11 | out_channels, 12 | no_trans, 13 | group_size=1, 14 | part_size=None, 15 | sample_per_part=4, 16 | trans_std=.0): 17 | super(DeformRoIPooling, self).__init__() 18 | self.spatial_scale = spatial_scale 19 | self.out_size = out_size 20 | self.out_channels = out_channels 21 | self.no_trans = no_trans 22 | self.group_size = group_size 23 | self.part_size = out_size if part_size is None else part_size 24 | self.sample_per_part = sample_per_part 25 | self.trans_std = trans_std 26 | 27 | def forward(self, data, rois, offset): 28 | if self.no_trans: 29 | offset = data.new_empty(0) 30 | return deform_roi_pooling( 31 | data, rois, offset, self.spatial_scale, self.out_size, 32 | self.out_channels, self.no_trans, self.group_size, self.part_size, 33 | self.sample_per_part, self.trans_std) 34 | 35 | 36 | class DeformRoIPoolingPack(DeformRoIPooling): 37 | 38 | def __init__(self, 39 | spatial_scale, 40 | out_size, 41 | out_channels, 42 | no_trans, 43 | group_size=1, 44 | part_size=None, 45 | sample_per_part=4, 46 | trans_std=.0, 47 | num_offset_fcs=3, 48 | deform_fc_channels=1024): 49 | super(DeformRoIPoolingPack, 50 | self).__init__(spatial_scale, out_size, out_channels, no_trans, 51 | group_size, part_size, sample_per_part, trans_std) 52 | 53 | self.num_offset_fcs = num_offset_fcs 54 | self.deform_fc_channels = deform_fc_channels 55 | 56 | if not no_trans: 57 | seq = [] 58 | ic = self.out_size * self.out_size * self.out_channels 59 | for i in range(self.num_offset_fcs): 60 | if i < self.num_offset_fcs - 1: 61 | oc = self.deform_fc_channels 62 | else: 63 | oc = self.out_size * self.out_size * 2 64 | seq.append(nn.Linear(ic, oc)) 65 | ic = oc 66 | if i < self.num_offset_fcs - 1: 67 | seq.append(nn.ReLU(inplace=True)) 68 | self.offset_fc = nn.Sequential(*seq) 69 | self.offset_fc[-1].weight.data.zero_() 70 | self.offset_fc[-1].bias.data.zero_() 71 | 72 | def forward(self, data, rois): 73 | assert data.size(1) == self.out_channels 74 | if self.no_trans: 75 | offset = data.new_empty(0) 76 | return deform_roi_pooling( 77 | data, rois, offset, self.spatial_scale, self.out_size, 78 | self.out_channels, self.no_trans, self.group_size, 79 | self.part_size, self.sample_per_part, self.trans_std) 80 | else: 81 | n = rois.shape[0] 82 | offset = data.new_empty(0) 83 | x = deform_roi_pooling(data, rois, offset, self.spatial_scale, 84 | self.out_size, self.out_channels, True, 85 | self.group_size, self.part_size, 86 | self.sample_per_part, self.trans_std) 87 | offset = self.offset_fc(x.view(n, -1)) 88 | offset = offset.view(n, 2, self.out_size, self.out_size) 89 | return deform_roi_pooling( 90 | data, rois, offset, self.spatial_scale, self.out_size, 91 | self.out_channels, self.no_trans, self.group_size, 92 | self.part_size, self.sample_per_part, self.trans_std) 93 | 94 | 95 | class ModulatedDeformRoIPoolingPack(DeformRoIPooling): 96 | 97 | def __init__(self, 98 | spatial_scale, 99 | out_size, 100 | out_channels, 101 | no_trans, 102 | group_size=1, 103 | part_size=None, 104 | sample_per_part=4, 105 | trans_std=.0, 106 | num_offset_fcs=3, 107 | num_mask_fcs=2, 108 | deform_fc_channels=1024): 109 | super(ModulatedDeformRoIPoolingPack, self).__init__( 110 | spatial_scale, out_size, out_channels, no_trans, group_size, 111 | part_size, sample_per_part, trans_std) 112 | 113 | self.num_offset_fcs = num_offset_fcs 114 | self.num_mask_fcs = num_mask_fcs 115 | self.deform_fc_channels = deform_fc_channels 116 | 117 | if not no_trans: 118 | offset_fc_seq = [] 119 | ic = self.out_size * self.out_size * self.out_channels 120 | for i in range(self.num_offset_fcs): 121 | if i < self.num_offset_fcs - 1: 122 | oc = self.deform_fc_channels 123 | else: 124 | oc = self.out_size * self.out_size * 2 125 | offset_fc_seq.append(nn.Linear(ic, oc)) 126 | ic = oc 127 | if i < self.num_offset_fcs - 1: 128 | offset_fc_seq.append(nn.ReLU(inplace=True)) 129 | self.offset_fc = nn.Sequential(*offset_fc_seq) 130 | self.offset_fc[-1].weight.data.zero_() 131 | self.offset_fc[-1].bias.data.zero_() 132 | 133 | mask_fc_seq = [] 134 | ic = self.out_size * self.out_size * self.out_channels 135 | for i in range(self.num_mask_fcs): 136 | if i < self.num_mask_fcs - 1: 137 | oc = self.deform_fc_channels 138 | else: 139 | oc = self.out_size * self.out_size 140 | mask_fc_seq.append(nn.Linear(ic, oc)) 141 | ic = oc 142 | if i < self.num_mask_fcs - 1: 143 | mask_fc_seq.append(nn.ReLU(inplace=True)) 144 | else: 145 | mask_fc_seq.append(nn.Sigmoid()) 146 | self.mask_fc = nn.Sequential(*mask_fc_seq) 147 | self.mask_fc[-2].weight.data.zero_() 148 | self.mask_fc[-2].bias.data.zero_() 149 | 150 | def forward(self, data, rois): 151 | assert data.size(1) == self.out_channels 152 | if self.no_trans: 153 | offset = data.new_empty(0) 154 | return deform_roi_pooling( 155 | data, rois, offset, self.spatial_scale, self.out_size, 156 | self.out_channels, self.no_trans, self.group_size, 157 | self.part_size, self.sample_per_part, self.trans_std) 158 | else: 159 | n = rois.shape[0] 160 | offset = data.new_empty(0) 161 | x = deform_roi_pooling(data, rois, offset, self.spatial_scale, 162 | self.out_size, self.out_channels, True, 163 | self.group_size, self.part_size, 164 | self.sample_per_part, self.trans_std) 165 | offset = self.offset_fc(x.view(n, -1)) 166 | offset = offset.view(n, 2, self.out_size, self.out_size) 167 | mask = self.mask_fc(x.view(n, -1)) 168 | mask = mask.view(n, 1, self.out_size, self.out_size) 169 | return deform_roi_pooling( 170 | data, rois, offset, self.spatial_scale, self.out_size, 171 | self.out_channels, self.no_trans, self.group_size, 172 | self.part_size, self.sample_per_part, self.trans_std) * mask 173 | -------------------------------------------------------------------------------- /models/dcn/functions/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from .. import deform_conv_cuda 6 | 7 | 8 | class DeformConvFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, 12 | input, 13 | offset, 14 | weight, 15 | stride=1, 16 | padding=0, 17 | dilation=1, 18 | groups=1, 19 | deformable_groups=1, 20 | im2col_step=64): 21 | if input is not None and input.dim() != 4: 22 | raise ValueError( 23 | "Expected 4D tensor as input, got {}D tensor instead.".format( 24 | input.dim())) 25 | ctx.stride = _pair(stride) 26 | ctx.padding = _pair(padding) 27 | ctx.dilation = _pair(dilation) 28 | ctx.groups = groups 29 | ctx.deformable_groups = deformable_groups 30 | ctx.im2col_step = im2col_step 31 | 32 | ctx.save_for_backward(input, offset, weight) 33 | 34 | output = input.new_empty( 35 | DeformConvFunction._output_size(input, weight, ctx.padding, 36 | ctx.dilation, ctx.stride)) 37 | 38 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones 39 | 40 | if not input.is_cuda: 41 | raise NotImplementedError 42 | else: 43 | cur_im2col_step = min(ctx.im2col_step, input.shape[0]) 44 | assert (input.shape[0] % 45 | cur_im2col_step) == 0, 'im2col step must divide batchsize' 46 | deform_conv_cuda.deform_conv_forward_cuda( 47 | input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1], 48 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], 49 | ctx.padding[1], ctx.padding[0], ctx.dilation[1], 50 | ctx.dilation[0], ctx.groups, ctx.deformable_groups, 51 | cur_im2col_step) 52 | return output 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | input, offset, weight = ctx.saved_tensors 57 | 58 | grad_input = grad_offset = grad_weight = None 59 | 60 | if not grad_output.is_cuda: 61 | raise NotImplementedError 62 | else: 63 | cur_im2col_step = min(ctx.im2col_step, input.shape[0]) 64 | assert (input.shape[0] % 65 | cur_im2col_step) == 0, 'im2col step must divide batchsize' 66 | 67 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 68 | grad_input = torch.zeros_like(input) 69 | grad_offset = torch.zeros_like(offset) 70 | deform_conv_cuda.deform_conv_backward_input_cuda( 71 | input, offset, grad_output, grad_input, 72 | grad_offset, weight, ctx.bufs_[0], weight.size(3), 73 | weight.size(2), ctx.stride[1], ctx.stride[0], 74 | ctx.padding[1], ctx.padding[0], ctx.dilation[1], 75 | ctx.dilation[0], ctx.groups, ctx.deformable_groups, 76 | cur_im2col_step) 77 | 78 | if ctx.needs_input_grad[2]: 79 | grad_weight = torch.zeros_like(weight) 80 | deform_conv_cuda.deform_conv_backward_parameters_cuda( 81 | input, offset, grad_output, 82 | grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), 83 | weight.size(2), ctx.stride[1], ctx.stride[0], 84 | ctx.padding[1], ctx.padding[0], ctx.dilation[1], 85 | ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, 86 | cur_im2col_step) 87 | 88 | return (grad_input, grad_offset, grad_weight, None, None, None, None, 89 | None) 90 | 91 | @staticmethod 92 | def _output_size(input, weight, padding, dilation, stride): 93 | channels = weight.size(0) 94 | output_size = (input.size(0), channels) 95 | for d in range(input.dim() - 2): 96 | in_size = input.size(d + 2) 97 | pad = padding[d] 98 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 99 | stride_ = stride[d] 100 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) 101 | if not all(map(lambda s: s > 0, output_size)): 102 | raise ValueError( 103 | "convolution input is too small (output would be {})".format( 104 | 'x'.join(map(str, output_size)))) 105 | return output_size 106 | 107 | 108 | class ModulatedDeformConvFunction(Function): 109 | 110 | @staticmethod 111 | def forward(ctx, 112 | input, 113 | offset, 114 | mask, 115 | weight, 116 | bias=None, 117 | stride=1, 118 | padding=0, 119 | dilation=1, 120 | groups=1, 121 | deformable_groups=1): 122 | ctx.stride = stride 123 | ctx.padding = padding 124 | ctx.dilation = dilation 125 | ctx.groups = groups 126 | ctx.deformable_groups = deformable_groups 127 | ctx.with_bias = bias is not None 128 | if not ctx.with_bias: 129 | bias = input.new_empty(1) # fake tensor 130 | if not input.is_cuda: 131 | raise NotImplementedError 132 | if weight.requires_grad or mask.requires_grad or offset.requires_grad \ 133 | or input.requires_grad: 134 | ctx.save_for_backward(input, offset, mask, weight, bias) 135 | output = input.new_empty( 136 | ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) 137 | ctx._bufs = [input.new_empty(0), input.new_empty(0)] 138 | deform_conv_cuda.modulated_deform_conv_cuda_forward( 139 | input, weight, bias, ctx._bufs[0], offset, mask, output, 140 | ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, 141 | ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, 142 | ctx.groups, ctx.deformable_groups, ctx.with_bias) 143 | return output 144 | 145 | @staticmethod 146 | def backward(ctx, grad_output): 147 | if not grad_output.is_cuda: 148 | raise NotImplementedError 149 | input, offset, mask, weight, bias = ctx.saved_tensors 150 | grad_input = torch.zeros_like(input) 151 | grad_offset = torch.zeros_like(offset) 152 | grad_mask = torch.zeros_like(mask) 153 | grad_weight = torch.zeros_like(weight) 154 | grad_bias = torch.zeros_like(bias) 155 | deform_conv_cuda.modulated_deform_conv_cuda_backward( 156 | input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], 157 | grad_input, grad_weight, grad_bias, grad_offset, grad_mask, 158 | grad_output, weight.shape[2], weight.shape[3], ctx.stride, 159 | ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, 160 | ctx.groups, ctx.deformable_groups, ctx.with_bias) 161 | if not ctx.with_bias: 162 | grad_bias = None 163 | 164 | return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, 165 | None, None, None, None, None) 166 | 167 | @staticmethod 168 | def _infer_shape(ctx, input, weight): 169 | n = input.size(0) 170 | channels_out = weight.size(0) 171 | height, width = input.shape[2:4] 172 | kernel_h, kernel_w = weight.shape[2:4] 173 | height_out = (height + 2 * ctx.padding - 174 | (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 175 | width_out = (width + 2 * ctx.padding - 176 | (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 177 | return n, channels_out, height_out, width_out 178 | 179 | 180 | deform_conv = DeformConvFunction.apply 181 | modulated_deform_conv = ModulatedDeformConvFunction.apply 182 | -------------------------------------------------------------------------------- /loss/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import cv2 5 | from scipy import ndimage 6 | 7 | 8 | class DiceLoss(nn.Module): 9 | ''' 10 | Loss function from https://arxiv.org/abs/1707.03237, 11 | where iou computation is introduced heatmap manner to measure the 12 | diversity bwtween tow heatmaps. 13 | ''' 14 | 15 | def __init__(self, eps=1e-6): 16 | super(DiceLoss, self).__init__() 17 | self.eps = eps 18 | 19 | def forward(self, pred: torch.Tensor, gt, mask, weights=None): 20 | ''' 21 | pred: one or two heatmaps of shape (N, 1, H, W), 22 | the losses of tow heatmaps are added together. 23 | gt: (N, 1, H, W) 24 | mask: (N, H, W) 25 | ''' 26 | assert pred.dim() == 4, pred.dim() 27 | return self._compute(pred, gt, mask, weights) 28 | 29 | def _compute(self, pred, gt, mask, weights): 30 | if pred.dim() == 4: 31 | pred = pred[:, 0, :, :] 32 | gt = gt[:, 0, :, :] 33 | assert pred.shape == gt.shape 34 | assert pred.shape == mask.shape 35 | if weights is not None: 36 | assert weights.shape == mask.shape 37 | mask = weights * mask 38 | 39 | intersection = (pred * gt * mask).sum() 40 | union = (pred * mask).sum() + (gt * mask).sum() + self.eps 41 | loss = 1 - 2.0 * intersection / union 42 | assert loss <= 1 43 | return loss 44 | 45 | 46 | def dice_loss(input, target, mask,eps=1e-6): 47 | input = input.contiguous().view(input.size()[0], -1) 48 | target = target.contiguous().view(target.size()[0], -1) 49 | mask = mask.contiguous().view(mask.size()[0], -1) 50 | 51 | input = input * mask 52 | target = target * mask 53 | 54 | a = torch.sum(input * target, 1) 55 | b = torch.sum(input * input, 1) + eps 56 | c = torch.sum(target * target, 1) + eps 57 | d = (2 * a) / (b + c) 58 | dice_loss = torch.mean(d) 59 | return 1 - dice_loss 60 | 61 | 62 | class LeakyDiceLoss(nn.Module): 63 | ''' 64 | Variation from DiceLoss. 65 | The coverage and union are computed separately. 66 | ''' 67 | 68 | def __init__(self, eps=1e-6, coverage_scale=5.0): 69 | super(LeakyDiceLoss, self).__init__() 70 | self.eps = eps 71 | self.coverage_scale = coverage_scale 72 | 73 | def forward(self, pred, gt, mask): 74 | if pred.dim() == 4: 75 | pred = pred[:, 0, :, :] 76 | gt = gt[:, 0, :, :] 77 | assert pred.shape == gt.shape 78 | assert pred.shape == mask.shape 79 | 80 | coverage = (pred * mask * gt).sum() / ((gt * mask).sum() + self.eps) 81 | assert coverage <= 1 82 | coverage = 1 - coverage 83 | excede = (pred * mask * gt).sum() / ((pred * mask).sum() + self.eps) 84 | assert excede <= 1 85 | excede = 1 - excede 86 | loss = coverage * self.coverage_scale + excede 87 | return loss, dict(coverage=coverage, excede=excede) 88 | 89 | 90 | class InstanceDiceLoss(DiceLoss): 91 | ''' 92 | DiceLoss normalized on each instance. 93 | Input: 94 | pred: (N, 1, H, W) 95 | gt: (N, 1, H, W) 96 | mask: (N, H, W) 97 | Note: This class assume that input tensors are on gpu, 98 | while cput computation is required to find union areas. 99 | ''' 100 | REDUCTION = ['mean', 'sum', 'none'] 101 | 102 | def __init__(self, threshold=0.3, iou_thresh=0.2, reduction=None, 103 | max_regions=100, eps=1e-6): 104 | nn.Module.__init__(self) 105 | self.threshold = threshold 106 | self.iou_thresh = iou_thresh 107 | self.reduction = reduction 108 | if self.reduction is None: 109 | self.reduction = 'mean' 110 | assert self.reduction in self.REDUCTION 111 | self.max_regions = max_regions 112 | self.eps = eps 113 | 114 | def label(self, tensor_on_gpu, blur=None): 115 | ''' 116 | Args: 117 | tensor_on_gpu: (N, 1, H, W) 118 | blur: Lambda. If exists, each instance will be blured using `blur`. 119 | ''' 120 | tensor = tensor_on_gpu.cpu().detach().numpy() 121 | 122 | instance_maps = [] 123 | instance_counts = [] 124 | for batch_index in range(tensor_on_gpu.shape[0]): 125 | instance = tensor[batch_index] 126 | if blur is not None: 127 | instance = blur(instance) 128 | lable_map, instance_count = ndimage.label(instance[0]) 129 | instance_count = min(self.max_regions, instance_count) 130 | instance_map = [] 131 | for index in range(1, instance_count): 132 | instance = torch.from_numpy( 133 | lable_map == index).to(tensor_on_gpu.device).type(torch.float32) 134 | instance_map.append(instance) 135 | instance_maps.append(instance_map) 136 | return instance_maps, instance_counts 137 | 138 | def iou(self, pred, gt): 139 | overlap = (pred * gt).sum() 140 | return max(overlap / pred.sum(), overlap / gt.sum()) 141 | 142 | def replace_or_add(self, dest, value): 143 | if dest is None: 144 | return value 145 | if value is None: 146 | return dest 147 | return dest + value 148 | 149 | def forward(self, pred, gt, mask): 150 | # pred_label_maps: N, P, H, W, where P is the number of regions. 151 | torch.cuda.synchronize() 152 | pred_label_maps, _ = self.label(pred > self.threshold) 153 | gt_label_maps, _ = self.label(gt) 154 | 155 | losses = [] 156 | for batch_index, gt_instance_maps in enumerate(gt_label_maps): 157 | pred_instance_maps = pred_label_maps[batch_index] 158 | if gt_instance_maps is None or pred_instance_maps is None: 159 | continue 160 | 161 | single_loss = None # loss on a single image in a batch 162 | mask_not_matched = set(range(len(pred_instance_maps))) 163 | for gt_instance_map in gt_instance_maps: 164 | instance_loss = None # loss on a specific gt region 165 | for instance_index, pred_instance_map in enumerate(pred_instance_maps): 166 | if self.iou(pred_instance_map, gt_instance_map) > self.iou_thresh: 167 | match_loss = self._compute( 168 | pred[batch_index][0], gt[batch_index][0], 169 | mask[batch_index] * (pred_instance_map + gt_instance_map > 0).type(torch.float32)) 170 | instance_loss = self.replace_or_add(instance_loss, match_loss) 171 | if instance_index in mask_not_matched: 172 | mask_not_matched.remove(instance_index) 173 | if instance_loss is None: 174 | instance_loss = self._compute( 175 | pred[batch_index][0], gt[batch_index][0], 176 | mask[batch_index] * gt_instance_map) 177 | single_loss = self.replace_or_add(single_loss, instance_loss) 178 | 179 | '''Whether to compute single loss on instances which contrain no positive sample. 180 | if single_loss is None: 181 | single_loss = self._compute( 182 | pred[batch_index][0], gt[batch_index][0], 183 | mask[batch_index]) 184 | ''' 185 | 186 | for instance_index in mask_not_matched: 187 | single_loss = self.replace_or_add( 188 | single_loss, 189 | self._compute( 190 | pred[batch_index][0], gt[batch_index][0], 191 | mask[batch_index] * pred_instance_maps[instance_index])) 192 | 193 | if single_loss is not None: 194 | losses.append(single_loss) 195 | 196 | if self.reduction == 'none': 197 | loss = losses 198 | else: 199 | assert self.reduction in ['sum', 'mean'] 200 | count = len(losses) 201 | loss = sum(losses) 202 | if self.reduction == 'mean': 203 | loss = loss / count 204 | return loss 205 | -------------------------------------------------------------------------------- /utils/DB_postprocesss.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: DB_postprocesss.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import numpy as np 11 | import string 12 | import cv2 13 | from shapely.geometry import Polygon 14 | import pyclipper 15 | 16 | 17 | class DBPostProcess(object): 18 | """ 19 | The post process for Differentiable Binarization (DB). 20 | """ 21 | 22 | def __init__(self, params): 23 | self.thresh = params['thresh'] 24 | self.box_thresh = params['box_thresh'] 25 | self.max_candidates = params['max_candidates'] 26 | self.is_poly = params['is_poly'] 27 | self.unclip_ratio = params['unclip_ratio'] 28 | self.min_size = params['min_size'] 29 | 30 | def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): 31 | ''' 32 | _bitmap: single map with shape (1, H, W), 33 | whose values are binarized as {0, 1} 34 | ''' 35 | 36 | 37 | bitmap = _bitmap 38 | pred = pred 39 | height, width = bitmap.shape 40 | boxes = [] 41 | scores = [] 42 | 43 | contours, _ = cv2.findContours( 44 | (bitmap*255).astype(np.uint8), 45 | cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 46 | 47 | for contour in contours[:self.max_candidates]: 48 | epsilon = 0.001 * cv2.arcLength(contour, True) 49 | approx = cv2.approxPolyDP(contour, epsilon, True) 50 | points = approx.reshape((-1, 2)) 51 | if points.shape[0] < 4: 52 | continue 53 | # _, sside = self.get_mini_boxes(contour) 54 | # if sside < self.min_size: 55 | # continue 56 | score = self.box_score_fast(pred, points.reshape(-1, 2)) 57 | if self.box_thresh > score: 58 | continue 59 | 60 | if points.shape[0] > 2: 61 | box = self.unclip(points, self.unclip_ratio) 62 | if len(box) > 1: 63 | continue 64 | else: 65 | continue 66 | box = box.reshape(-1, 2) 67 | _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) 68 | if sside < self.min_size + 2: 69 | continue 70 | 71 | if not isinstance(dest_width, int): 72 | dest_width = dest_width.item() 73 | dest_height = dest_height.item() 74 | 75 | box[:, 0] = np.clip( 76 | np.round(box[:, 0] / width * dest_width), 0, dest_width) 77 | box[:, 1] = np.clip( 78 | np.round(box[:, 1] / height * dest_height), 0, dest_height) 79 | boxes.append(box.tolist()) 80 | scores.append(score) 81 | return boxes, scores 82 | 83 | def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): 84 | ''' 85 | _bitmap: single map with shape (1, H, W), 86 | whose values are binarized as {0, 1} 87 | ''' 88 | 89 | bitmap = _bitmap 90 | height, width = bitmap.shape 91 | 92 | outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 93 | if len(outs) == 3: 94 | img, contours, _ = outs[0], outs[1], outs[2] 95 | elif len(outs) == 2: 96 | contours, _ = outs[0], outs[1] 97 | 98 | num_contours = min(len(contours), self.max_candidates) 99 | boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) 100 | scores = np.zeros((num_contours, ), dtype=np.float32) 101 | 102 | for index in range(num_contours): 103 | contour = contours[index] 104 | points, sside = self.get_mini_boxes(contour) 105 | if sside < self.min_size: 106 | continue 107 | points = np.array(points) 108 | score = self.box_score_fast(pred, points.reshape(-1, 2)) 109 | if self.box_thresh > score: 110 | continue 111 | 112 | box = self.unclip(points,self.unclip_ratio).reshape(-1, 1, 2) 113 | box, sside = self.get_mini_boxes(box) 114 | if sside < self.min_size + 2: 115 | continue 116 | box = np.array(box) 117 | if not isinstance(dest_width, int): 118 | dest_width = dest_width.item() 119 | dest_height = dest_height.item() 120 | 121 | box[:, 0] = np.clip( 122 | np.round(box[:, 0] / width * dest_width), 0, dest_width) 123 | box[:, 1] = np.clip( 124 | np.round(box[:, 1] / height * dest_height), 0, dest_height) 125 | boxes[index, :, :] = box.astype(np.int16) 126 | scores[index] = score 127 | return boxes, scores 128 | 129 | def unclip(self, box, unclip_ratio=2): 130 | poly = Polygon(box) 131 | distance = poly.area * unclip_ratio / poly.length 132 | offset = pyclipper.PyclipperOffset() 133 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 134 | expanded = np.array(offset.Execute(distance)) 135 | return expanded 136 | 137 | def get_mini_boxes(self, contour): 138 | bounding_box = cv2.minAreaRect(contour) 139 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 140 | 141 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 142 | if points[1][1] > points[0][1]: 143 | index_1 = 0 144 | index_4 = 1 145 | else: 146 | index_1 = 1 147 | index_4 = 0 148 | if points[3][1] > points[2][1]: 149 | index_2 = 2 150 | index_3 = 3 151 | else: 152 | index_2 = 3 153 | index_3 = 2 154 | 155 | box = [ 156 | points[index_1], points[index_2], points[index_3], points[index_4] 157 | ] 158 | return box, min(bounding_box[1]) 159 | 160 | def box_score_fast(self, bitmap, _box): 161 | h, w = bitmap.shape[:2] 162 | box = _box.copy() 163 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) 164 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) 165 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) 166 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) 167 | 168 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 169 | box[:, 0] = box[:, 0] - xmin 170 | box[:, 1] = box[:, 1] - ymin 171 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) 172 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] 173 | 174 | def __call__(self, pred, ratio_list): 175 | pred = pred[:, 0, :, :] 176 | segmentation = pred > self.thresh 177 | 178 | boxes_batch = [] 179 | score_batch = [] 180 | for batch_index in range(pred.shape[0]): 181 | height, width = pred.shape[-2:] 182 | if(self.is_poly): 183 | tmp_boxes, tmp_scores = self.polygons_from_bitmap( 184 | pred[batch_index], segmentation[batch_index], width, height) 185 | 186 | boxes = [] 187 | score = [] 188 | for k in range(len(tmp_boxes)): 189 | if tmp_scores[k] > self.box_thresh: 190 | boxes.append(tmp_boxes[k]) 191 | score.append(tmp_scores[k]) 192 | if len(boxes) > 0: 193 | ratio_w, ratio_h = ratio_list[batch_index] 194 | for i in range(len(boxes)): 195 | boxes[i] = np.array(boxes[i]) 196 | boxes[i][:, 0] = boxes[i][:, 0] * ratio_w 197 | boxes[i][:, 1] = boxes[i][:, 1] * ratio_h 198 | 199 | boxes_batch.append(boxes) 200 | score_batch.append(score) 201 | else: 202 | tmp_boxes, tmp_scores = self.boxes_from_bitmap( 203 | pred[batch_index], segmentation[batch_index], width, height) 204 | 205 | boxes = [] 206 | score = [] 207 | for k in range(len(tmp_boxes)): 208 | if tmp_scores[k] > self.box_thresh: 209 | boxes.append(tmp_boxes[k]) 210 | score.append(tmp_scores[k]) 211 | if len(boxes) > 0: 212 | boxes = np.array(boxes) 213 | 214 | ratio_w, ratio_h = ratio_list[batch_index] 215 | boxes[:, :, 0] = boxes[:, :, 0] * ratio_w 216 | boxes[:, :, 1] = boxes[:, :, 1] * ratio_h 217 | 218 | boxes_batch.append(boxes) 219 | score_batch.append(score) 220 | return boxes_batch,score_batch -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: train.py.py 5 | @time: 2020/04/28 6 | """ 7 | import sys 8 | sys.path.append('/home/aistudio/external-libraries') 9 | import os 10 | import torch 11 | import torch.nn as nn 12 | import yaml 13 | import argparse 14 | import numpy as np 15 | olderr = np.seterr(all='ignore') 16 | from models.DBNet import DBNet 17 | from torch.autograd import Variable 18 | from loss.loss import L1BalanceCELoss 19 | from dataloader.dataload import DataLoader 20 | from utils.Logger import Logger 21 | from utils.metrics import runningScore 22 | from utils.model_eval import val 23 | from utils.tools import * 24 | from utils.set_optimizer import * 25 | 26 | torch.backends.cudnn.benchmark = False 27 | torch.backends.cudnn.deterministic = True 28 | 29 | def updateBN(model,config): 30 | tag = 0 31 | for m in model.modules(): 32 | if(tag>69): 33 | break 34 | if isinstance(m, nn.BatchNorm2d): 35 | if hasattr(m.weight, 'data'): 36 | m.weight.grad.data.add_(config['train']['sr_lr']*torch.sign(m.weight.data)) #L1正则 37 | tag+=1 38 | 39 | def set_seed(seed): 40 | import numpy as np 41 | import random 42 | import torch 43 | random.seed(seed) 44 | np.random.seed(seed) 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed(seed) 47 | torch.cuda.manual_seed_all(seed) 48 | 49 | GLOBAL_WORKER_ID = None 50 | GLOBAL_SEED = 2020 51 | 52 | def worker_init_fn(worker_id): 53 | global GLOBAL_WORKER_ID 54 | GLOBAL_WORKER_ID = worker_id 55 | set_seed(GLOBAL_SEED + worker_id) 56 | 57 | def train_net(config): 58 | os.environ["CUDA_VISIBLE_DEVICES"] = config['train']['gpu_id'] 59 | data_loader = DataLoader(config) 60 | train_loader = torch.utils.data.DataLoader( 61 | data_loader, 62 | batch_size=config['train']['batch_size'], 63 | shuffle=True, 64 | num_workers=config['train']['num_workers'], 65 | worker_init_fn = worker_init_fn, 66 | drop_last=True, 67 | pin_memory=False) 68 | 69 | start_epoch = 0 70 | running_metric_binary = runningScore(2) 71 | 72 | if not (os.path.exists(config['train']['checkpoints'])): 73 | os.mkdir(config['train']['checkpoints']) 74 | checkpoints = os.path.join(config['train']['checkpoints'],"DB_%s_bs_%d_ep_%d" % (config['train']['backbone'], 75 | config['train']['batch_size'], config['train']['n_epoch'])) 76 | if not (os.path.exists(checkpoints)): 77 | os.mkdir(checkpoints) 78 | 79 | 80 | model = DBNet(config).cuda() 81 | criterion = L1BalanceCELoss() 82 | optimizer = torch.optim.SGD(model.parameters(), lr=config['train']['base_lr'], momentum=0.99, weight_decay=5e-4) 83 | 84 | if config['train']['restore']: 85 | print('Resuming from checkpoint.') 86 | assert os.path.isfile(config['train']['resume']), 'Error: no checkpoint directory found!' 87 | checkpoint = torch.load(config['train']['resume']) 88 | start_epoch = checkpoint['epoch'] 89 | model.load_state_dict(checkpoint['state_dict']) 90 | optimizer.load_state_dict(checkpoint['optimizer']) 91 | log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['train']['backbone'], resume=True) 92 | else: 93 | print('Training from scratch.') 94 | log_write = Logger(os.path.join(checkpoints,'log.txt'), title=config['train']['backbone']) 95 | log_write.set_names([' epoch', 'Total loss', ' Bce loss', 'Thresh loss', ' L1 loss', 'Binary Acc', 'Binary IoU', ' rescall',' precision',' hmean']) 96 | max_hmean = -1 97 | for epoch in range(start_epoch,config['train']['n_epoch']): 98 | model.train() 99 | 100 | bce_loss_list = [] 101 | thresh_loss_list = [] 102 | l1_loss_list = [] 103 | total_loss_list = [] 104 | 105 | if(config['train']['decay_method']=='e_decay'): 106 | adjust_learning_rate_poly(config['train']['base_lr'], optimizer, epoch, max_epoch=config['train']['n_epoch'], factor=0.9) 107 | else: 108 | adjust_learning_rate(config, optimizer, epoch,config['train']['gama']) 109 | 110 | for batch_idx, (imgs, gts, gt_masks, thresh_maps, thresh_masks) in enumerate(train_loader): 111 | imgs = Variable(imgs.cuda()) 112 | gts = Variable(gts.cuda()) 113 | gt_masks = Variable(gt_masks.cuda()) 114 | thresh_maps = Variable(thresh_maps.cuda()) 115 | thresh_masks = Variable(thresh_masks.cuda()) 116 | batch = {} 117 | batch['gt'] = gts 118 | batch['mask'] = gt_masks 119 | batch['thresh_map'] = thresh_maps 120 | batch['thresh_mask'] = thresh_masks 121 | 122 | pre = model(imgs) 123 | loss, metrics = criterion(pre, batch) 124 | 125 | optimizer.zero_grad() 126 | loss.backward() 127 | if(config['train']['use_sr']): 128 | updateBN(model,config) 129 | optimizer.step() 130 | 131 | score_binary = cal_binary_score(pre['binary'], gts, gt_masks.unsqueeze(1), running_metric_binary) 132 | 133 | bce_loss_list.append(metrics['bce_loss'].item()) 134 | thresh_loss_list.append(metrics['thresh_loss'].item()) 135 | l1_loss_list.append(metrics['l1_loss'].item()) 136 | total_loss_list.append(loss.item()) 137 | if batch_idx % config['train']['show_step'] == 0: 138 | if(config['train']['print_format']=='linux'): 139 | headers = ['epoch/epochs','batch/batchs' ,'TotalLoss' ,'BceLoss',' ThreshLoss','L1Loss', 'Binary Acc','Binary IoU', 'Lr Rate'] 140 | show_item = [[str(epoch)+'/'+str(config['train']['n_epoch']), 141 | str(batch_idx + 1)+'/'+str(len(train_loader)), 142 | get_str(np.mean(total_loss_list)), 143 | get_str(np.mean(bce_loss_list)), 144 | get_str(np.mean(thresh_loss_list)), 145 | get_str(np.mean(l1_loss_list)), 146 | get_str(score_binary['Mean Acc']), 147 | get_str(score_binary['Mean IoU']), 148 | get_str(optimizer.param_groups[0]['lr']) 149 | ]] 150 | print_table(headers,show_item,type_str='train') 151 | else: 152 | output_log = '({epoch}/{epochs}/{batch}/{size}) | TotalLoss: {total_loss:.4f} | BceLoss: {bce_loss:.4f} | ThreshLoss: {thresh_loss: .4f} | L1Loss: {l1_loss: .4f} | Binary Acc: {bin_acc: .4f} | Binary IoU: {bin_iou: .4f} | Lr: {lr: .4f}'.format( 153 | epoch=epoch, 154 | epochs=config['train']['n_epoch'] , 155 | batch=batch_idx + 1, 156 | size=len(train_loader), 157 | total_loss=np.mean(total_loss_list), 158 | bce_loss=np.mean(bce_loss_list), 159 | thresh_loss=np.mean(thresh_loss_list), 160 | l1_loss=np.mean(l1_loss_list), 161 | bin_acc=score_binary['Mean Acc'], 162 | bin_iou=score_binary['Mean IoU'], 163 | lr=optimizer.param_groups[0]['lr'] 164 | ) 165 | print(output_log) 166 | sys.stdout.flush() 167 | 168 | if( epoch > config['train']['start_val_epoch']): 169 | result_dict = val(model,config) 170 | rescall,precision,hmean = result_dict['recall'],result_dict['precision'],result_dict['hmean'] 171 | print('epoch:',epoch,'rescall:',rescall,'precision:',precision,'hmean:',hmean) 172 | else: 173 | rescall = 0 174 | precision = 0 175 | hmean = 0 176 | log_write.append([epoch, np.mean(total_loss_list), np.mean(bce_loss_list), np.mean(thresh_loss_list), 177 | np.mean(l1_loss_list), score_binary['Mean Acc'], score_binary['Mean IoU'], 178 | rescall,precision,hmean]) 179 | if(hmean > max_hmean and config['train']['start_val_epoch'] < config['train']['n_epoch']): 180 | max_hmean = hmean 181 | save_checkpoint({ 182 | 'epoch': epoch + 1, 183 | 'state_dict': model.state_dict(), 184 | 'lr': config['train']['base_lr'], 185 | 'optimizer': optimizer.state_dict(), 186 | }, checkpoint=checkpoints,filename='best_model.pth.tar') 187 | 188 | save_checkpoint({ 189 | 'epoch': epoch + 1, 190 | 'state_dict': model.state_dict(), 191 | 'lr': config['train']['base_lr'], 192 | 'optimizer': optimizer.state_dict(), 193 | }, checkpoint=checkpoints) 194 | 195 | 196 | 197 | if __name__ == '__main__': 198 | stream = open('config.yaml', 'r', encoding='utf-8') 199 | config = yaml.load(stream,Loader=yaml.FullLoader) 200 | train_net(config) -------------------------------------------------------------------------------- /dataloader/random_thansform.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: random_thansform.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import cv2 11 | import numpy as np 12 | import imgaug.augmenters as aug_img 13 | import imgaug 14 | import random 15 | def solve_polys(polys): 16 | len_max = 0 17 | for poly in polys: 18 | if(len(poly)//2>len_max): 19 | len_max = len(poly)//2 20 | new_polys = [] 21 | for poly in polys: 22 | new_poly = [] 23 | if(len(poly)//2 x + w: 69 | return False 70 | if poly[:, 1].min() < y or poly[:, 1].max() > y + h: 71 | return False 72 | return True 73 | 74 | def is_poly_outside_rect(self, poly, x, y, w, h): 75 | poly = np.array(poly) 76 | if poly[:, 0].max() < x or poly[:, 0].min() > x + w: 77 | return True 78 | if poly[:, 1].max() < y or poly[:, 1].min() > y + h: 79 | return True 80 | return False 81 | 82 | def split_regions(self, axis): 83 | regions = [] 84 | min_axis = 0 85 | for i in range(1, axis.shape[0]): 86 | if axis[i] != axis[i - 1] + 1: 87 | region = axis[min_axis:i] 88 | min_axis = i 89 | regions.append(region) 90 | return regions 91 | 92 | def random_select(self, axis, max_size): 93 | xx = np.random.choice(axis, size=2) 94 | xmin = np.min(xx) 95 | xmax = np.max(xx) 96 | xmin = np.clip(xmin, 0, max_size - 1) 97 | xmax = np.clip(xmax, 0, max_size - 1) 98 | return xmin, xmax 99 | 100 | def region_wise_random_select(self, regions, max_size): 101 | selected_index = list(np.random.choice(len(regions), 2)) 102 | selected_values = [] 103 | for index in selected_index: 104 | axis = regions[index] 105 | xx = int(np.random.choice(axis, size=1)) 106 | selected_values.append(xx) 107 | xmin = min(selected_values) 108 | xmax = max(selected_values) 109 | return xmin, xmax 110 | 111 | def crop_area(self, img, polys): 112 | h, w, _ = img.shape 113 | h_array = np.zeros(h, dtype=np.int32) 114 | w_array = np.zeros(w, dtype=np.int32) 115 | for points in polys: 116 | points = np.round(points, decimals=0).astype(np.int32) 117 | minx = np.min(points[:, 0]) 118 | maxx = np.max(points[:, 0]) 119 | w_array[minx:maxx] = 1 120 | miny = np.min(points[:, 1]) 121 | maxy = np.max(points[:, 1]) 122 | h_array[miny:maxy] = 1 123 | # ensure the cropped area not across a text 124 | h_axis = np.where(h_array == 0)[0] 125 | w_axis = np.where(w_array == 0)[0] 126 | 127 | if len(h_axis) == 0 or len(w_axis) == 0: 128 | return 0, 0, w, h 129 | 130 | h_regions = self.split_regions(h_axis) 131 | w_regions = self.split_regions(w_axis) 132 | 133 | for i in range(self.max_tries): 134 | if len(w_regions) > 1: 135 | xmin, xmax = self.region_wise_random_select(w_regions, w) 136 | else: 137 | xmin, xmax = self.random_select(w_axis, w) 138 | if len(h_regions) > 1: 139 | ymin, ymax = self.region_wise_random_select(h_regions, h) 140 | else: 141 | ymin, ymax = self.random_select(h_axis, h) 142 | 143 | if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h: 144 | # area too small 145 | continue 146 | num_poly_in_rect = 0 147 | for poly in polys: 148 | if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin): 149 | num_poly_in_rect += 1 150 | break 151 | 152 | if num_poly_in_rect > 0: 153 | return xmin, ymin, xmax - xmin, ymax - ymin 154 | 155 | return 0, 0, w, h 156 | 157 | class Random_Augment(): 158 | def __init__(self): 159 | super(Random_Augment,self).__init__() 160 | self.random_crop_data = RandomCropData() 161 | 162 | def augment_poly(self,aug, img_shape, poly): 163 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] 164 | keypoints = aug.augment_keypoints([imgaug.KeypointsOnImage(keypoints, shape=img_shape[:2])])[0].keypoints 165 | poly = [(p.x, p.y) for p in keypoints] 166 | return np.array(poly) 167 | 168 | def random_rotate(self,img, polys, random_range): 169 | angle = np.random.randint(random_range[0], random_range[1]) 170 | aug_bin = aug_img.Sequential([aug_img.Affine(rotate=angle)]) 171 | img = aug_bin.augment_image(img) 172 | new_polys = [] 173 | for poly in polys: 174 | poly = self.augment_poly(aug_bin, img.shape, poly) 175 | poly = np.maximum(poly, 0) 176 | new_polys.append(poly) 177 | return img, new_polys 178 | 179 | def random_scale(self,img, polys, min_size): 180 | polys,len_max = solve_polys(polys) 181 | h, w = img.shape[0:2] 182 | new_polys = [] 183 | for poly in polys: 184 | poly = np.asarray(poly) 185 | poly = poly / ([w * 1.0, h * 1.0] * len_max) 186 | new_polys.append(poly) 187 | new_polys = np.array(new_polys) 188 | if max(h, w) > 1280: 189 | scale = 1280.0 / max(h, w) 190 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 191 | h, w = img.shape[0:2] 192 | random_scale = np.array([0.5, 1.0, 2.0, 3.0]) 193 | scale = np.random.choice(random_scale) 194 | if min(h, w) * scale <= min_size: 195 | scale = (min_size + 10) * 1.0 / min(h, w) 196 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 197 | new_polys = np.reshape(new_polys * ([img.shape[1], img.shape[0]] * len_max), 198 | (new_polys.shape[0], polys.shape[1] // 2, 2)) 199 | return img, new_polys 200 | 201 | def random_flip(self,img, polys): 202 | if (np.random.rand(1)[0] > 0.5): 203 | aug_bin = aug_img.Sequential([aug_img.Fliplr((1))]) 204 | img = aug_bin.augment_image(img) 205 | new_polys = [] 206 | for poly in polys: 207 | poly = self.augment_poly(aug_bin, img.shape, poly) 208 | poly = np.maximum(poly, 0) 209 | new_polys.append(poly) 210 | else: 211 | new_polys = polys 212 | return img, new_polys 213 | 214 | def random_crop_db(self,img, polys, dont_care): 215 | img, new_polys,new_dotcare = self.random_crop_data.process(img, polys, dont_care) 216 | return img, new_polys,new_dotcare 217 | 218 | def random_crop_pse(self,imgs, img_size=(640, 640)): 219 | h, w = imgs[0].shape[0:2] 220 | th, tw = img_size 221 | if w == tw and h == th: 222 | return imgs 223 | 224 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 225 | tl = np.min(np.where(imgs[1] > 0), axis=1) - img_size 226 | tl[tl < 0] = 0 227 | br = np.max(np.where(imgs[1] > 0), axis=1) - img_size 228 | br[br < 0] = 0 229 | br[0] = min(br[0], h - th) 230 | br[1] = min(br[1], w - tw) 231 | 232 | i = random.randint(tl[0], br[0]) 233 | j = random.randint(tl[1], br[1]) 234 | else: 235 | i = random.randint(0, h - th) 236 | j = random.randint(0, w - tw) 237 | 238 | # return i, j, th, tw 239 | for idx in range(len(imgs)): 240 | if len(imgs[idx].shape) == 3: 241 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] 242 | else: 243 | imgs[idx] = imgs[idx][i:i + th, j:j + tw] 244 | return imgs 245 | 246 | 247 | -------------------------------------------------------------------------------- /pruned/train_fintune.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: train.py.py 5 | @time: 2020/04/28 6 | """ 7 | import sys 8 | 9 | sys.path.append('/home/aistudio/external-libraries') 10 | sys.path.append('./') 11 | import os 12 | import torch 13 | import torch.nn as nn 14 | import yaml 15 | import argparse 16 | import numpy as np 17 | 18 | olderr = np.seterr(all='ignore') 19 | from models.DBNet import DBNet 20 | from torch.autograd import Variable 21 | from loss.loss import L1BalanceCELoss 22 | from dataloader.dataload import DataLoader 23 | from utils.Logger import Logger 24 | from utils.metrics import runningScore 25 | from utils.model_eval import val 26 | from utils.tools import * 27 | from utils.set_optimizer import * 28 | from pruned.get_pruned_model import load_prune_model 29 | 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.deterministic = True 32 | 33 | 34 | def set_seed(seed): 35 | import numpy as np 36 | import random 37 | import torch 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | 44 | 45 | GLOBAL_WORKER_ID = None 46 | GLOBAL_SEED = 2020 47 | 48 | 49 | def worker_init_fn(worker_id): 50 | global GLOBAL_WORKER_ID 51 | GLOBAL_WORKER_ID = worker_id 52 | set_seed(GLOBAL_SEED + worker_id) 53 | 54 | 55 | def train_net(config): 56 | os.environ["CUDA_VISIBLE_DEVICES"] = config['pruned']['gpu_id'] 57 | data_loader = DataLoader(config) 58 | train_loader = torch.utils.data.DataLoader( 59 | data_loader, 60 | batch_size=config['train']['batch_size'], 61 | shuffle=True, 62 | num_workers=config['train']['num_workers'], 63 | worker_init_fn=worker_init_fn, 64 | drop_last=True, 65 | pin_memory=False) 66 | 67 | start_epoch = 0 68 | running_metric_binary = runningScore(2) 69 | 70 | if not (os.path.exists(config['train']['checkpoints'])): 71 | os.mkdir(config['train']['checkpoints']) 72 | checkpoints = os.path.join(config['pruned']['save_checkpoints'], "DB_%s_bs_%d_ep_%d" % (config['train']['backbone'], 73 | config['train'][ 74 | 'batch_size'], 75 | config['train']['n_epoch'])) 76 | if not (os.path.exists(checkpoints)): 77 | os.mkdir(checkpoints) 78 | 79 | model = DBNet(config) 80 | criterion = L1BalanceCELoss() 81 | optimizer = torch.optim.SGD(model.parameters(), lr=config['pruned']['finetune_lr'], momentum=0.99, 82 | weight_decay=5e-4) 83 | 84 | if config['pruned']['restore']: 85 | print('Resuming from checkpoint.') 86 | assert os.path.isfile(config['pruned']['resume']), 'Error: no checkpoint directory found!' 87 | checkpoint = torch.load(config['pruned']['resume']) 88 | start_epoch = checkpoint['epoch'] 89 | model = load_prune_model(model, config['pruned']['checkpoints_dict']).cuda() 90 | model.load_state_dict(checkpoint['state_dict']) 91 | optimizer.load_state_dict(checkpoint['optimizer']) 92 | log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['train']['backbone'], resume=True) 93 | else: 94 | print('Training from scratch.') 95 | model_dict = torch.load(config['pruned']['pruned_checkpoints']) 96 | model = load_prune_model(model, config['pruned']['checkpoints_dict']).cuda() 97 | print(model) 98 | try: 99 | model.load_state_dict(model_dict) 100 | except: 101 | state = model.state_dict() 102 | for key in state.keys(): 103 | state[key] = model_dict['module.' + key] 104 | model.load_state_dict(state) 105 | log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['train']['backbone']) 106 | log_write.set_names( 107 | [' epoch', 'Total loss', ' Bce loss', 'Thresh loss', ' L1 loss', 'Binary Acc', 'Binary IoU', 108 | ' rescall', ' precision', ' hmean']) 109 | max_hmean = -1 110 | for epoch in range(start_epoch, config['pruned']['n_epoch']): 111 | model.train() 112 | 113 | bce_loss_list = [] 114 | thresh_loss_list = [] 115 | l1_loss_list = [] 116 | total_loss_list = [] 117 | 118 | if (config['train']['decay_method'] == 'e_decay'): 119 | adjust_learning_rate_poly(config['pruned']['finetune_lr'], optimizer, epoch, 120 | max_epoch=config['pruned']['n_epoch'], factor=0.9) 121 | else: 122 | adjust_learning_rate(config, optimizer, epoch, config['train']['gama']) 123 | 124 | for batch_idx, (imgs, gts, gt_masks, thresh_maps, thresh_masks) in enumerate(train_loader): 125 | imgs = Variable(imgs.cuda()) 126 | gts = Variable(gts.cuda()) 127 | gt_masks = Variable(gt_masks.cuda()) 128 | thresh_maps = Variable(thresh_maps.cuda()) 129 | thresh_masks = Variable(thresh_masks.cuda()) 130 | batch = {} 131 | batch['gt'] = gts 132 | batch['mask'] = gt_masks 133 | batch['thresh_map'] = thresh_maps 134 | batch['thresh_mask'] = thresh_masks 135 | 136 | pre = model(imgs) 137 | loss, metrics = criterion(pre, batch) 138 | 139 | optimizer.zero_grad() 140 | loss.backward() 141 | optimizer.step() 142 | 143 | score_binary = cal_binary_score(pre['binary'], gts, gt_masks.unsqueeze(1), running_metric_binary) 144 | 145 | bce_loss_list.append(metrics['bce_loss'].item()) 146 | thresh_loss_list.append(metrics['thresh_loss'].item()) 147 | l1_loss_list.append(metrics['l1_loss'].item()) 148 | total_loss_list.append(loss.item()) 149 | if batch_idx % config['train']['show_step'] == 0: 150 | if (config['train']['print_format'] == 'linux'): 151 | headers = ['epoch/epochs', 'batch/batchs', 'TotalLoss', 'BceLoss', ' ThreshLoss', 'L1Loss', 152 | 'Binary Acc', 'Binary IoU', 'Lr Rate'] 153 | show_item = [[str(epoch) + '/' + str(config['pruned']['n_epoch']), 154 | str(batch_idx + 1) + '/' + str(len(train_loader)), 155 | get_str(np.mean(total_loss_list)), 156 | get_str(np.mean(bce_loss_list)), 157 | get_str(np.mean(thresh_loss_list)), 158 | get_str(np.mean(l1_loss_list)), 159 | get_str(score_binary['Mean Acc']), 160 | get_str(score_binary['Mean IoU']), 161 | get_str(optimizer.param_groups[0]['lr']) 162 | ]] 163 | print_table(headers, show_item, type_str='train') 164 | else: 165 | output_log = '({epoch}/{epochs}/{batch}/{size}) | TotalLoss: {total_loss:.4f} | BceLoss: {bce_loss:.4f} | ThreshLoss: {thresh_loss: .4f} | L1Loss: {l1_loss: .4f} | Binary Acc: {bin_acc: .4f} | Binary IoU: {bin_iou: .4f} | Lr: {lr: .4f}'.format( 166 | epoch=epoch, 167 | epochs=config['pruned']['n_epoch'], 168 | batch=batch_idx + 1, 169 | size=len(train_loader), 170 | total_loss=np.mean(total_loss_list), 171 | bce_loss=np.mean(bce_loss_list), 172 | thresh_loss=np.mean(thresh_loss_list), 173 | l1_loss=np.mean(l1_loss_list), 174 | bin_acc=score_binary['Mean Acc'], 175 | bin_iou=score_binary['Mean IoU'], 176 | lr=optimizer.param_groups[0]['lr'] 177 | ) 178 | print(output_log) 179 | sys.stdout.flush() 180 | 181 | if (epoch > config['pruned']['start_val_epoch']): 182 | result_dict = val(model, config) 183 | rescall, precision, hmean = result_dict['recall'], result_dict['precision'], result_dict['hmean'] 184 | print('epoch:', epoch, 'rescall:', rescall, 'precision:', precision, 'hmean:', hmean) 185 | else: 186 | rescall = 0 187 | precision = 0 188 | hmean = 0 189 | log_write.append([epoch, np.mean(total_loss_list), np.mean(bce_loss_list), np.mean(thresh_loss_list), 190 | np.mean(l1_loss_list), score_binary['Mean Acc'], score_binary['Mean IoU'], 191 | rescall, precision, hmean]) 192 | if (hmean > max_hmean and config['pruned']['start_val_epoch'] < config['pruned']['n_epoch']): 193 | max_hmean = hmean 194 | save_checkpoint({ 195 | 'epoch': epoch + 1, 196 | 'state_dict': model.state_dict(), 197 | 'lr': optimizer.param_groups[0]['lr'], 198 | 'optimizer': optimizer.state_dict(), 199 | }, checkpoint=checkpoints, filename='best_model.pth.tar') 200 | 201 | save_checkpoint({ 202 | 'epoch': epoch + 1, 203 | 'state_dict': model.state_dict(), 204 | 'lr': optimizer.param_groups[0]['lr'], 205 | 'optimizer': optimizer.state_dict(), 206 | }, checkpoint=checkpoints) 207 | 208 | 209 | if __name__ == '__main__': 210 | stream = open('./config.yaml', 'r', encoding='utf-8') 211 | config = yaml.load(stream, Loader=yaml.FullLoader) 212 | train_net(config) -------------------------------------------------------------------------------- /pruned/prune.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: prune.py 5 | @time: 2020/07/17 6 | """ 7 | """ 8 | #!-*- coding=utf-8 -*- 9 | @author: BADBADBADBADBOY 10 | @contact: 2441124901@qq.com 11 | @software: PyCharm Community Edition 12 | @file: prune.py 13 | @time: 2020/6/27 10:23 14 | 15 | """ 16 | import sys 17 | 18 | sys.path.append('/home/aistudio/external-libraries') 19 | sys.path.append('./') 20 | import yaml 21 | from models.DBNet import DBNet 22 | import torch 23 | import torch.nn as nn 24 | import numpy as np 25 | import collections 26 | import torchvision.transforms as transforms 27 | import cv2 28 | import os 29 | import argparse 30 | import math 31 | from PIL import Image 32 | from torch.autograd import Variable 33 | 34 | 35 | def prune(config): 36 | os.environ["CUDA_VISIBLE_DEVICES"] = config['pruned']['gpu_id'] 37 | 38 | model = DBNet(config).cuda() 39 | model_dict = torch.load(config['pruned']['checkpoints'])['state_dict'] 40 | state = model.state_dict() 41 | for key in state.keys(): 42 | if key in model_dict.keys(): 43 | state[key] = model_dict[key] 44 | model.load_state_dict(state) 45 | 46 | 47 | bn_weights = [] 48 | for m in model.modules(): 49 | if (isinstance(m, nn.BatchNorm2d)): 50 | bn_weights.append(m.weight.data.abs().clone()) 51 | bn_weights = torch.cat(bn_weights, 0) 52 | 53 | sort_result, sort_index = torch.sort(bn_weights) 54 | 55 | thresh_index = int(config['pruned']['cut_percent'] * bn_weights.shape[0]) 56 | 57 | if (thresh_index == bn_weights.shape[0]): 58 | thresh_index = bn_weights.shape[0] - 1 59 | 60 | prued = 0 61 | prued_mask = [] 62 | bn_index = [] 63 | conv_index = [] 64 | remain_channel_nums = [] 65 | for k, m in enumerate(model.modules()): 66 | if (k > 69): 67 | break 68 | if (isinstance(m, nn.BatchNorm2d)): 69 | bn_weight = m.weight.data.clone() 70 | mask = bn_weight.abs().gt(sort_result[thresh_index]) 71 | remain_channel = mask.sum() 72 | 73 | if (remain_channel == 0): 74 | remain_channel = 1 75 | mask[int(torch.argmax(bn_weight))] = 1 76 | 77 | v = 0 78 | n = 1 79 | if (remain_channel % config['pruned']['base_num'] != 0): 80 | if (remain_channel > config['pruned']['base_num']): 81 | while (v < remain_channel): 82 | n += 1 83 | v = config['pruned']['base_num'] * n 84 | if (remain_channel - (v - config['pruned']['base_num']) < v - remain_channel): 85 | remain_channel = v - config['pruned']['base_num'] 86 | else: 87 | remain_channel = v 88 | if (remain_channel > bn_weight.size()[0]): 89 | remain_channel = bn_weight.size()[0] 90 | remain_channel = torch.tensor(remain_channel) 91 | result, index = torch.sort(bn_weight) 92 | mask = bn_weight.abs().ge(result[-remain_channel]) 93 | 94 | remain_channel_nums.append(int(mask.sum())) 95 | prued_mask.append(mask) 96 | bn_index.append(k) 97 | prued += mask.shape[0] - mask.sum() 98 | elif (isinstance(m, nn.Conv2d)): 99 | conv_index.append(k) 100 | print('remain_channel_nums', remain_channel_nums) 101 | print('total_prune_ratio:', float(prued) / bn_weights.shape[0]) 102 | print('bn_index', bn_index) 103 | 104 | new_model = DBNet(config).cuda() 105 | 106 | merge1_index = [3, 12, 18] 107 | merge2_index = [25, 28, 34] 108 | merge3_index = [41, 44, 50] 109 | merge4_index = [57, 60, 66] 110 | 111 | index_0 = [] 112 | for item in merge1_index: 113 | index_0.append(bn_index.index(item)) 114 | mask1 = prued_mask[index_0[0]] | prued_mask[index_0[1]] | prued_mask[index_0[2]] 115 | 116 | index_1 = [] 117 | for item in merge2_index: 118 | index_1.append(bn_index.index(item)) 119 | mask2 = prued_mask[index_1[0]] | prued_mask[index_1[1]] | prued_mask[index_1[2]] 120 | 121 | index_2 = [] 122 | for item in merge3_index: 123 | index_2.append(bn_index.index(item)) 124 | mask3 = prued_mask[index_2[0]] | prued_mask[index_2[1]] | prued_mask[index_2[2]] 125 | 126 | index_3 = [] 127 | for item in merge4_index: 128 | index_3.append(bn_index.index(item)) 129 | mask4 = prued_mask[index_3[0]] | prued_mask[index_3[1]] | prued_mask[index_3[2]] 130 | 131 | for index in index_0: 132 | prued_mask[index] = mask1 133 | 134 | for index in index_1: 135 | prued_mask[index] = mask2 136 | 137 | for index in index_2: 138 | prued_mask[index] = mask3 139 | 140 | for index in index_3: 141 | prued_mask[index] = mask4 142 | 143 | print(model) 144 | 145 | ############################################################## 146 | index_bn = 0 147 | index_conv = 0 148 | 149 | bn_mask = [] 150 | conv_in_mask = [] 151 | conv_out_mask = [] 152 | tag = 0 153 | for m in new_model.modules(): 154 | if (tag > 69): 155 | break 156 | if (isinstance(m, nn.BatchNorm2d)): 157 | m.num_features = prued_mask[index_bn].sum() 158 | bn_mask.append(prued_mask[index_bn]) 159 | index_bn += 1 160 | elif (isinstance(m, nn.Conv2d)): 161 | if (index_conv == 0): 162 | m.in_channels = 3 163 | conv_in_mask.append(torch.ones(3)) 164 | else: 165 | m.in_channels = prued_mask[index_conv - 1].sum() 166 | conv_in_mask.append(prued_mask[index_conv - 1]) 167 | m.out_channels = prued_mask[index_conv].sum() 168 | conv_out_mask.append(prued_mask[index_conv]) 169 | index_conv += 1 170 | tag += 1 171 | 172 | conv_change_index = [27, 43, 59] # 173 | change_conv_bn_index = [18, 34, 50] # 174 | tag = 0 175 | for m in new_model.modules(): 176 | if (tag > 69): 177 | break 178 | if (isinstance(m, nn.Conv2d)): 179 | if (tag in conv_change_index): 180 | index = conv_change_index.index(tag) 181 | index = change_conv_bn_index[index] 182 | index = bn_index.index(index) 183 | mask = prued_mask[index] 184 | conv_in_mask[index + 3] = mask 185 | m.in_channels = mask.sum() 186 | tag += 1 187 | 188 | ############################################################# 189 | bn_i = 0 190 | conv_i = 0 191 | scale_i = 0 192 | scale_mask = [mask4, mask3, mask2, mask1] 193 | # scale = [70,86,90,94] # FPN 194 | scale = config['pruned']['scale'] # DB 195 | for [m0, m1] in zip(model.modules(), new_model.modules()): 196 | if (scale_i > 69): 197 | if isinstance(m0, nn.Conv2d): 198 | if (scale_i in scale): 199 | index = scale.index(scale_i) 200 | m1.in_channels = scale_mask[index].sum() 201 | idx0 = np.squeeze(np.argwhere(np.asarray(scale_mask[index].cpu().numpy()))) 202 | idx1 = np.squeeze(np.argwhere(np.asarray(torch.ones(256).cpu().numpy()))) 203 | if idx0.size == 1: 204 | idx0 = np.resize(idx0, (1,)) 205 | if idx1.size == 1: 206 | idx1 = np.resize(idx1, (1,)) 207 | w = m0.weight.data[:, idx0, :, :].clone() 208 | m1.weight.data = w[idx1, :, :, :].clone() 209 | if m1.bias is not None: 210 | m1.bias.data = m0.bias.data[idx1].clone() 211 | 212 | else: 213 | m1.weight.data = m0.weight.data.clone() 214 | if m1.bias is not None: 215 | m1.bias.data = m0.bias.data.clone() 216 | 217 | elif isinstance(m0, nn.BatchNorm2d): 218 | m1.weight.data = m0.weight.data.clone() 219 | if m1.bias is not None: 220 | m1.bias.data = m0.bias.data.clone() 221 | m1.running_mean = m0.running_mean.clone() 222 | m1.running_var = m0.running_var.clone() 223 | 224 | else: 225 | if isinstance(m0, nn.BatchNorm2d): 226 | idx1 = np.squeeze(np.argwhere(np.asarray(bn_mask[bn_i].cpu().numpy()))) 227 | if idx1.size == 1: 228 | idx1 = np.resize(idx1, (1,)) 229 | m1.weight.data = m0.weight.data[idx1].clone() 230 | if m1.bias is not None: 231 | m1.bias.data = m0.bias.data[idx1].clone() 232 | m1.running_mean = m0.running_mean[idx1].clone() 233 | m1.running_var = m0.running_var[idx1].clone() 234 | bn_i += 1 235 | elif isinstance(m0, nn.Conv2d): 236 | if (isinstance(conv_in_mask[conv_i], list)): 237 | idx0 = np.squeeze(np.argwhere(np.asarray(torch.cat(conv_in_mask[conv_i], 0).cpu().numpy()))) 238 | else: 239 | idx0 = np.squeeze(np.argwhere(np.asarray(conv_in_mask[conv_i].cpu().numpy()))) 240 | idx1 = np.squeeze(np.argwhere(np.asarray(conv_out_mask[conv_i].cpu().numpy()))) 241 | if idx0.size == 1: 242 | idx0 = np.resize(idx0, (1,)) 243 | if idx1.size == 1: 244 | idx1 = np.resize(idx1, (1,)) 245 | w = m0.weight.data[:, idx0, :, :].clone() 246 | m1.weight.data = w[idx1, :, :, :].clone() 247 | if m1.bias is not None: 248 | m1.bias.data = m0.bias.data[idx1].clone() 249 | conv_i += 1 250 | 251 | scale_i += 1 252 | 253 | print(new_model) 254 | 255 | save_obj = {'prued_mask': prued_mask, 'bn_index': bn_index} 256 | torch.save(save_obj, os.path.join(config['pruned']['save_checkpoints'], 'pruned_dict.dict')) 257 | torch.save(new_model.state_dict(), os.path.join(config['pruned']['save_checkpoints'], 'pruned_dict.pth.tar')) 258 | 259 | 260 | if __name__ == '__main__': 261 | 262 | stream = open('./config.yaml', 'r', encoding='utf-8') 263 | config = yaml.load(stream, Loader=yaml.FullLoader) 264 | prune(config) 265 | -------------------------------------------------------------------------------- /models/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: resnet.py 7 | @time: 2020/7/4 15:16 8 | 9 | """ 10 | import torch.nn as nn 11 | import math 12 | import torch.utils.model_zoo as model_zoo 13 | 14 | BatchNorm2d = nn.BatchNorm2d 15 | 16 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 17 | 'resnet152'] 18 | 19 | model_urls = { 20 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 21 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 22 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 23 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 24 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 25 | } 26 | 27 | 28 | def constant_init(module, constant, bias=0): 29 | nn.init.constant_(module.weight, constant) 30 | if hasattr(module, 'bias'): 31 | nn.init.constant_(module.bias, bias) 32 | 33 | 34 | def conv3x3(in_planes, out_planes, stride=1): 35 | """3x3 convolution with padding""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 37 | padding=1, bias=False) 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): 44 | super(BasicBlock, self).__init__() 45 | self.with_dcn = dcn is not None 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = BatchNorm2d(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.with_modulated_dcn = False 50 | if self.with_dcn: 51 | fallback_on_stride = dcn.get('fallback_on_stride', False) 52 | self.with_modulated_dcn = dcn.get('modulated', False) 53 | # self.conv2 = conv3x3(planes, planes) 54 | if not self.with_dcn or fallback_on_stride: 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 56 | padding=1, bias=False) 57 | else: 58 | deformable_groups = dcn.get('deformable_groups', 1) 59 | if not self.with_modulated_dcn: 60 | from models.dcn import DeformConv 61 | conv_op = DeformConv 62 | offset_channels = 18 63 | else: 64 | from models.dcn import ModulatedDeformConv 65 | conv_op = ModulatedDeformConv 66 | offset_channels = 27 67 | self.conv2_offset = nn.Conv2d( 68 | planes, 69 | deformable_groups * offset_channels, 70 | kernel_size=3, 71 | padding=1) 72 | self.conv2 = conv_op( 73 | planes, 74 | planes, 75 | kernel_size=3, 76 | padding=1, 77 | deformable_groups=deformable_groups, 78 | bias=False) 79 | self.bn2 = BatchNorm2d(planes) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | # out = self.conv2(out) 91 | if not self.with_dcn: 92 | out = self.conv2(out) 93 | elif self.with_modulated_dcn: 94 | offset_mask = self.conv2_offset(out) 95 | offset = offset_mask[:, :18, :, :] 96 | mask = offset_mask[:, -9:, :, :].sigmoid() 97 | out = self.conv2(out, offset, mask) 98 | else: 99 | offset = self.conv2_offset(out) 100 | out = self.conv2(out, offset) 101 | out = self.bn2(out) 102 | 103 | if self.downsample is not None: 104 | residual = self.downsample(x) 105 | 106 | out += residual 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class Bottleneck(nn.Module): 113 | expansion = 4 114 | 115 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): 116 | super(Bottleneck, self).__init__() 117 | self.with_dcn = dcn is not None 118 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 119 | self.bn1 = BatchNorm2d(planes) 120 | fallback_on_stride = False 121 | self.with_modulated_dcn = False 122 | if self.with_dcn: 123 | fallback_on_stride = dcn.get('fallback_on_stride', False) 124 | self.with_modulated_dcn = dcn.get('modulated', False) 125 | if not self.with_dcn or fallback_on_stride: 126 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 127 | stride=stride, padding=1, bias=False) 128 | else: 129 | deformable_groups = dcn.get('deformable_groups', 1) 130 | if not self.with_modulated_dcn: 131 | from models.dcn import DeformConv 132 | conv_op = DeformConv 133 | offset_channels = 18 134 | else: 135 | from models.dcn import ModulatedDeformConv 136 | conv_op = ModulatedDeformConv 137 | offset_channels = 27 138 | self.conv2_offset = nn.Conv2d( 139 | planes, deformable_groups * offset_channels, 140 | kernel_size=3, 141 | padding=1) 142 | self.conv2 = conv_op( 143 | planes, planes, kernel_size=3, padding=1, stride=stride, 144 | deformable_groups=deformable_groups, bias=False) 145 | self.bn2 = BatchNorm2d(planes) 146 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 147 | self.bn3 = BatchNorm2d(planes * 4) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.downsample = downsample 150 | self.stride = stride 151 | self.dcn = dcn 152 | self.with_dcn = dcn is not None 153 | 154 | def forward(self, x): 155 | residual = x 156 | 157 | out = self.conv1(x) 158 | out = self.bn1(out) 159 | out = self.relu(out) 160 | 161 | # out = self.conv2(out) 162 | if not self.with_dcn: 163 | out = self.conv2(out) 164 | elif self.with_modulated_dcn: 165 | offset_mask = self.conv2_offset(out) 166 | offset = offset_mask[:, :18, :, :] 167 | mask = offset_mask[:, -9:, :, :].sigmoid() 168 | out = self.conv2(out, offset, mask) 169 | else: 170 | offset = self.conv2_offset(out) 171 | out = self.conv2(out, offset) 172 | out = self.bn2(out) 173 | out = self.relu(out) 174 | 175 | out = self.conv3(out) 176 | out = self.bn3(out) 177 | 178 | if self.downsample is not None: 179 | residual = self.downsample(x) 180 | 181 | out += residual 182 | out = self.relu(out) 183 | 184 | return out 185 | 186 | 187 | class ResNet(nn.Module): 188 | def __init__(self, block, layers, num_classes=1000, 189 | dcn=None, stage_with_dcn=(False, False, False, False)): 190 | self.dcn = dcn 191 | self.stage_with_dcn = stage_with_dcn 192 | self.inplanes = 64 193 | super(ResNet, self).__init__() 194 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 195 | bias=False) 196 | self.bn1 = BatchNorm2d(64) 197 | self.relu = nn.ReLU(inplace=True) 198 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 199 | self.layer1 = self._make_layer(block, 64, layers[0]) 200 | self.layer2 = self._make_layer( 201 | block, 128, layers[1], stride=2, dcn=dcn) 202 | self.layer3 = self._make_layer( 203 | block, 256, layers[2], stride=2, dcn=dcn) 204 | self.layer4 = self._make_layer( 205 | block, 512, layers[3], stride=2, dcn=dcn) 206 | #self.avgpool = nn.AvgPool2d(7, stride=1) 207 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 208 | 209 | #self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1) 210 | 211 | for m in self.modules(): 212 | if isinstance(m, nn.Conv2d): 213 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 214 | m.weight.data.normal_(0, math.sqrt(2. / n)) 215 | elif isinstance(m, BatchNorm2d): 216 | m.weight.data.fill_(1) 217 | m.bias.data.zero_() 218 | if self.dcn is not None: 219 | for m in self.modules(): 220 | if isinstance(m, Bottleneck) or isinstance(m, BasicBlock): 221 | if hasattr(m, 'conv2_offset'): 222 | constant_init(m.conv2_offset, 0) 223 | 224 | def _make_layer(self, block, planes, blocks, stride=1, dcn=None): 225 | downsample = None 226 | if stride != 1 or self.inplanes != planes * block.expansion: 227 | downsample = nn.Sequential( 228 | nn.Conv2d(self.inplanes, planes * block.expansion, 229 | kernel_size=1, stride=stride, bias=False), 230 | BatchNorm2d(planes * block.expansion), 231 | ) 232 | 233 | layers = [] 234 | layers.append(block(self.inplanes, planes, 235 | stride, downsample, dcn=dcn)) 236 | self.inplanes = planes * block.expansion 237 | for i in range(1, blocks): 238 | layers.append(block(self.inplanes, planes, dcn=dcn)) 239 | 240 | return nn.Sequential(*layers) 241 | 242 | def forward(self, x): 243 | x = self.conv1(x) 244 | x = self.bn1(x) 245 | x = self.relu(x) 246 | x = self.maxpool(x) 247 | 248 | x2 = self.layer1(x) 249 | x3 = self.layer2(x2) 250 | x4 = self.layer3(x3) 251 | x5 = self.layer4(x4) 252 | 253 | return x2, x3, x4, x5 254 | 255 | 256 | def resnet18(pretrained=True, **kwargs): 257 | """Constructs a ResNet-18 model. 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | """ 261 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 262 | if pretrained: 263 | model.load_state_dict(model_zoo.load_url( 264 | model_urls['resnet18']), strict=False) 265 | return model 266 | 267 | 268 | def deformable_resnet18(pretrained=True, **kwargs): 269 | """Constructs a ResNet-18 model. 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | """ 273 | model = ResNet(BasicBlock, [2, 2, 2, 2], 274 | dcn=dict(modulated=True, 275 | deformable_groups=1, 276 | fallback_on_stride=False), 277 | stage_with_dcn=[False, True, True, True], **kwargs) 278 | if pretrained: 279 | model.load_state_dict(model_zoo.load_url( 280 | model_urls['resnet18']), strict=False) 281 | return model 282 | 283 | 284 | def resnet34(pretrained=True, **kwargs): 285 | """Constructs a ResNet-34 model. 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | """ 289 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 290 | if pretrained: 291 | model.load_state_dict(model_zoo.load_url( 292 | model_urls['resnet34']), strict=False) 293 | return model 294 | 295 | 296 | def resnet50(pretrained=True, **kwargs): 297 | """Constructs a ResNet-50 model. 298 | Args: 299 | pretrained (bool): If True, returns a model pre-trained on ImageNet 300 | """ 301 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url( 304 | model_urls['resnet50']), strict=False) 305 | return model 306 | 307 | 308 | def deformable_resnet50(pretrained=True, **kwargs): 309 | """Constructs a ResNet-50 model with deformable conv. 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | """ 313 | model = ResNet(Bottleneck, [3, 4, 6, 3], 314 | dcn=dict(modulated=True, 315 | deformable_groups=1, 316 | fallback_on_stride=False), 317 | stage_with_dcn=[False, True, True, True], 318 | **kwargs) 319 | if pretrained: 320 | model.load_state_dict(model_zoo.load_url( 321 | model_urls['resnet50']), strict=False) 322 | return model 323 | 324 | 325 | def resnet101(pretrained=True, **kwargs): 326 | """Constructs a ResNet-101 model. 327 | Args: 328 | pretrained (bool): If True, returns a model pre-trained on ImageNet 329 | """ 330 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 331 | if pretrained: 332 | model.load_state_dict(model_zoo.load_url( 333 | model_urls['resnet101']), strict=False) 334 | return model 335 | 336 | 337 | def resnet152(pretrained=True, **kwargs): 338 | """Constructs a ResNet-152 model. 339 | Args: 340 | pretrained (bool): If True, returns a model pre-trained on ImageNet 341 | """ 342 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 343 | if pretrained: 344 | model.load_state_dict(model_zoo.load_url( 345 | model_urls['resnet152']), strict=False) 346 | return model 347 | 348 | -------------------------------------------------------------------------------- /cal_rescall/script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import namedtuple 3 | from . import rrc_evaluation_funcs 4 | import Polygon as plg 5 | import numpy as np 6 | 7 | 8 | def default_evaluation_params(): 9 | """ 10 | default_evaluation_params: Default parameters to use for the validation and evaluation. 11 | """ 12 | return { 13 | 'IOU_CONSTRAINT': 0.5, 14 | 'AREA_PRECISION_CONSTRAINT': 0.5, 15 | 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', 16 | 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', 17 | 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) 18 | 'CRLF': False, # Lines are delimited by Windows CRLF format 19 | 'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated 20 | 'PER_SAMPLE_RESULTS': True # Generate per sample results and produce data for visualization 21 | } 22 | 23 | 24 | def validate_data(gtFilePath, submFilePath, evaluationParams): 25 | """ 26 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents). 27 | Validates also that there are no missing files in the folder. 28 | If some error detected, the method raises the error 29 | """ 30 | gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) 31 | 32 | subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) 33 | 34 | # Validate format of GroundTruth 35 | for k in gt: 36 | rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True) 37 | 38 | # Validate format of results 39 | for k in subm: 40 | if (k in gt) == False: 41 | raise Exception("The sample %s not present in GT" % k) 42 | 43 | rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'], 44 | False, evaluationParams['CONFIDENCES']) 45 | 46 | 47 | def evaluate_method(gtFilePath, submFilePath, evaluationParams): 48 | """ 49 | Method evaluate_method: evaluate method and returns the results 50 | Results. Dictionary with the following values: 51 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } 52 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } 53 | """ 54 | 55 | def polygon_from_points(points): 56 | """ 57 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 58 | """ 59 | resBoxes = np.empty([1, 8], dtype='int32') 60 | resBoxes[0, 0] = int(points[0]) 61 | resBoxes[0, 4] = int(points[1]) 62 | resBoxes[0, 1] = int(points[2]) 63 | resBoxes[0, 5] = int(points[3]) 64 | resBoxes[0, 2] = int(points[4]) 65 | resBoxes[0, 6] = int(points[5]) 66 | resBoxes[0, 3] = int(points[6]) 67 | resBoxes[0, 7] = int(points[7]) 68 | pointMat = resBoxes[0].reshape([2, 4]).T 69 | return plg.Polygon(pointMat) 70 | 71 | def rectangle_to_polygon(rect): 72 | resBoxes = np.empty([1, 8], dtype='int32') 73 | resBoxes[0, 0] = int(rect.xmin) 74 | resBoxes[0, 4] = int(rect.ymax) 75 | resBoxes[0, 1] = int(rect.xmin) 76 | resBoxes[0, 5] = int(rect.ymin) 77 | resBoxes[0, 2] = int(rect.xmax) 78 | resBoxes[0, 6] = int(rect.ymin) 79 | resBoxes[0, 3] = int(rect.xmax) 80 | resBoxes[0, 7] = int(rect.ymax) 81 | 82 | pointMat = resBoxes[0].reshape([2, 4]).T 83 | 84 | return plg.Polygon(pointMat) 85 | 86 | def rectangle_to_points(rect): 87 | points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), 88 | int(rect.xmin), int(rect.ymin)] 89 | return points 90 | 91 | def get_union(pD, pG): 92 | areaA = pD.area(); 93 | areaB = pG.area(); 94 | return areaA + areaB - get_intersection(pD, pG); 95 | 96 | def get_intersection_over_union(pD, pG): 97 | try: 98 | return get_intersection(pD, pG) / get_union(pD, pG); 99 | except: 100 | return 0 101 | 102 | def get_intersection(pD, pG): 103 | pInt = pD & pG 104 | if len(pInt) == 0: 105 | return 0 106 | return pInt.area() 107 | 108 | def compute_ap(confList, matchList, numGtCare): 109 | correct = 0 110 | AP = 0 111 | if len(confList) > 0: 112 | confList = np.array(confList) 113 | matchList = np.array(matchList) 114 | sorted_ind = np.argsort(-confList) 115 | confList = confList[sorted_ind] 116 | matchList = matchList[sorted_ind] 117 | for n in range(len(confList)): 118 | match = matchList[n] 119 | if match: 120 | correct += 1 121 | AP += float(correct) / (n + 1) 122 | 123 | if numGtCare > 0: 124 | AP /= numGtCare 125 | 126 | return AP 127 | 128 | perSampleMetrics = {} 129 | 130 | matchedSum = 0 131 | 132 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') 133 | 134 | gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) 135 | subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) 136 | 137 | numGlobalCareGt = 0; 138 | numGlobalCareDet = 0; 139 | 140 | arrGlobalConfidences = []; 141 | arrGlobalMatches = []; 142 | 143 | for resFile in gt: 144 | 145 | gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile]) 146 | recall = 0 147 | precision = 0 148 | hmean = 0 149 | 150 | detMatched = 0 151 | 152 | iouMat = np.empty([1, 1]) 153 | 154 | gtPols = [] 155 | detPols = [] 156 | 157 | gtPolPoints = [] 158 | detPolPoints = [] 159 | 160 | # Array of Ground Truth Polygons' keys marked as don't Care 161 | gtDontCarePolsNum = [] 162 | # Array of Detected Polygons' matched with a don't Care GT 163 | detDontCarePolsNum = [] 164 | 165 | pairs = [] 166 | detMatchedNums = [] 167 | 168 | arrSampleConfidences = []; 169 | arrSampleMatch = []; 170 | sampleAP = 0; 171 | 172 | evaluationLog = "" 173 | 174 | pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile, 175 | evaluationParams[ 176 | 'CRLF'], 177 | evaluationParams[ 178 | 'LTRB'], 179 | True, False) 180 | for n in range(len(pointsList)): 181 | points = pointsList[n] 182 | transcription = transcriptionsList[n] 183 | dontCare = transcription == "###" 184 | if evaluationParams['LTRB']: 185 | gtRect = Rectangle(*points) 186 | gtPol = rectangle_to_polygon(gtRect) 187 | else: 188 | gtPol = polygon_from_points(points) 189 | gtPols.append(gtPol) 190 | gtPolPoints.append(points) 191 | if dontCare: 192 | gtDontCarePolsNum.append(len(gtPols) - 1) 193 | 194 | evaluationLog += "GT polygons: " + str(len(gtPols)) + ( 195 | " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n") 196 | 197 | if resFile in subm: 198 | 199 | detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile]) 200 | 201 | pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile, 202 | evaluationParams[ 203 | 'CRLF'], 204 | evaluationParams[ 205 | 'LTRB'], 206 | False, 207 | evaluationParams[ 208 | 'CONFIDENCES']) 209 | for n in range(len(pointsList)): 210 | points = pointsList[n] 211 | 212 | if evaluationParams['LTRB']: 213 | detRect = Rectangle(*points) 214 | detPol = rectangle_to_polygon(detRect) 215 | else: 216 | detPol = polygon_from_points(points) 217 | detPols.append(detPol) 218 | detPolPoints.append(points) 219 | if len(gtDontCarePolsNum) > 0: 220 | for dontCarePol in gtDontCarePolsNum: 221 | dontCarePol = gtPols[dontCarePol] 222 | intersected_area = get_intersection(dontCarePol, detPol) 223 | pdDimensions = detPol.area() 224 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions 225 | if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']): 226 | detDontCarePolsNum.append(len(detPols) - 1) 227 | break 228 | 229 | evaluationLog += "DET polygons: " + str(len(detPols)) + ( 230 | " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n") 231 | 232 | if len(gtPols) > 0 and len(detPols) > 0: 233 | # Calculate IoU and precision matrixs 234 | outputShape = [len(gtPols), len(detPols)] 235 | iouMat = np.empty(outputShape) 236 | gtRectMat = np.zeros(len(gtPols), np.int8) 237 | detRectMat = np.zeros(len(detPols), np.int8) 238 | for gtNum in range(len(gtPols)): 239 | for detNum in range(len(detPols)): 240 | pG = gtPols[gtNum] 241 | pD = detPols[detNum] 242 | iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) 243 | 244 | for gtNum in range(len(gtPols)): 245 | for detNum in range(len(detPols)): 246 | if gtRectMat[gtNum] == 0 and detRectMat[ 247 | detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: 248 | if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']: 249 | gtRectMat[gtNum] = 1 250 | detRectMat[detNum] = 1 251 | detMatched += 1 252 | pairs.append({'gt': gtNum, 'det': detNum}) 253 | detMatchedNums.append(detNum) 254 | evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" 255 | 256 | if evaluationParams['CONFIDENCES']: 257 | for detNum in range(len(detPols)): 258 | if detNum not in detDontCarePolsNum: 259 | # we exclude the don't care detections 260 | match = detNum in detMatchedNums 261 | 262 | arrSampleConfidences.append(confidencesList[detNum]) 263 | arrSampleMatch.append(match) 264 | 265 | arrGlobalConfidences.append(confidencesList[detNum]); 266 | arrGlobalMatches.append(match); 267 | 268 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) 269 | numDetCare = (len(detPols) - len(detDontCarePolsNum)) 270 | if numGtCare == 0: 271 | recall = float(1) 272 | precision = float(0) if numDetCare > 0 else float(1) 273 | sampleAP = precision 274 | else: 275 | recall = float(detMatched) / numGtCare 276 | precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare 277 | if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: 278 | sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare) 279 | 280 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 281 | 282 | matchedSum += detMatched 283 | numGlobalCareGt += numGtCare 284 | numGlobalCareDet += numDetCare 285 | 286 | if evaluationParams['PER_SAMPLE_RESULTS']: 287 | perSampleMetrics[resFile] = { 288 | 'precision': precision, 289 | 'recall': recall, 290 | 'hmean': hmean, 291 | 'pairs': pairs, 292 | 'AP': sampleAP, 293 | 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), 294 | 'gtPolPoints': gtPolPoints, 295 | 'detPolPoints': detPolPoints, 296 | 'gtDontCare': gtDontCarePolsNum, 297 | 'detDontCare': detDontCarePolsNum, 298 | 'evaluationParams': evaluationParams, 299 | 'evaluationLog': evaluationLog 300 | } 301 | 302 | # Compute MAP and MAR 303 | AP = 0 304 | if evaluationParams['CONFIDENCES']: 305 | AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) 306 | 307 | methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt 308 | methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet 309 | methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( 310 | methodRecall + methodPrecision) 311 | 312 | methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP} 313 | 314 | resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics} 315 | 316 | return resDict; 317 | 318 | 319 | def cal_recall_precison_f1(gt_path, result_path, show_result=False): 320 | p = {'g': gt_path, 's': result_path} 321 | result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, validate_data, evaluate_method, 322 | show_result) 323 | return result['method'] -------------------------------------------------------------------------------- /models/dcn/src/deform_pool_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017 Microsoft 3 | * Licensed under The MIT License [see LICENSE for details] 4 | * \file deformable_psroi_pooling.cu 5 | * \brief 6 | * \author Yi Li, Guodong Zhang, Jifeng Dai 7 | */ 8 | /***************** Adapted by Charles Shang *********************/ 9 | // modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | using namespace at; 18 | 19 | #define CUDA_KERNEL_LOOP(i, n) \ 20 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 21 | i < (n); \ 22 | i += blockDim.x * gridDim.x) 23 | 24 | const int CUDA_NUM_THREADS = 1024; 25 | inline int GET_BLOCKS(const int N) 26 | { 27 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 28 | } 29 | 30 | template 31 | __device__ scalar_t bilinear_interp( 32 | const scalar_t *data, 33 | const scalar_t x, 34 | const scalar_t y, 35 | const int width, 36 | const int height) 37 | { 38 | int x1 = floor(x); 39 | int x2 = ceil(x); 40 | int y1 = floor(y); 41 | int y2 = ceil(y); 42 | scalar_t dist_x = (scalar_t)(x - x1); 43 | scalar_t dist_y = (scalar_t)(y - y1); 44 | scalar_t value11 = data[y1 * width + x1]; 45 | scalar_t value12 = data[y2 * width + x1]; 46 | scalar_t value21 = data[y1 * width + x2]; 47 | scalar_t value22 = data[y2 * width + x2]; 48 | scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; 49 | return value; 50 | } 51 | 52 | template 53 | __global__ void DeformablePSROIPoolForwardKernel( 54 | const int count, 55 | const scalar_t *bottom_data, 56 | const scalar_t spatial_scale, 57 | const int channels, 58 | const int height, const int width, 59 | const int pooled_height, const int pooled_width, 60 | const scalar_t *bottom_rois, const scalar_t *bottom_trans, 61 | const int no_trans, 62 | const scalar_t trans_std, 63 | const int sample_per_part, 64 | const int output_dim, 65 | const int group_size, 66 | const int part_size, 67 | const int num_classes, 68 | const int channels_each_class, 69 | scalar_t *top_data, 70 | scalar_t *top_count) 71 | { 72 | CUDA_KERNEL_LOOP(index, count) 73 | { 74 | // The output is in order (n, ctop, ph, pw) 75 | int pw = index % pooled_width; 76 | int ph = (index / pooled_width) % pooled_height; 77 | int ctop = (index / pooled_width / pooled_height) % output_dim; 78 | int n = index / pooled_width / pooled_height / output_dim; 79 | 80 | // [start, end) interval for spatial sampling 81 | const scalar_t *offset_bottom_rois = bottom_rois + n * 5; 82 | int roi_batch_ind = offset_bottom_rois[0]; 83 | scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 84 | scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 85 | scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 86 | scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 87 | 88 | // Force too small ROIs to be 1x1 89 | scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 90 | scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1); 91 | 92 | // Compute w and h at bottom 93 | scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height); 94 | scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width); 95 | 96 | scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part); 97 | scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part); 98 | 99 | int part_h = floor((scalar_t)(ph) / pooled_height * part_size); 100 | int part_w = floor((scalar_t)(pw) / pooled_width * part_size); 101 | int class_id = ctop / channels_each_class; 102 | scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; 103 | scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; 104 | 105 | scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w; 106 | wstart += trans_x * roi_width; 107 | scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h; 108 | hstart += trans_y * roi_height; 109 | 110 | scalar_t sum = 0; 111 | int count = 0; 112 | int gw = floor((scalar_t)(pw)*group_size / pooled_width); 113 | int gh = floor((scalar_t)(ph)*group_size / pooled_height); 114 | gw = min(max(gw, 0), group_size - 1); 115 | gh = min(max(gh, 0), group_size - 1); 116 | 117 | const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; 118 | for (int ih = 0; ih < sample_per_part; ih++) 119 | { 120 | for (int iw = 0; iw < sample_per_part; iw++) 121 | { 122 | scalar_t w = wstart + iw * sub_bin_size_w; 123 | scalar_t h = hstart + ih * sub_bin_size_h; 124 | // bilinear interpolation 125 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 126 | { 127 | continue; 128 | } 129 | w = min(max(w, 0.), width - 1.); 130 | h = min(max(h, 0.), height - 1.); 131 | int c = (ctop * group_size + gh) * group_size + gw; 132 | scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); 133 | sum += val; 134 | count++; 135 | } 136 | } 137 | top_data[index] = count == 0 ? (scalar_t)(0) : sum / count; 138 | top_count[index] = count; 139 | } 140 | } 141 | 142 | template 143 | __global__ void DeformablePSROIPoolBackwardAccKernel( 144 | const int count, 145 | const scalar_t *top_diff, 146 | const scalar_t *top_count, 147 | const int num_rois, 148 | const scalar_t spatial_scale, 149 | const int channels, 150 | const int height, const int width, 151 | const int pooled_height, const int pooled_width, 152 | const int output_dim, 153 | scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff, 154 | const scalar_t *bottom_data, 155 | const scalar_t *bottom_rois, 156 | const scalar_t *bottom_trans, 157 | const int no_trans, 158 | const scalar_t trans_std, 159 | const int sample_per_part, 160 | const int group_size, 161 | const int part_size, 162 | const int num_classes, 163 | const int channels_each_class) 164 | { 165 | CUDA_KERNEL_LOOP(index, count) 166 | { 167 | // The output is in order (n, ctop, ph, pw) 168 | int pw = index % pooled_width; 169 | int ph = (index / pooled_width) % pooled_height; 170 | int ctop = (index / pooled_width / pooled_height) % output_dim; 171 | int n = index / pooled_width / pooled_height / output_dim; 172 | 173 | // [start, end) interval for spatial sampling 174 | const scalar_t *offset_bottom_rois = bottom_rois + n * 5; 175 | int roi_batch_ind = offset_bottom_rois[0]; 176 | scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 177 | scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 178 | scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 179 | scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 180 | 181 | // Force too small ROIs to be 1x1 182 | scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 183 | scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1); 184 | 185 | // Compute w and h at bottom 186 | scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height); 187 | scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width); 188 | 189 | scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part); 190 | scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part); 191 | 192 | int part_h = floor((scalar_t)(ph) / pooled_height * part_size); 193 | int part_w = floor((scalar_t)(pw) / pooled_width * part_size); 194 | int class_id = ctop / channels_each_class; 195 | scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; 196 | scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; 197 | 198 | scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w; 199 | wstart += trans_x * roi_width; 200 | scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h; 201 | hstart += trans_y * roi_height; 202 | 203 | if (top_count[index] <= 0) 204 | { 205 | continue; 206 | } 207 | scalar_t diff_val = top_diff[index] / top_count[index]; 208 | const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; 209 | scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; 210 | int gw = floor((scalar_t)(pw)*group_size / pooled_width); 211 | int gh = floor((scalar_t)(ph)*group_size / pooled_height); 212 | gw = min(max(gw, 0), group_size - 1); 213 | gh = min(max(gh, 0), group_size - 1); 214 | 215 | for (int ih = 0; ih < sample_per_part; ih++) 216 | { 217 | for (int iw = 0; iw < sample_per_part; iw++) 218 | { 219 | scalar_t w = wstart + iw * sub_bin_size_w; 220 | scalar_t h = hstart + ih * sub_bin_size_h; 221 | // bilinear interpolation 222 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 223 | { 224 | continue; 225 | } 226 | w = min(max(w, 0.), width - 1.); 227 | h = min(max(h, 0.), height - 1.); 228 | int c = (ctop * group_size + gh) * group_size + gw; 229 | // backward on feature 230 | int x0 = floor(w); 231 | int x1 = ceil(w); 232 | int y0 = floor(h); 233 | int y1 = ceil(h); 234 | scalar_t dist_x = w - x0, dist_y = h - y0; 235 | scalar_t q00 = (1 - dist_x) * (1 - dist_y); 236 | scalar_t q01 = (1 - dist_x) * dist_y; 237 | scalar_t q10 = dist_x * (1 - dist_y); 238 | scalar_t q11 = dist_x * dist_y; 239 | int bottom_index_base = c * height * width; 240 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); 241 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); 242 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); 243 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); 244 | 245 | if (no_trans) 246 | { 247 | continue; 248 | } 249 | scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; 250 | scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; 251 | scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; 252 | scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; 253 | scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; 254 | diff_x *= roi_width; 255 | scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; 256 | diff_y *= roi_height; 257 | 258 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); 259 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); 260 | } 261 | } 262 | } 263 | } 264 | 265 | void DeformablePSROIPoolForward(const at::Tensor data, 266 | const at::Tensor bbox, 267 | const at::Tensor trans, 268 | at::Tensor out, 269 | at::Tensor top_count, 270 | const int batch, 271 | const int channels, 272 | const int height, 273 | const int width, 274 | const int num_bbox, 275 | const int channels_trans, 276 | const int no_trans, 277 | const float spatial_scale, 278 | const int output_dim, 279 | const int group_size, 280 | const int pooled_size, 281 | const int part_size, 282 | const int sample_per_part, 283 | const float trans_std) 284 | { 285 | const int pooled_height = pooled_size; 286 | const int pooled_width = pooled_size; 287 | const int count = num_bbox * output_dim * pooled_height * pooled_width; 288 | const int num_classes = no_trans ? 1 : channels_trans / 2; 289 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 290 | 291 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 292 | data.type(), "deformable_psroi_pool_forward", ([&] { 293 | const scalar_t *bottom_data = data.data(); 294 | const scalar_t *bottom_rois = bbox.data(); 295 | const scalar_t *bottom_trans = no_trans ? NULL : trans.data(); 296 | scalar_t *top_data = out.data(); 297 | scalar_t *top_count_data = top_count.data(); 298 | 299 | DeformablePSROIPoolForwardKernel<<>>( 300 | count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, 301 | bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim, 302 | group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); 303 | })); 304 | 305 | cudaError_t err = cudaGetLastError(); 306 | if (err != cudaSuccess) 307 | { 308 | printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); 309 | } 310 | } 311 | 312 | void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, 313 | const at::Tensor data, 314 | const at::Tensor bbox, 315 | const at::Tensor trans, 316 | const at::Tensor top_count, 317 | at::Tensor in_grad, 318 | at::Tensor trans_grad, 319 | const int batch, 320 | const int channels, 321 | const int height, 322 | const int width, 323 | const int num_bbox, 324 | const int channels_trans, 325 | const int no_trans, 326 | const float spatial_scale, 327 | const int output_dim, 328 | const int group_size, 329 | const int pooled_size, 330 | const int part_size, 331 | const int sample_per_part, 332 | const float trans_std) 333 | { 334 | // LOG(INFO) << "DeformablePSROIPoolBackward"; 335 | const int num_rois = num_bbox; 336 | const int pooled_height = pooled_size; 337 | const int pooled_width = pooled_size; 338 | const int count = num_bbox * output_dim * pooled_height * pooled_width; 339 | const int num_classes = no_trans ? 1 : channels_trans / 2; 340 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 341 | 342 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 343 | out_grad.type(), "deformable_psroi_pool_backward_acc", ([&] { 344 | const scalar_t *top_diff = out_grad.data(); 345 | const scalar_t *bottom_data = data.data(); 346 | const scalar_t *bottom_rois = bbox.data(); 347 | const scalar_t *bottom_trans = no_trans ? NULL : trans.data(); 348 | scalar_t *bottom_data_diff = in_grad.data(); 349 | scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data(); 350 | const scalar_t *top_count_data = top_count.data(); 351 | 352 | DeformablePSROIPoolBackwardAccKernel<<>>( 353 | count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width, 354 | pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, 355 | bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, 356 | group_size, part_size, num_classes, channels_each_class); 357 | })); 358 | 359 | cudaError_t err = cudaGetLastError(); 360 | if (err != cudaSuccess) 361 | { 362 | printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); 363 | } 364 | } --------------------------------------------------------------------------------