├── cal_rescall
├── __init__.py
├── __pycache__
│ ├── script.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ └── rrc_evaluation_funcs.cpython-37.pyc
└── script.py
├── models
├── dcn
│ ├── functions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── deform_conv.cpython-37.pyc
│ │ │ └── deform_pool.cpython-37.pyc
│ │ ├── deform_pool.py
│ │ └── deform_conv.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── deform_conv.cpython-37.pyc
│ │ │ └── deform_pool.cpython-37.pyc
│ │ ├── deform_conv.py
│ │ └── deform_pool.py
│ ├── make.sh
│ ├── setup.py
│ ├── __init__.py
│ └── src
│ │ ├── deform_pool_cuda.cpp
│ │ └── deform_pool_cuda_kernel.cu
├── __pycache__
│ ├── DBNet.cpython-37.pyc
│ └── __init__.cpython-37.pyc
├── head
│ ├── __pycache__
│ │ ├── DB_Head.cpython-37.pyc
│ │ ├── FPN_Head.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ └── seg_detector.cpython-37.pyc
│ ├── __init__.py
│ ├── DB_Head.py
│ ├── FPN_Head.py
│ └── seg_detector.py
├── backbone
│ ├── __pycache__
│ │ ├── resnet.cpython-37.pyc
│ │ └── __init__.cpython-37.pyc
│ ├── __init__.py
│ └── resnet.py
├── __init__.py
└── DBNet.py
├── show
├── 1.jpg
└── 2.jpg
├── loss
├── __pycache__
│ ├── loss.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── l1_loss.cpython-37.pyc
│ ├── dice_loss.cpython-37.pyc
│ └── balance_cross_entropy_loss.cpython-37.pyc
├── __init__.py
├── l1_loss.py
├── loss.py
├── balance_cross_entropy_loss.py
└── dice_loss.py
├── utils
├── __pycache__
│ ├── tools.cpython-37.pyc
│ ├── Logger.cpython-37.pyc
│ ├── metrics.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── fuse_model.cpython-37.pyc
│ ├── model_eval.cpython-37.pyc
│ ├── DB_postprocesss.cpython-37.pyc
│ └── set_optimizer.cpython-37.pyc
├── __init__.py
├── set_optimizer.py
├── fuse_model.py
├── metrics.py
├── Logger.py
├── tools.py
├── model_eval.py
└── DB_postprocesss.py
├── dataloader
├── __pycache__
│ ├── MakeSegMap.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── dataload.cpython-37.pyc
│ ├── MakeBorderMap.cpython-37.pyc
│ └── random_thansform.cpython-37.pyc
├── __init__.py
├── MakeSegMap.py
├── dataload.py
├── MakeBorderMap.py
└── random_thansform.py
├── pruned
├── __init__.py
├── prune_inference.py
├── get_pruned_model.py
├── train_fintune.py
└── prune.py
├── requirement.txt
├── config.yaml
├── README.md
├── README_en.md
├── inference.py
└── train.py
/cal_rescall/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/dcn/functions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/dcn/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/dcn/make.sh:
--------------------------------------------------------------------------------
1 | python3 setup.py build_ext --inplace
--------------------------------------------------------------------------------
/show/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/show/1.jpg
--------------------------------------------------------------------------------
/show/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/show/2.jpg
--------------------------------------------------------------------------------
/loss/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/tools.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/tools.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/l1_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/l1_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/DBNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/__pycache__/DBNet.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/Logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/Logger.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/metrics.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/dice_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/dice_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/fuse_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/fuse_model.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/model_eval.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/model_eval.cpython-37.pyc
--------------------------------------------------------------------------------
/cal_rescall/__pycache__/script.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/cal_rescall/__pycache__/script.cpython-37.pyc
--------------------------------------------------------------------------------
/cal_rescall/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/cal_rescall/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/MakeSegMap.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/MakeSegMap.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/dataload.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/dataload.cpython-37.pyc
--------------------------------------------------------------------------------
/models/head/__pycache__/DB_Head.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/head/__pycache__/DB_Head.cpython-37.pyc
--------------------------------------------------------------------------------
/models/head/__pycache__/FPN_Head.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/head/__pycache__/FPN_Head.cpython-37.pyc
--------------------------------------------------------------------------------
/models/head/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/head/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/DB_postprocesss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/DB_postprocesss.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/set_optimizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/utils/__pycache__/set_optimizer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/backbone/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/backbone/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/MakeBorderMap.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/MakeBorderMap.cpython-37.pyc
--------------------------------------------------------------------------------
/models/backbone/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/backbone/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/head/__pycache__/seg_detector.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/head/__pycache__/seg_detector.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/random_thansform.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/dataloader/__pycache__/random_thansform.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dcn/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/balance_cross_entropy_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/loss/__pycache__/balance_cross_entropy_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dcn/functions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/functions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dcn/modules/__pycache__/deform_conv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/modules/__pycache__/deform_conv.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dcn/modules/__pycache__/deform_pool.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/modules/__pycache__/deform_pool.cpython-37.pyc
--------------------------------------------------------------------------------
/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dcn/functions/__pycache__/deform_conv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/functions/__pycache__/deform_conv.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dcn/functions/__pycache__/deform_pool.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BADBADBADBOY/DBnet-lite.pytorch/HEAD/models/dcn/functions/__pycache__/deform_pool.cpython-37.pyc
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: __init__.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: __init__.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
--------------------------------------------------------------------------------
/pruned/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: __init__.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: __init__.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: __init__.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
--------------------------------------------------------------------------------
/models/head/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: __init__.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
--------------------------------------------------------------------------------
/models/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: __init__.py.py
7 | @time: 2020/6/29 21:39
8 |
9 | """
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | pyyaml
2 | tqdm
3 | tensorboardX
4 | opencv-python==4.1.2.30
5 | anyconfig
6 | munch
7 | scipy
8 | sortedcontainers
9 | shapely
10 | pyclipper
11 | gevent
12 | gevent-websocket
13 | flask
14 | editdistance
15 | scikit-image
16 | imgaug
17 | tabulate
--------------------------------------------------------------------------------
/models/dcn/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='deform_conv',
6 | ext_modules=[
7 | CUDAExtension('deform_conv_cuda', [
8 | 'src/deform_conv_cuda.cpp',
9 | 'src/deform_conv_cuda_kernel.cu',
10 | ]),
11 | CUDAExtension('deform_pool_cuda', [
12 | 'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu'
13 | ]),
14 | ],
15 | cmdclass={'build_ext': BuildExtension})
16 |
--------------------------------------------------------------------------------
/models/dcn/__init__.py:
--------------------------------------------------------------------------------
1 | from .functions.deform_conv import deform_conv, modulated_deform_conv
2 | from .functions.deform_pool import deform_roi_pooling
3 | from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
4 | DeformConvPack, ModulatedDeformConvPack)
5 | from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
6 | ModulatedDeformRoIPoolingPack)
7 |
8 | __all__ = [
9 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
10 | 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
11 | 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
12 | 'deform_roi_pooling'
13 | ]
14 |
--------------------------------------------------------------------------------
/utils/set_optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: set_optimizer.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 |
11 | def lr_poly(base_lr, epoch, max_epoch=1200, factor=0.9):
12 | return base_lr*((1-float(epoch)/max_epoch)**(factor))
13 |
14 | def adjust_learning_rate_poly(base_lr, optimizer, epoch, max_epoch=1200, factor=0.9):
15 | lr = lr_poly(base_lr, epoch, max_epoch, factor)
16 | optimizer.param_groups[0]['lr'] = lr
17 |
18 |
19 | def adjust_learning_rate(config, optimizer, epoch,gama = 0.1):
20 | if epoch in config['train']['schedule']:
21 | config['train']['base_lr'] =config['train']['base_lr'] * gama
22 | for param_group in optimizer.param_groups:
23 | param_group['lr'] = config['train']['base_lr']
--------------------------------------------------------------------------------
/models/DBNet.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: DBNet.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import torch.nn as nn
11 | from models.head.seg_detector import SegDetector
12 | from models.backbone.resnet import resnet18, resnet50,deformable_resnet50, deformable_resnet18
13 |
14 | class DBNet(nn.Module):
15 | def __init__(self, config,is_train = True):
16 | super(DBNet, self).__init__()
17 | self.backbone = globals().get(config['train']['backbone'])(pretrained=config['train']['pretrained'])
18 | if(is_train is False):
19 | config['train']['adaptive'] = config['test']['adaptive']
20 | self.decode = SegDetector(headname = config['train']['HeadName'],
21 | in_channels = config['train']['in_channels'],
22 | inner_channels = config['train']['inner_channels'],
23 | k = config['train']['k'],
24 | bias=False, adaptive= config['train']['adaptive'], smooth=False, serial=False)
25 | def forward(self, x):
26 | x = self.backbone(x)
27 | out = self.decode.forward(x)
28 | return out
29 |
30 |
31 |
--------------------------------------------------------------------------------
/loss/l1_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MaskL1Loss(nn.Module):
6 | def __init__(self):
7 | super(MaskL1Loss, self).__init__()
8 |
9 | def forward(self, pred: torch.Tensor, gt, mask):
10 | mask_sum = mask.sum()
11 | if mask_sum.item() == 0:
12 | return mask_sum, dict(l1_loss=mask_sum)
13 | else:
14 | loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sum
15 | return loss, dict(l1_loss=loss)
16 |
17 |
18 | class BalanceL1Loss(nn.Module):
19 | def __init__(self, negative_ratio=3.):
20 | super(BalanceL1Loss, self).__init__()
21 | self.negative_ratio = negative_ratio
22 |
23 | def forward(self, pred: torch.Tensor, gt, mask):
24 | '''
25 | Args:
26 | pred: (N, 1, H, W).
27 | gt: (N, H, W).
28 | mask: (N, H, W).
29 | '''
30 | loss = torch.abs(pred[:, 0] - gt)
31 | positive = loss * mask
32 | negative = loss * (1 - mask)
33 | positive_count = int(mask.sum())
34 | negative_count = min(
35 | int((1 - mask).sum()),
36 | int(positive_count * self.negative_ratio))
37 | negative_loss, _ = torch.topk(negative.view(-1), negative_count)
38 | negative_loss = negative_loss.sum() / negative_count
39 | positive_loss = positive.sum() / positive_count
40 | return positive_loss + negative_loss,\
41 | dict(l1_loss=positive_loss, nge_l1_loss=negative_loss)
42 |
--------------------------------------------------------------------------------
/loss/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: loss.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import torch.nn as nn
11 | from loss.dice_loss import DiceLoss,dice_loss
12 | from loss.l1_loss import MaskL1Loss
13 | from loss.balance_cross_entropy_loss import BalanceCrossEntropyLoss
14 |
15 | class L1BalanceCELoss(nn.Module):
16 | '''
17 | Balanced CrossEntropy Loss on `binary`,
18 | MaskL1Loss on `thresh`,
19 | DiceLoss on `thresh_binary`.
20 | Note: The meaning of inputs can be figured out in `SegDetectorLossBuilder`.
21 | '''
22 |
23 | def __init__(self, eps=1e-6, l1_scale=10, bce_scale=1):
24 | super(L1BalanceCELoss, self).__init__()
25 | # self.dice_loss = DiceLoss(eps=eps)
26 | self.dice_loss = dice_loss
27 | self.l1_loss = MaskL1Loss()
28 | self.bce_loss = BalanceCrossEntropyLoss()
29 | self.l1_scale = l1_scale
30 | self.bce_scale = bce_scale
31 |
32 | def forward(self, pred, batch):
33 | bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask'])
34 | metrics = dict(bce_loss=bce_loss)
35 | if 'thresh' in pred:
36 | l1_loss, l1_metric = self.l1_loss(pred['thresh'], batch['thresh_map'], batch['thresh_mask'])
37 | dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'], batch['mask'])
38 | metrics['thresh_loss'] = dice_loss
39 | loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale
40 | metrics.update(**l1_metric)
41 | else:
42 | loss = bce_loss
43 | return loss, metrics
--------------------------------------------------------------------------------
/utils/fuse_model.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: fuse_model.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import torch
11 | import torch.nn as nn
12 | import time
13 | import sys
14 | import numpy as np
15 | import torchvision
16 | import torch.nn.functional as F
17 |
18 |
19 | class DummyModule(nn.Module):
20 | def __init__(self):
21 | super(DummyModule, self).__init__()
22 |
23 | def forward(self, x):
24 | return x
25 |
26 | def fuse(conv, bn):
27 | # *******************conv参数********************
28 | w = conv.weight
29 |
30 | # ********************BN参数*********************
31 | mean = bn.running_mean
32 | var_sqrt = torch.sqrt(bn.running_var + bn.eps)
33 | gamma = bn.weight
34 | beta = bn.bias
35 |
36 | if conv.bias is not None:
37 | b = conv.bias
38 | else:
39 | b = mean.new_zeros(mean.shape)
40 |
41 | w = w * (gamma / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
42 | b = (b - mean)/var_sqrt * gamma + beta
43 |
44 | fused_conv = nn.Conv2d(conv.in_channels,
45 | conv.out_channels,
46 | conv.kernel_size,
47 | conv.stride,
48 | conv.padding,
49 | bias=True)
50 | fused_conv.weight = nn.Parameter(w)
51 | fused_conv.bias = nn.Parameter(b)
52 | return fused_conv
53 |
54 | def fuse_module(m):
55 | children = list(m.named_children())
56 | c = None
57 | cn = None
58 | for name, child in children:
59 | if isinstance(child, nn.BatchNorm2d) and c is not None:
60 | bc = fuse(c, child)
61 | m._modules[cn] = bc
62 | m._modules[name] = DummyModule()
63 | c = None
64 | elif isinstance(child, nn.Conv2d):
65 | c = child
66 | cn = name
67 | else:
68 | fuse_module(child)
69 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Adapted from score written by wkentaro
2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3 |
4 | import numpy as np
5 |
6 | class runningScore(object):
7 |
8 | def __init__(self, n_classes):
9 | self.n_classes = n_classes
10 | self.confusion_matrix = np.zeros((n_classes, n_classes))
11 |
12 | def _fast_hist(self, label_true, label_pred, n_class):
13 | mask = (label_true >= 0) & (label_true < n_class)
14 |
15 | if np.sum((label_pred[mask] < 0)) > 0:
16 | print (label_pred[label_pred < 0])
17 | hist = np.bincount(
18 | n_class * label_true[mask].astype(int) +
19 | label_pred[mask], minlength=n_class**2).reshape(n_class, n_class)
20 | return hist
21 |
22 | def update(self, label_trues, label_preds):
23 | # print label_trues.dtype, label_preds.dtype
24 | for lt, lp in zip(label_trues, label_preds):
25 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
26 |
27 | def get_scores(self):
28 | """Returns accuracy score evaluation result.
29 | - overall accuracy
30 | - mean accuracy
31 | - mean IU
32 | - fwavacc
33 | """
34 | hist = self.confusion_matrix
35 | acc = np.diag(hist).sum() / (hist.sum() + 0.0001)
36 | acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001)
37 | acc_cls = np.nanmean(acc_cls)
38 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001)
39 | mean_iu = np.nanmean(iu)
40 | freq = hist.sum(axis=1) / (hist.sum() + 0.0001)
41 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
42 | cls_iu = dict(zip(range(self.n_classes), iu))
43 |
44 | return {'Overall Acc': acc,
45 | 'Mean Acc': acc_cls,
46 | 'FreqW Acc': fwavacc,
47 | 'Mean IoU': mean_iu,}, cls_iu
48 |
49 | def reset(self):
50 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
--------------------------------------------------------------------------------
/loss/balance_cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BalanceCrossEntropyLoss(nn.Module):
6 | '''
7 | Balanced cross entropy loss.
8 | Shape:
9 | - Input: :math:`(N, 1, H, W)`
10 | - GT: :math:`(N, 1, H, W)`, same shape as the input
11 | - Mask: :math:`(N, H, W)`, same spatial shape as the input
12 | - Output: scalar.
13 |
14 | Examples::
15 |
16 | >>> m = nn.Sigmoid()
17 | >>> loss = nn.BCELoss()
18 | >>> input = torch.randn(3, requires_grad=True)
19 | >>> target = torch.empty(3).random_(2)
20 | >>> output = loss(m(input), target)
21 | >>> output.backward()
22 | '''
23 |
24 | def __init__(self, negative_ratio=3.0, eps=1e-6):
25 | super(BalanceCrossEntropyLoss, self).__init__()
26 | self.negative_ratio = negative_ratio
27 | self.eps = eps
28 |
29 | def forward(self,
30 | pred: torch.Tensor,
31 | gt: torch.Tensor,
32 | mask: torch.Tensor,
33 | return_origin=False):
34 | '''
35 | Args:
36 | pred: shape :math:`(N, 1, H, W)`, the prediction of network
37 | gt: shape :math:`(N, 1, H, W)`, the target
38 | mask: shape :math:`(N, H, W)`, the mask indicates positive regions
39 | '''
40 | positive = (gt * mask).byte()
41 | negative = ((1 - gt) * mask).byte()
42 | positive_count = int(positive.float().sum())
43 | negative_count = min(int(negative.float().sum()),
44 | int(positive_count * self.negative_ratio))
45 | loss = nn.functional.binary_cross_entropy(
46 | pred, gt, reduction='none')[:, 0, :, :]
47 | positive_loss = loss * positive.float()
48 | negative_loss = loss * negative.float()
49 | negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
50 |
51 | balance_loss = (positive_loss.sum() + negative_loss.sum()) /\
52 | (positive_count + negative_count + self.eps)
53 |
54 | if return_origin:
55 | return balance_loss, loss
56 | return balance_loss
57 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | gpu_id: '0'
3 | backbone: 'resnet18'
4 | pretrained: True
5 | HeadName: 'DB'
6 | in_channels: [64, 128, 256, 512] #[256, 512, 1024, 2048]
7 | inner_channels: 256
8 | k: 50
9 | adaptive: True
10 | start_val_epoch: 1000
11 | n_epoch: 1200
12 | batch_size: 16
13 | use_sr: True
14 | sr_lr: 0.00001
15 | base_lr: 0.002
16 | num_workers: 0
17 | show_step: 5
18 | print_format: 'linux' # linux or windows
19 | restore: True
20 | resume: './checkpoints/DB_resnet18_bs_16_ep_1200/DB.pth.tar'
21 | checkpoints: './checkpoints'
22 | is_icdar2015: True
23 | is_transform: True
24 | is_show: False
25 | train_img_format: '.jpg'
26 | val_img_format: '.jpg'
27 | train_img_dir: '/home/aistudio/work/data/icdar/train_img/'
28 | train_gt_dir: '/home/aistudio/work/data/icdar/train_gt/'
29 | val_img_dir: '/home/aistudio/work/data/icdar/test_img/'
30 | val_gt_dir: '/home/aistudio/work/data/icdar/test_gt/'
31 | radom_angle: [-10, 10]
32 | output_path: './outputs_val'
33 | decay_method: 'e_decay' # e_decay: 指数衰减, s_decay: 指定epoch衰减
34 | schedule: [500,800,1000]
35 | gama: 0.1
36 | test:
37 | gpu_id: '0'
38 | pretrained: False
39 | merge_conv_bn: False
40 | adaptive: False
41 | short_side: 736
42 | thresh: 0.5
43 | box_thresh: 0.6
44 | unclip_ratio: 2
45 | min_size: 3
46 | max_candidates: 1000
47 | is_poly: False
48 | is_icdar2015: True
49 | test_img_format: '.jpg'
50 | test_img_dir: '/home/aistudio/work/data/icdar/test_img/'
51 | test_gt_dir: '/home/aistudio/work/data/icdar/test_gt/'
52 | checkpoints: './checkpoints/DB_resnet18_bs_16_ep_1200/DB.pth.tar'
53 | out_dir: './outputs_test'
54 |
55 | pruned:
56 | gpu_id: '0'
57 | scale: [73, 77, 81, 85]
58 | base_num: 8
59 | cut_percent: 0.8
60 | pruned_checkpoints: './pruned/checkpoint/pruned_dict.pth.tar'
61 | checkpoints_dict: './pruned/checkpoint/pruned_dict.dict'
62 | save_checkpoints: './pruned/checkpoint'
63 | checkpoints: './checkpoints/DB_resnet18_bs_16_ep_1200/DB.pth.tar'
64 | finetune_lr: 0.0005
65 | resume: './checkpoints/DB_resnet18_bs_16_ep_1200/DB.pth.tar'
66 | restore: True
67 | n_epoch: 100
68 | start_val_epoch: 40
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [English](README_en.md) | 简体中文
2 |
3 | # DBNet-lite-pytorch
4 |
5 |
6 | ## 这个项目之后会在这里更新,我把之前的项目都做了下整合[pytorchOCR](https://github.com/BADBADBADBOY/pytorchOCR)
7 |
8 | ## 环境配置
9 |
10 | ```
11 | pip install -r requirement.txt
12 | cd models/dcn/
13 | sh make.sh
14 | ```
15 |
16 | ***
17 |
18 | ## 水平或倾斜文本格式
19 |
20 | 照着icdar2015的格式, x1,y1,x2,y2,x3,y3,x4,y4,label,
21 | ```
22 |
23 | image
24 | │ 1.jpg
25 | │ 2.jpg
26 | │ ...
27 | label
28 | │ gt_1.txt
29 | │ gt_2.txt
30 | | ...
31 | ```
32 | ***
33 |
34 | ## 弧形文本的格式
35 |
36 | 数据格式, x1,y1,x2,y2,x3,y3,x4,y4 ...xn,yn,label
37 |
38 | n个点组成,n的个数可以不定
39 |
40 | ```
41 | image
42 | │ 1.jpg
43 | │ 2.jpg
44 | │ ...
45 | label
46 | │ gt_1.txt
47 | │ gt_2.txt
48 | | ...
49 | ```
50 |
51 | ***
52 |
53 |
54 | ## 训练部分
55 |
56 | 在根目录的config.yaml里配置train部分的一些参数,例如一些图片位置,如果你的图片和gt文件名字是一样的,你可以设置is_icdar2015=False。
57 | 如果你不想做验证,可以直接设置start_val_epoch大于n_epoch,如果设置了做验证,会保存一个hmean最高的最优模型。
58 |
59 | ```
60 | python3 train.py
61 | ```
62 | ***
63 |
64 |
65 | ## 测试部分
66 |
67 | 测试时,配置config.yaml中test部分,对于弧形文本设置is_poly=True,其它非弧形文本设置为False
68 |
69 | ```
70 | python3 inference.py
71 | ```
72 | ***
73 | ## 模型压缩之通道剪裁
74 |
75 | ### 训练部分
76 | 1. 先进行稀疏训练,首先修改config.yaml将use_sr 设置为True,并设定sr_lr,这个设置越大压的越多,注意设置太大有可能不收敛.
77 |
78 | ```
79 | python3 train.py
80 | ```
81 | 2. 压缩模型
82 | 设置好config.yaml中pruned部分参数,运行
83 | ```
84 | python3 ./pruned/prune.py
85 | ```
86 | 3. 重新finetune模型
87 | 这里精度会很快回升,一般可以训练50-100epoch,具体自己做实验
88 | ```
89 | python3 ./pruned/train_fintune.py
90 | ```
91 |
92 | ### 测试部分
93 |
94 | ```
95 | python3 ./pruned/prune_inference.py
96 | ```
97 |
98 | ***
99 |
100 |
101 | ## 在icdar2015的测试结果
102 |
103 | |Method| head|extra data|prune ratio|model size(M)|precision(%)| recall(%) | hmean(%)|model_file|
104 | | - | - | - | - | - | - |- | - |- |
105 | | Resnet18|FPN|no|0|62.6|86.11| 76.45| 80.99|[baiduyun](https://pan.baidu.com/s/1wmbGMoluWlZ97LCqOnwjOg) (extract code: p0bk)|
106 | | Resnet18|DB|no|0.8|20.1|85.55| 76.40| 80.72||
107 | ***
108 | ## 在icdar2015的测试结果图
109 |
110 |
111 |
112 | ***
113 |
114 | #### 该项目会做
115 | - [x] 转换作者的代码便于阅读和调试
116 | - [x] 展示一些训练结果
117 | - [ ] 加入轻量化的backbone压缩模型
118 | - [ ] 通过通道剪裁压缩DB模型,精度基本不变
119 | - [ ] 通过知识蒸馏进一步提升压缩后模型效果
120 |
121 |
122 |
123 |
124 | # 参考
125 |
126 | 1. https://github.com/whai362/PSENet
127 | 2. https://github.com/MhLiao/DB
128 | 3. https://github.com/Jzz24/pytorch_quantization
129 |
130 |
131 |
--------------------------------------------------------------------------------
/models/dcn/functions/deform_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 |
4 | from .. import deform_pool_cuda
5 |
6 |
7 | class DeformRoIPoolingFunction(Function):
8 |
9 | @staticmethod
10 | def forward(ctx,
11 | data,
12 | rois,
13 | offset,
14 | spatial_scale,
15 | out_size,
16 | out_channels,
17 | no_trans,
18 | group_size=1,
19 | part_size=None,
20 | sample_per_part=4,
21 | trans_std=.0):
22 | ctx.spatial_scale = spatial_scale
23 | ctx.out_size = out_size
24 | ctx.out_channels = out_channels
25 | ctx.no_trans = no_trans
26 | ctx.group_size = group_size
27 | ctx.part_size = out_size if part_size is None else part_size
28 | ctx.sample_per_part = sample_per_part
29 | ctx.trans_std = trans_std
30 |
31 | assert 0.0 <= ctx.trans_std <= 1.0
32 | if not data.is_cuda:
33 | raise NotImplementedError
34 |
35 | n = rois.shape[0]
36 | output = data.new_empty(n, out_channels, out_size, out_size)
37 | output_count = data.new_empty(n, out_channels, out_size, out_size)
38 | deform_pool_cuda.deform_psroi_pooling_cuda_forward(
39 | data, rois, offset, output, output_count, ctx.no_trans,
40 | ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
41 | ctx.part_size, ctx.sample_per_part, ctx.trans_std)
42 |
43 | if data.requires_grad or rois.requires_grad or offset.requires_grad:
44 | ctx.save_for_backward(data, rois, offset)
45 | ctx.output_count = output_count
46 |
47 | return output
48 |
49 | @staticmethod
50 | def backward(ctx, grad_output):
51 | if not grad_output.is_cuda:
52 | raise NotImplementedError
53 |
54 | data, rois, offset = ctx.saved_tensors
55 | output_count = ctx.output_count
56 | grad_input = torch.zeros_like(data)
57 | grad_rois = None
58 | grad_offset = torch.zeros_like(offset)
59 |
60 | deform_pool_cuda.deform_psroi_pooling_cuda_backward(
61 | grad_output, data, rois, offset, output_count, grad_input,
62 | grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
63 | ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
64 | ctx.trans_std)
65 | return (grad_input, grad_rois, grad_offset, None, None, None, None,
66 | None, None, None, None)
67 |
68 |
69 | deform_roi_pooling = DeformRoIPoolingFunction.apply
70 |
--------------------------------------------------------------------------------
/utils/Logger.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 _*-
2 | # A simple torch style logger
3 | # (C) Wei YANG 2017
4 | from __future__ import absolute_import
5 |
6 | class Logger(object):
7 | def __init__(self, fpath, title=None, resume=False):
8 | self.file = None
9 | self.resume = resume
10 | self.title = '' if title == None else title
11 | if fpath is not None:
12 | if resume:
13 | self.file = open(fpath, 'r')
14 | name = self.file.readline()
15 | self.names = name.rstrip().split('\t')
16 | self.numbers = {}
17 | for _, name in enumerate(self.names):
18 | self.numbers[name] = []
19 |
20 | for numbers in self.file:
21 | numbers = numbers.rstrip().split('\t')
22 | for i in range(0, len(numbers)):
23 | self.numbers[self.names[i]].append(numbers[i])
24 | self.file.close()
25 | self.file = open(fpath, 'a')
26 | else:
27 | self.file = open(fpath, 'w')
28 |
29 | def set_names(self, names):
30 | if self.resume:
31 | pass
32 | self.numbers = {}
33 | self.names = names
34 | for _, name in enumerate(self.names):
35 | self.file.write(name)
36 | self.file.write('\t')
37 | self.numbers[name] = []
38 | self.file.write('\n')
39 | self.file.flush()
40 |
41 | def set_split(self, names):
42 | if self.resume:
43 | pass
44 | self.numbers = {}
45 | self.names = names
46 | for _, name in enumerate(self.names):
47 | self.file.write(name)
48 | self.numbers[name] = []
49 | self.file.write('\n')
50 | self.file.flush()
51 |
52 | def append(self, numbers):
53 | assert len(self.names) == len(numbers), 'Numbers do not match names'
54 | for index, num in enumerate(numbers):
55 | self.file.write("{0:.6f}".format(num))
56 | self.file.write('\t')
57 | self.numbers[self.names[index]].append(num)
58 | self.file.write('\n')
59 | self.file.flush()
60 |
61 | def close(self):
62 | if self.file is not None:
63 | self.file.close()
64 |
65 | # if __name__ == '__main__':
66 | # import numpy as np
67 | # logger = Logger('test.txt')
68 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
69 | #
70 | # length = 100
71 | # t = np.arange(length)
72 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
73 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
74 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
75 | #
76 | # for i in range(0, length):
77 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
78 | # logger.set_split(['----------','------------','--------'])
79 |
80 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 | English| [简体中文](README.md)
2 |
3 | # DBNet-lite-pytorch
4 |
5 | ## setup
6 |
7 | ```
8 | pip install -r requirement.txt
9 | cd models/dcn/
10 | sh make.sh
11 | ```
12 |
13 | ***
14 |
15 | ## data format for Horizontal or slanted text
16 |
17 | follow icdar15 dataset format, x1,y1,x2,y2,x3,y3,x4,y4,label
18 | ```
19 |
20 | image
21 | │ 1.jpg
22 | │ 2.jpg
23 | │ ...
24 | label
25 | │ gt_1.txt
26 | │ gt_2.txt
27 | | ...
28 | ```
29 | ***
30 |
31 | ## data format for curved text
32 |
33 | dataset format, x1,y1,x2,y2,x3,y3,x4,y4 ...xn,yn,label
34 |
35 | The number of N can be inconsistent,The arrangement of points is clockwise or counterclockwise
36 |
37 | ```
38 | image
39 | │ 1.jpg
40 | │ 2.jpg
41 | │ ...
42 | label
43 | │ gt_1.txt
44 | │ gt_2.txt
45 | | ...
46 | ```
47 |
48 | ***
49 |
50 |
51 | ## train
52 |
53 | Go to configure config.yaml in the root directory
54 |
55 | ```
56 | python3 train.py
57 | ```
58 | ***
59 |
60 |
61 | ## test
62 |
63 | set is_poly = True in config.yaml for curved text , others set is_poly = False
64 |
65 | ```
66 | python3 inference.py
67 | ```
68 | ***
69 | ## Channel clipping for model compression
70 |
71 | ### Training section
72 | 1. sparse training is performed first. firstly, modify config.yaml to set use_sr to True, and set sr_lr. the larger this setting is, the more pressure it will have.
73 | pay attention to the fact that it may not converge if it is too large.
74 |
75 | ```
76 | python3 train.py
77 | ```
78 | 2. Compression model
79 | Set the parameters of pruned part in config.yaml and run it
80 | ```
81 | python3 ./pruned/prune.py
82 | ```
83 | 3. Re-finetune model
84 | Here, the accuracy will pick up quickly. Generally, you can train 50-100epoch and do your own experiments
85 | ```
86 | python3 ./pruned/train_fintune.py
87 | ```
88 |
89 | ### test section
90 |
91 | ```
92 | python3 ./pruned/prune_inference.py
93 | ```
94 |
95 | ## performance in icdar2015
96 |
97 | |Method| head|extra data|prune ratio|model size(M)|precision(%)| recall(%) | hmean(%)|model_file|
98 | | - | - | - | - | - | - |- | - |- |
99 | | Resnet18|FPN|no|0|62.6|86.11| 76.45| 80.99|[baiduyun](https://pan.baidu.com/s/1wmbGMoluWlZ97LCqOnwjOg) (extract code: p0bk)|
100 | | Resnet18|DB|no|0.8|20.1|85.55| 76.40| 80.72||
101 | | mobilev3|DB|no|0|2.5|85.55| 76.40| 74.71||
102 |
103 | ***
104 | ## some result
105 |
106 |
107 |
108 | ***
109 |
110 | #### ToDoList
111 | - [x] tranform DB code format from MhLiao/DB
112 | - [x] add some performance
113 | - [ ] add light backbone
114 | - [ ] pruned big model by channel clipping
115 | - [ ] Model distillation
116 |
117 |
118 |
119 |
120 | # reference
121 |
122 | 1. https://github.com/whai362/PSENet
123 | 2. https://github.com/MhLiao/DB
124 | 3. https://github.com/Jzz24/pytorch_quantization
125 |
126 |
127 |
--------------------------------------------------------------------------------
/models/head/DB_Head.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: DB_Head.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import torch
11 | import torch.nn as nn
12 |
13 | class ConvBnRelu(nn.Module):
14 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding,bias):
15 | super(ConvBnRelu,self).__init__()
16 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,bias=bias) # Reduce channels
17 | self.bn = nn.BatchNorm2d(out_channels)
18 | self.relu = nn.ReLU(inplace=True)
19 | def forward(self, x):
20 | x = self.conv(x)
21 | x = self.bn(x)
22 | x = self.relu(x)
23 | return x
24 |
25 | class DB_Head(nn.Module):
26 | def __init__(self,in_channels,inner_channels,bias=False):
27 | super(DB_Head,self).__init__()
28 |
29 | self.up5 = nn.Upsample(scale_factor=2, mode='nearest')
30 | self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
31 | self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
32 |
33 | self.in5 = ConvBnRelu(in_channels[-1], inner_channels, 1,stride=1, padding=0, bias=bias)
34 | self.in4 = ConvBnRelu(in_channels[-2], inner_channels, 1,stride=1, padding=0, bias=bias)
35 | self.in3 = ConvBnRelu(in_channels[-3], inner_channels, 1, stride=1, padding=0,bias=bias)
36 | self.in2 = ConvBnRelu(in_channels[-4], inner_channels, 1,stride=1, padding=0, bias=bias)
37 |
38 | self.out5 = nn.Sequential(
39 | ConvBnRelu(inner_channels, inner_channels //4, 3,stride=1, padding=1, bias=bias),
40 | nn.Upsample(scale_factor=8, mode='nearest'))
41 | self.out4 = nn.Sequential(
42 | ConvBnRelu(inner_channels, inner_channels //4, 3, stride=1,padding=1, bias=bias),
43 | nn.Upsample(scale_factor=4, mode='nearest'))
44 | self.out3 = nn.Sequential(
45 | ConvBnRelu(inner_channels, inner_channels //4, 3, stride=1,padding=1, bias=bias),
46 | nn.Upsample(scale_factor=2, mode='nearest'))
47 | self.out2 = ConvBnRelu(inner_channels, inner_channels//4, 3, stride=1,padding=1, bias=bias)
48 |
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d):
51 | nn.init.kaiming_normal_(m.weight.data)
52 | elif isinstance(m, nn.BatchNorm2d):
53 | m.weight.data.fill_(1.)
54 | m.bias.data.fill_(1e-4)
55 |
56 | def forward(self, x):
57 |
58 | c2, c3, c4, c5 = x
59 | in5 = self.in5(c5)
60 | in4 = self.in4(c4)
61 | in3 = self.in3(c3)
62 | in2 = self.in2(c2)
63 |
64 | out4 = self.up5(in5) + in4 # 1/16
65 | out3 = self.up4(out4) + in3 # 1/8
66 | out2 = self.up3(out3) + in2 # 1/4
67 |
68 | p5 = self.out5(in5)
69 | p4 = self.out4(out4)
70 | p3 = self.out3(out3)
71 | p2 = self.out2(out2)
72 | fuse = torch.cat((p5, p4, p3, p2), 1)
73 | return fuse
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: tools.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 |
11 | import copy
12 | import math
13 | import cv2
14 | import os
15 | import torch
16 | import numpy as np
17 | from tabulate import tabulate
18 |
19 | def judgePoint(point1,point2,center_point):
20 | if(point1[0]<=0 and point2[0]>0):
21 | return True
22 | if(point1[0]==0 and point2[0]==0):
23 | return point1[1]0):
26 | return True
27 | if(det<0):
28 | return False
29 | d1 = (point1[0]-center_point[0])*(point1[0]-center_point[0])-(point1[1]-center_point[1])*(point1[1]-center_point[1])
30 | d2 = (point2[0]-center_point[0])*(point2[0]-center_point[0])-(point2[1]-center_point[1])*(point2[1]-center_point[1])
31 | return d1 thresh] = 1
75 | pred_binary = pred_binary.astype(np.int32)
76 | gt_binary = gt_binarys.data.cpu().numpy() * training_masks
77 | gt_binary = gt_binary.astype(np.int32)
78 | running_metric_binary.update(gt_binary, pred_binary)
79 | score_binary, _ = running_metric_binary.get_scores()
80 | return score_binary
81 |
82 |
83 | def save_checkpoint(state, checkpoint='checkpoints', filename='DB.pth.tar'):
84 | filepath = os.path.join(checkpoint, filename)
85 | torch.save(state, filepath)
--------------------------------------------------------------------------------
/models/head/FPN_Head.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: FPN_Head.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | class ConvBnRelu(nn.Module):
15 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding):
16 | super(ConvBnRelu,self).__init__()
17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) # Reduce channels
18 | self.bn = nn.BatchNorm2d(out_channels)
19 | self.relu = nn.ReLU(inplace=True)
20 | def forward(self, x):
21 | x = self.conv(x)
22 | x = self.bn(x)
23 | x = self.relu(x)
24 | return x
25 |
26 |
27 | class FPN_Head(nn.Module):
28 | def __init__(self,in_channels,inner_channels):
29 | super(FPN_Head,self).__init__()
30 | # Top layer
31 | self.toplayer = ConvBnRelu(in_channels[-1], inner_channels, kernel_size=1, stride=1, padding=0) # Reduce channels
32 | # Smooth layers
33 | self.smooth1 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1)
34 | self.smooth2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1)
35 | self.smooth3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1)
36 | # Lateral layers
37 | self.latlayer1 = ConvBnRelu(in_channels[-2], inner_channels, kernel_size=1, stride=1, padding=0)
38 | self.latlayer2 = ConvBnRelu(in_channels[-3], inner_channels, kernel_size=1, stride=1, padding=0)
39 | self.latlayer3 = ConvBnRelu(in_channels[-4], inner_channels, kernel_size=1, stride=1, padding=0)
40 | # Out map
41 | self.conv_out= nn.Conv2d(inner_channels*4, inner_channels, kernel_size=3, stride=1, padding=1)
42 |
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d):
45 | nn.init.kaiming_normal_(m.weight.data)
46 | elif isinstance(m, nn.BatchNorm2d):
47 | m.weight.data.fill_(1.)
48 | m.bias.data.fill_(1e-4)
49 |
50 | def _upsample(self, x, y, scale=1):
51 | _, _, H, W = y.size()
52 | # return F.upsample(x, size=(H // scale, W // scale), mode='nearest')
53 | return F.interpolate(x, size=(H// scale, W// scale), mode='bilinear', align_corners=True)
54 |
55 | def _upsample_add(self, x, y):
56 | _, _, H, W = y.size()
57 | # return F.upsample(x, size=(H, W), mode='nearest') + y
58 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
59 |
60 | def forward(self, x):
61 | c2, c3, c4, c5 = x
62 | ##
63 | p5 = self.toplayer(c5)
64 | c4 = self.latlayer1(c4)
65 | p4 = self._upsample_add(p5, c4)
66 | p4 = self.smooth1(p4)
67 | c3 = self.latlayer2(c3)
68 | p3 = self._upsample_add(p4, c3)
69 | p3 = self.smooth2(p3)
70 | c2 = self.latlayer3(c2)
71 | p2 = self._upsample_add(p3, c2)
72 | p2 = self.smooth3(p2)
73 | ##
74 | p3 = self._upsample(p3, p2)
75 | p4 = self._upsample(p4, p2)
76 | p5 = self._upsample(p5, p2)
77 |
78 | out = torch.cat((p2, p3, p4, p5), 1)
79 | out = self.conv_out(out)
80 | return out
81 |
--------------------------------------------------------------------------------
/dataloader/MakeSegMap.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: MakeSegMap.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 |
11 | import cv2
12 | import pyclipper
13 | from shapely.geometry import Polygon
14 | import numpy as np
15 |
16 | class MakeSegDetectionData():
17 | r'''
18 | Making binary mask from detection data with ICDAR format.
19 | Typically following the process of class `MakeICDARData`.
20 | '''
21 | def __init__(self, min_text_size = 8,shrink_ratio = 0.4,is_training = True):
22 | self.min_text_size =min_text_size
23 | self.shrink_ratio = shrink_ratio
24 | self.is_training = is_training
25 | def process(self, img,polys,dontcare):
26 | '''
27 | requied keys:
28 | image, polygons, ignore_tags, filename
29 | adding keys:
30 | mask
31 | '''
32 | h, w = img.shape[:2]
33 | if self.is_training:
34 | polys, dontcare = self.validate_polygons(
35 | polys, dontcare, h, w)
36 | gt = np.zeros((1, h, w), dtype=np.float32)
37 | mask = np.ones((h, w), dtype=np.float32)
38 | for i in range(len(polys)):
39 | polygon = polys[i]
40 | height = max(polygon[:, 1]) - min(polygon[:, 1])
41 | width = max(polygon[:, 0]) - min(polygon[:, 0])
42 | if dontcare[i] or min(height, width) < self.min_text_size:
43 | cv2.fillPoly(mask, polygon.astype(
44 | np.int32)[np.newaxis, :, :], 0)
45 | dontcare[i] = True
46 | else:
47 | polygon_shape = Polygon(polygon)
48 | distance = polygon_shape.area * \
49 | (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
50 | subject = [tuple(l) for l in polys[i]]
51 | padding = pyclipper.PyclipperOffset()
52 | padding.AddPath(subject, pyclipper.JT_ROUND,
53 | pyclipper.ET_CLOSEDPOLYGON)
54 | shrinked = padding.Execute(-distance)
55 | if shrinked == []:
56 | cv2.fillPoly(mask, polygon.astype(
57 | np.int32)[np.newaxis, :, :], 0)
58 | dontcare[i] = True
59 | continue
60 | shrinked = np.array(shrinked[0]).reshape(-1, 2)
61 | cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)
62 | return img,gt,mask
63 |
64 | def validate_polygons(self, polygons, ignore_tags, h, w):
65 | '''
66 | polygons (numpy.array, required): of shape (num_instances, num_points, 2)
67 | '''
68 | if len(polygons) == 0:
69 | return polygons, ignore_tags
70 | assert len(polygons) == len(ignore_tags)
71 | for polygon in polygons:
72 | polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
73 | polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
74 |
75 | for i in range(len(polygons)):
76 | area = self.polygon_area(polygons[i])
77 | if abs(area) < 1:
78 | ignore_tags[i] = True
79 | if area > 0:
80 | polygons[i] = polygons[i][::-1, :]
81 | return polygons, ignore_tags
82 |
83 | def polygon_area(self, polygon):
84 | edge = 0
85 | for i in range(polygon.shape[0]):
86 | next_index = (i + 1) % polygon.shape[0]
87 | edge += (polygon[next_index, 0] - polygon[i, 0]) * (polygon[next_index, 1] - polygon[i, 1])
88 |
89 | return edge / 2.
90 |
--------------------------------------------------------------------------------
/utils/model_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: model_eval.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import sys
11 | import cv2
12 | import torch
13 | import math
14 | import os
15 | import glob
16 | import argparse
17 | import pyclipper
18 | import torchvision.transforms as transforms
19 | import numpy as np
20 | from shapely.geometry import Polygon
21 | from tqdm import tqdm
22 | import time
23 | from cal_rescall.script import cal_recall_precison_f1
24 | from .DB_postprocesss import DBPostProcess
25 | import copy
26 | import yaml
27 | from utils.tools import *
28 | from PIL import Image
29 |
30 | def val(model,config):
31 | model.eval()
32 | files = glob.glob(os.path.join(config['train']['val_img_dir'],'*'+config['train']['val_img_format']))
33 | if not (os.path.exists(config['train']['output_path'])):
34 | os.mkdir(config['train']['output_path'])
35 |
36 | if not (os.path.exists(os.path.join(config['train']['output_path'],'img_text'))):
37 | os.mkdir(os.path.join(config['train']['output_path'],'img_text'))
38 |
39 | if not (os.path.exists(os.path.join(config['train']['output_path'],'img_result'))):
40 | os.mkdir(os.path.join(config['train']['output_path'],'img_result'))
41 |
42 | bar = tqdm(total=len(files))
43 |
44 |
45 | params = {'thresh':config['test']['thresh'],
46 | 'box_thresh':config['test']['box_thresh'],
47 | 'max_candidates':config['test']['max_candidates'],
48 | 'is_poly':config['test']['is_poly'],
49 | 'unclip_ratio':config['test']['unclip_ratio'],
50 | 'min_size':config['test']['min_size']
51 | }
52 |
53 | dbprocess = DBPostProcess(params)
54 | total_frame = 0.0
55 | total_time = 0.0
56 | for file in files:
57 |
58 | bar.update(1)
59 | img = cv2.imread(file)
60 | img_ori = img.copy()
61 | img_name = file.split('/')[-1].split('.')[0]
62 | img = resize_image(img,config['test']['short_side'])
63 | img = Image.fromarray(img)
64 | img = img.convert('RGB')
65 | img = transforms.ToTensor()(img)
66 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img).unsqueeze(0).cuda()
67 |
68 | torch.cuda.synchronize()
69 | start = time.time()
70 |
71 | with torch.no_grad():
72 | out = model(img)
73 |
74 | scale = (img_ori.shape[1] * 1.0 / out.shape[3], img_ori.shape[0] * 1.0 / out.shape[2])
75 | bbox_batch,score_batch = dbprocess(out.cpu().numpy(),[scale])
76 |
77 | torch.cuda.synchronize()
78 | end = time.time()
79 | total_frame += 1
80 | total_time += (end - start)
81 | sys.stdout.flush()
82 |
83 | for bbox in bbox_batch[0]:
84 | img_ori = cv2.drawContours(img_ori.copy(), [bbox.reshape(-1, 2).astype(np.int)], -1, (0, 255, 0), 1)
85 |
86 | if config['test']['is_icdar2015']:
87 | text_file = 'res_' + img_name + '.txt'
88 | else:
89 | text_file = img_name + '.txt'
90 |
91 | with open(os.path.join(config['train']['output_path'],'img_text',text_file),'w+',encoding='utf-8') as fid:
92 | for bbox in bbox_batch[0]:
93 | if(len(bbox)==0):
94 | continue
95 | bbox = bbox.reshape(-1).tolist()
96 | bbox = [str(x) for x in bbox]
97 | bbox = ','.join(bbox)
98 | fid.write(bbox+'\n')
99 |
100 | cv2.imwrite(os.path.join(config['train']['output_path'],'img_result',img_name+'.jpg'),img_ori)
101 | bar.close()
102 | print('fps: %.2f'%(total_frame / total_time))
103 | from cal_rescall.script import cal_recall_precison_f1
104 | result_dict = cal_recall_precison_f1(config['train']['val_gt_dir'], os.path.join(config['train']['output_path'], 'img_text'))
105 | return result_dict
--------------------------------------------------------------------------------
/dataloader/dataload.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: dataload.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import numpy as np
11 | from PIL import Image
12 | from torch.utils import data
13 | import glob
14 | import cv2
15 | import random
16 | import os
17 | import torchvision.transforms as transforms
18 | import torch
19 | from .random_thansform import Random_Augment
20 | from .MakeSegMap import MakeSegDetectionData
21 | from .MakeBorderMap import MakeBorderMap
22 |
23 | def get_img(img_path):
24 | img = cv2.imread(img_path)
25 | return img
26 |
27 | def get_bboxes(gt_path,config):
28 | with open(gt_path,'r',encoding='utf-8') as fid:
29 | lines = fid.readlines()
30 | polys = []
31 | tags = []
32 | for line in lines:
33 | line = line.replace('\ufeff','').replace( '\xef\xbb\xbf','')
34 | gt = line.split(',')
35 | if "#" in gt[-1]:
36 | tags.append(True)
37 | else:
38 | tags.append(False)
39 | if(config['train']['is_icdar2015']):
40 | box = [int(gt[i]) for i in range(8)]
41 | else:
42 | box = [int(gt[i]) for i in range(len(gt)-1)]
43 | polys.append(box)
44 | return np.array(polys), tags
45 |
46 | class DataLoader(data.Dataset):
47 | def __init__(self, config):
48 | self.config = config
49 | self.ra = Random_Augment()
50 | self.ms = MakeSegDetectionData()
51 | self.mb = MakeBorderMap()
52 | img_paths = glob.glob(os.path.join(config['train']['train_img_dir'],'*'+config['train']['train_img_format']))
53 | gt_paths = []
54 | for img_path in img_paths:
55 | im_name = img_path.split('/')[-1].split('.')[0]
56 | if(config['train']['is_icdar2015']):
57 | gt_file_name = 'gt_'+im_name+'.txt'
58 | else:
59 | gt_file_name = im_name + '.txt'
60 | gt_paths.append(os.path.join(config['train']['train_gt_dir'],gt_file_name))
61 | self.img_paths = img_paths
62 | self.gt_paths = gt_paths
63 |
64 | def __len__(self):
65 | return len(self.img_paths)
66 |
67 | def __getitem__(self, index):
68 | img_path = self.img_paths[index]
69 | gt_path = self.gt_paths[index]
70 |
71 | img = get_img(img_path)
72 | polys, dontcare = get_bboxes(gt_path,self.config)
73 |
74 | if self.config['train']['is_transform']:
75 | img, polys = self.ra.random_scale(img, polys, 640)
76 | img, polys = self.ra.random_rotate(img, polys,self.config['train']['radom_angle'])
77 | img, polys = self.ra.random_flip(img, polys)
78 | img, polys, dontcare = self.ra.random_crop_db(img, polys, dontcare)
79 |
80 | img, gt, gt_mask = self.ms.process(img, polys, dontcare)
81 | img, thresh_map, thresh_mask = self.mb.process(img, polys, dontcare)
82 |
83 | if self.config['train']['is_show']:
84 | cv2.imwrite('img.jpg',img)
85 | cv2.imwrite('gt.jpg',gt[0]*255)
86 | cv2.imwrite('gt_mask.jpg',gt_mask[0]*255)
87 | cv2.imwrite('thresh_map.jpg',thresh_map*255)
88 | cv2.imwrite('thresh_mask.jpg',thresh_mask*255)
89 |
90 | if self.config['train']['is_transform']:
91 | img = Image.fromarray(img)
92 | img = img.convert('RGB')
93 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img)
94 | else:
95 | img = Image.fromarray(img)
96 | img = img.convert('RGB')
97 |
98 | img = transforms.ToTensor()(img)
99 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
100 |
101 | gt = torch.from_numpy(gt).float()
102 | gt_mask = torch.from_numpy(gt_mask).float()
103 | thresh_map = torch.from_numpy(thresh_map).float()
104 | thresh_mask = torch.from_numpy(thresh_mask).float()
105 |
106 | return img, gt,gt_mask,thresh_map,thresh_mask
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/models/dcn/src/deform_pool_cuda.cpp:
--------------------------------------------------------------------------------
1 | // modify from
2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
3 |
4 | // based on
5 | // author: Charles Shang
6 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
7 |
8 | #include
9 |
10 | #include
11 | #include
12 |
13 | void DeformablePSROIPoolForward(
14 | const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
15 | at::Tensor out, at::Tensor top_count, const int batch, const int channels,
16 | const int height, const int width, const int num_bbox,
17 | const int channels_trans, const int no_trans, const float spatial_scale,
18 | const int output_dim, const int group_size, const int pooled_size,
19 | const int part_size, const int sample_per_part, const float trans_std);
20 |
21 | void DeformablePSROIPoolBackwardAcc(
22 | const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
23 | const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
24 | at::Tensor trans_grad, const int batch, const int channels,
25 | const int height, const int width, const int num_bbox,
26 | const int channels_trans, const int no_trans, const float spatial_scale,
27 | const int output_dim, const int group_size, const int pooled_size,
28 | const int part_size, const int sample_per_part, const float trans_std);
29 |
30 | void deform_psroi_pooling_cuda_forward(
31 | at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
32 | at::Tensor top_count, const int no_trans, const float spatial_scale,
33 | const int output_dim, const int group_size, const int pooled_size,
34 | const int part_size, const int sample_per_part, const float trans_std) {
35 | AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
36 |
37 | const int batch = input.size(0);
38 | const int channels = input.size(1);
39 | const int height = input.size(2);
40 | const int width = input.size(3);
41 | const int channels_trans = no_trans ? 2 : trans.size(1);
42 |
43 | const int num_bbox = bbox.size(0);
44 | if (num_bbox != out.size(0))
45 | AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
46 | out.size(0), num_bbox);
47 |
48 | DeformablePSROIPoolForward(
49 | input, bbox, trans, out, top_count, batch, channels, height, width,
50 | num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
51 | pooled_size, part_size, sample_per_part, trans_std);
52 | }
53 |
54 | void deform_psroi_pooling_cuda_backward(
55 | at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
56 | at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
57 | const int no_trans, const float spatial_scale, const int output_dim,
58 | const int group_size, const int pooled_size, const int part_size,
59 | const int sample_per_part, const float trans_std) {
60 | AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
61 | AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
62 |
63 | const int batch = input.size(0);
64 | const int channels = input.size(1);
65 | const int height = input.size(2);
66 | const int width = input.size(3);
67 | const int channels_trans = no_trans ? 2 : trans.size(1);
68 |
69 | const int num_bbox = bbox.size(0);
70 | if (num_bbox != out_grad.size(0))
71 | AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
72 | out_grad.size(0), num_bbox);
73 |
74 | DeformablePSROIPoolBackwardAcc(
75 | out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
76 | channels, height, width, num_bbox, channels_trans, no_trans,
77 | spatial_scale, output_dim, group_size, pooled_size, part_size,
78 | sample_per_part, trans_std);
79 | }
80 |
81 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
82 | m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
83 | "deform psroi pooling forward(CUDA)");
84 | m.def("deform_psroi_pooling_cuda_backward",
85 | &deform_psroi_pooling_cuda_backward,
86 | "deform psroi pooling backward(CUDA)");
87 | }
--------------------------------------------------------------------------------
/models/head/seg_detector.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: seg_detector.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 |
11 | from collections import OrderedDict
12 | import torch
13 | import torch.nn as nn
14 | from models.head.FPN_Head import FPN_Head
15 | from models.head.DB_Head import DB_Head
16 |
17 | BatchNorm2d = nn.BatchNorm2d
18 |
19 | class SegDetector(nn.Module):
20 | def __init__(self,headname='DB',
21 | in_channels=[64, 128, 256, 512],
22 | inner_channels=256, k=10,
23 | bias=False, adaptive=False, smooth=False, serial=False,
24 | *args, **kwargs):
25 | '''
26 | bias: Whether conv layers have bias or not.
27 | adaptive: Whether to use adaptive threshold training or not.
28 | smooth: If true, use bilinear instead of deconv.
29 | serial: If true, thresh prediction will combine segmentation result as input.
30 | '''
31 | super(SegDetector, self).__init__()
32 | self.k = k
33 | self.serial = serial
34 |
35 | if(headname == 'FPN'):
36 | self.head = FPN_Head(in_channels, inner_channels)
37 | else:
38 | self.head = DB_Head(in_channels, inner_channels)
39 |
40 | self.binarize = nn.Sequential(
41 | nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
42 | BatchNorm2d(inner_channels // 4),
43 | nn.ReLU(inplace=True),
44 | nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
45 | BatchNorm2d(inner_channels // 4),
46 | nn.ReLU(inplace=True),
47 | nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2),
48 | nn.Sigmoid())
49 | self.binarize.apply(self.weights_init)
50 |
51 | self.adaptive = adaptive
52 | if adaptive:
53 | self.thresh = self._init_thresh(
54 | inner_channels, serial=serial, smooth=smooth, bias=bias)
55 | self.thresh.apply(self.weights_init)
56 |
57 | def weights_init(self, m):
58 | classname = m.__class__.__name__
59 | if classname.find('Conv') != -1:
60 | nn.init.kaiming_normal_(m.weight.data)
61 | elif classname.find('BatchNorm') != -1:
62 | m.weight.data.fill_(1.)
63 | m.bias.data.fill_(1e-4)
64 |
65 | def _init_thresh(self, inner_channels,
66 | serial=False, smooth=False, bias=False):
67 | in_channels = inner_channels
68 | if serial:
69 | in_channels += 1
70 | self.thresh = nn.Sequential(
71 | nn.Conv2d(in_channels, inner_channels //
72 | 4, 3, padding=1, bias=bias),
73 | BatchNorm2d(inner_channels // 4),
74 | nn.ReLU(inplace=True),
75 | self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
76 | BatchNorm2d(inner_channels // 4),
77 | nn.ReLU(inplace=True),
78 | self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
79 | nn.Sigmoid())
80 | return self.thresh
81 |
82 | def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
83 | if smooth:
84 | inter_out_channels = out_channels
85 | if out_channels == 1:
86 | inter_out_channels = in_channels
87 | module_list = [
88 | nn.Upsample(scale_factor=2, mode='nearest'),
89 | nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
90 | if out_channels == 1:
91 | module_list.append(
92 | nn.Conv2d(in_channels, out_channels,
93 | kernel_size=1, stride=1, padding=1, bias=True))
94 |
95 | return nn.Sequential(module_list)
96 | else:
97 | return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
98 |
99 | def forward(self, features, gt=None, masks=None, training=False):
100 |
101 | fuse = self.head(features)
102 | binary = self.binarize(fuse)
103 |
104 | if self.training:
105 | result = OrderedDict(binary=binary)
106 | else:
107 | return binary
108 | if self.adaptive and self.training:
109 | if self.serial:
110 | fuse = torch.cat(
111 | (fuse, nn.functional.interpolate(
112 | binary, fuse.shape[2:])), 1)
113 | thresh = self.thresh(fuse)
114 | thresh_binary = self.step_function(binary, thresh)
115 | result.update(thresh=thresh, thresh_binary=thresh_binary)
116 | return result
117 |
118 | def step_function(self, x, y):
119 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
120 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: inference.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 |
11 | import sys
12 | sys.path.append('/home/aistudio/external-libraries')
13 | import cv2
14 | import torch
15 | import math
16 | import os
17 | import glob
18 | import time
19 | import copy
20 | import yaml
21 | import argparse
22 | import pyclipper
23 | import numpy as np
24 | from tqdm import tqdm
25 | from models.DBNet import DBNet
26 | import torchvision.transforms as transforms
27 | from shapely.geometry import Polygon
28 | from cal_rescall.script import cal_recall_precison_f1
29 | from utils.DB_postprocesss import *
30 | from utils.tools import *
31 | from utils.fuse_model import fuse_module
32 |
33 | def test_net(config):
34 | os.environ["CUDA_VISIBLE_DEVICES"] = config['test']['gpu_id']
35 |
36 | config['train']['pretrained'] = config['test']['pretrained']
37 | files = glob.glob(os.path.join(config['test']['test_img_dir'],'*'+config['test']['test_img_format']))
38 | model = DBNet(config).cuda()
39 |
40 | model_dict = torch.load(config['test']['checkpoints'])['state_dict']
41 | state = model.state_dict()
42 | for key in state.keys():
43 | if key in model_dict.keys():
44 | state[key] = model_dict[key]
45 | model.load_state_dict(state)
46 |
47 | if(config['test']['merge_conv_bn']):
48 | fuse_module(model)
49 | print('merge conv bn ok!!!')
50 |
51 | if not (os.path.exists(config['test']['out_dir'])):
52 | os.mkdir(config['test']['out_dir'])
53 |
54 | if not (os.path.exists(os.path.join(config['test']['out_dir'],'img_text'))):
55 | os.mkdir(os.path.join(config['test']['out_dir'],'img_text'))
56 |
57 | if not (os.path.exists(os.path.join(config['test']['out_dir'],'img_result'))):
58 | os.mkdir(os.path.join(config['test']['out_dir'],'img_result'))
59 |
60 | bar = tqdm(total=len(files))
61 | params = {'thresh':config['test']['thresh'],
62 | 'box_thresh':config['test']['box_thresh'],
63 | 'max_candidates':config['test']['max_candidates'],
64 | 'is_poly':config['test']['is_poly'],
65 | 'unclip_ratio':config['test']['unclip_ratio'],
66 | 'min_size':config['test']['min_size']
67 | }
68 |
69 | dbprocess = DBPostProcess(params)
70 |
71 | total_frame = 0.0
72 | total_time = 0.0
73 |
74 | for file in files:
75 | model.eval()
76 | bar.update(1)
77 | img = cv2.imread(file)
78 | img_ori = img.copy()
79 | img_name = file.split('/')[-1].split('.')[0]
80 | img = resize_image(img,config['test']['short_side'])
81 |
82 | img = transforms.ToTensor()(img)
83 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img).unsqueeze(0).cuda()
84 |
85 | torch.cuda.synchronize()
86 | start = time.time()
87 |
88 | with torch.no_grad():
89 | out = model(img)
90 |
91 | scale = (img_ori.shape[1] * 1.0 / out.shape[3], img_ori.shape[0] * 1.0 / out.shape[2])
92 | bbox_batch,score_batch = dbprocess(out.cpu().numpy(),[scale])
93 |
94 | torch.cuda.synchronize()
95 | end = time.time()
96 | total_frame += 1
97 | total_time += (end - start)
98 | sys.stdout.flush()
99 |
100 | for bbox in bbox_batch[0]:
101 | bbox = bbox.reshape(-1, 2).astype(np.int)
102 | # bbox = sort_coord(bbox)
103 | img_ori = cv2.drawContours(img_ori.copy(), [bbox], -1, (0, 255, 0), 1)
104 |
105 | if config['test']['is_icdar2015']:
106 | text_file = 'res_' + img_name + '.txt'
107 | else:
108 | text_file = img_name + '.txt'
109 |
110 | with open(os.path.join(config['test']['out_dir'],'img_text',text_file),'w+',encoding='utf-8') as fid:
111 | for bbox in bbox_batch[0]:
112 | if(len(bbox)==0):
113 | continue
114 | bbox = bbox.reshape(-1, 2).astype(np.int)
115 | # bbox = sort_coord(bbox)
116 | bbox = bbox.reshape(-1).tolist()
117 | bbox = [str(x) for x in bbox]
118 | bbox = ','.join(bbox)
119 | fid.write(bbox+'\n')
120 |
121 | cv2.imwrite(os.path.join(config['test']['out_dir'],'img_result',img_name+'.jpg'),img_ori)
122 | bar.close()
123 | print('fps: %.2f'%(total_frame / total_time))
124 | result_dict = cal_recall_precison_f1(config['test']['test_gt_dir'], os.path.join(config['test']['out_dir'], 'img_text'))
125 | return result_dict
126 |
127 | if __name__ == '__main__':
128 | stream = open('config.yaml', 'r', encoding='utf-8')
129 | config = yaml.load(stream,Loader=yaml.FullLoader)
130 | result_dict = test_net(config)
131 | print(result_dict)
--------------------------------------------------------------------------------
/pruned/prune_inference.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 _*-
2 | """
3 | @author:fxw
4 | @file: test.py
5 | @time: 2020/04/28
6 | """
7 |
8 | import sys
9 |
10 | sys.path.append('/home/aistudio/external-libraries')
11 | sys.path.append('./')
12 | import cv2
13 | import torch
14 | import math
15 | import os
16 | import glob
17 | import time
18 | import copy
19 | import yaml
20 | import argparse
21 | import pyclipper
22 | import numpy as np
23 | from tqdm import tqdm
24 | from models.DBNet import DBNet
25 | import torchvision.transforms as transforms
26 | from shapely.geometry import Polygon
27 | from cal_rescall.script import cal_recall_precison_f1
28 | from utils.DB_postprocesss import *
29 | from utils.tools import *
30 | from utils.fuse_model import fuse_module
31 | from pruned.get_pruned_model import load_prune_model
32 |
33 |
34 | def test_net(config):
35 | os.environ["CUDA_VISIBLE_DEVICES"] = config['test']['gpu_id']
36 |
37 | config['train']['pretrained'] = config['test']['pretrained']
38 | files = glob.glob(os.path.join(config['test']['test_img_dir'], '*' + config['test']['test_img_format']))
39 | model = DBNet(config)
40 | model = load_prune_model(model, config['pruned']['checkpoints_dict']).cuda()
41 |
42 | model_dict = torch.load(config['pruned']['checkpoints'])['state_dict']
43 | state = model.state_dict()
44 | for key in state.keys():
45 | if key in model_dict.keys():
46 | state[key] = model_dict[key]
47 | model.load_state_dict(state)
48 |
49 | if (config['test']['merge_conv_bn']):
50 | fuse_module(model)
51 | print('merge conv bn ok!!!')
52 |
53 | if not (os.path.exists(config['test']['out_dir'])):
54 | os.mkdir(config['test']['out_dir'])
55 |
56 | if not (os.path.exists(os.path.join(config['test']['out_dir'], 'img_text'))):
57 | os.mkdir(os.path.join(config['test']['out_dir'], 'img_text'))
58 |
59 | if not (os.path.exists(os.path.join(config['test']['out_dir'], 'img_result'))):
60 | os.mkdir(os.path.join(config['test']['out_dir'], 'img_result'))
61 |
62 | bar = tqdm(total=len(files))
63 | params = {'thresh': config['test']['thresh'],
64 | 'box_thresh': config['test']['box_thresh'],
65 | 'max_candidates': config['test']['max_candidates'],
66 | 'is_poly': config['test']['is_poly'],
67 | 'unclip_ratio': config['test']['unclip_ratio'],
68 | 'min_size': config['test']['min_size']
69 | }
70 |
71 | dbprocess = DBPostProcess(params)
72 |
73 | total_frame = 0.0
74 | total_time = 0.0
75 |
76 | for file in files:
77 | model.eval()
78 | bar.update(1)
79 | img = cv2.imread(file)
80 | img_ori = img.copy()
81 | img_name = file.split('/')[-1].split('.')[0]
82 | img = resize_image(img, config['test']['short_side'])
83 |
84 | img = transforms.ToTensor()(img)
85 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img).unsqueeze(0).cuda()
86 |
87 | torch.cuda.synchronize()
88 | start = time.time()
89 |
90 | with torch.no_grad():
91 | out = model(img)
92 |
93 | scale = (img_ori.shape[1] * 1.0 / out.shape[3], img_ori.shape[0] * 1.0 / out.shape[2])
94 | bbox_batch, score_batch = dbprocess(out.cpu().numpy(), [scale])
95 |
96 | torch.cuda.synchronize()
97 | end = time.time()
98 | total_frame += 1
99 | total_time += (end - start)
100 | sys.stdout.flush()
101 |
102 | for bbox in bbox_batch[0]:
103 | bbox = bbox.reshape(-1, 2).astype(np.int)
104 | # bbox = sort_coord(bbox)
105 | img_ori = cv2.drawContours(img_ori.copy(), [bbox], -1, (0, 255, 0), 1)
106 |
107 | if config['test']['is_icdar2015']:
108 | text_file = 'res_' + img_name + '.txt'
109 | else:
110 | text_file = img_name + '.txt'
111 |
112 | with open(os.path.join(config['test']['out_dir'], 'img_text', text_file), 'w+', encoding='utf-8') as fid:
113 | for bbox in bbox_batch[0]:
114 | if (len(bbox) == 0):
115 | continue
116 | bbox = bbox.reshape(-1, 2).astype(np.int)
117 | # bbox = sort_coord(bbox)
118 | bbox = bbox.reshape(-1).tolist()
119 | bbox = [str(x) for x in bbox]
120 | bbox = ','.join(bbox)
121 | fid.write(bbox + '\n')
122 |
123 | cv2.imwrite(os.path.join(config['test']['out_dir'], 'img_result', img_name + '.jpg'), img_ori)
124 | bar.close()
125 | print('fps: %.2f' % (total_frame / total_time))
126 | result_dict = cal_recall_precison_f1(config['test']['test_gt_dir'],
127 | os.path.join(config['test']['out_dir'], 'img_text'))
128 | return result_dict
129 |
130 |
131 | if __name__ == '__main__':
132 | stream = open('./config.yaml', 'r', encoding='utf-8')
133 | config = yaml.load(stream, Loader=yaml.FullLoader)
134 | result_dict = test_net(config)
135 | print(result_dict)
--------------------------------------------------------------------------------
/dataloader/MakeBorderMap.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: MakeBorderMap.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import warnings
11 | import numpy as np
12 | import cv2
13 | from shapely.geometry import Polygon
14 | import pyclipper
15 | import pdb
16 | class MakeBorderMap():
17 | def __init__(self,shrink_ratio=0.4,thresh_min=0.3,thresh_max=0.7):
18 | super(MakeBorderMap,self).__init__()
19 | self.shrink_ratio = shrink_ratio
20 | self.thresh_min = thresh_min
21 | self.thresh_max = thresh_max
22 | def process(self, img,polys,dontcare):
23 | thresh_map = np.zeros(img.shape[:2], dtype=np.float32)
24 | thresh_mask = np.zeros(img.shape[:2], dtype=np.float32)
25 |
26 | for i in range(len(polys)):
27 | if dontcare[i]:
28 | continue
29 | self.draw_border_map(polys[i], thresh_map, mask=thresh_mask)
30 | thresh_map = thresh_map * (self.thresh_max - self.thresh_min) + self.thresh_min
31 | return img,thresh_map,thresh_mask
32 |
33 | def draw_border_map(self, polygon, canvas, mask):
34 | polygon = np.array(polygon)
35 | assert polygon.ndim == 2
36 | assert polygon.shape[1] == 2
37 |
38 | polygon_shape = Polygon(polygon)
39 | distance = polygon_shape.area * \
40 | (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
41 | subject = [tuple(l) for l in polygon]
42 | padding = pyclipper.PyclipperOffset()
43 | padding.AddPath(subject, pyclipper.JT_ROUND,
44 | pyclipper.ET_CLOSEDPOLYGON)
45 | padded_polygon = np.array(padding.Execute(distance)[0])
46 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
47 |
48 | xmin = padded_polygon[:, 0].min()
49 | xmax = padded_polygon[:, 0].max()
50 | ymin = padded_polygon[:, 1].min()
51 | ymax = padded_polygon[:, 1].max()
52 | width = xmax - xmin + 1
53 | height = ymax - ymin + 1
54 |
55 | polygon[:, 0] = polygon[:, 0] - xmin
56 | polygon[:, 1] = polygon[:, 1] - ymin
57 |
58 | xs = np.broadcast_to(
59 | np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
60 | ys = np.broadcast_to(
61 | np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
62 |
63 | distance_map = np.zeros(
64 | (polygon.shape[0], height, width), dtype=np.float32)
65 | for i in range(polygon.shape[0]):
66 | j = (i + 1) % polygon.shape[0]
67 | absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
68 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
69 | distance_map = distance_map.min(axis=0)
70 |
71 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
72 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
73 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
74 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
75 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
76 | 1 - distance_map[
77 | ymin_valid-ymin:ymax_valid-ymax+height,
78 | xmin_valid-xmin:xmax_valid-xmax+width],
79 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
80 |
81 | def distance(self, xs, ys, point_1, point_2):
82 | '''
83 | compute the distance from point to a line
84 | ys: coordinates in the first axis
85 | xs: coordinates in the second axis
86 | point_1, point_2: (x, y), the end of the line
87 | '''
88 | height, width = xs.shape[:2]
89 | square_distance_1 = np.square(
90 | xs - point_1[0]) + np.square(ys - point_1[1])
91 | square_distance_2 = np.square(
92 | xs - point_2[0]) + np.square(ys - point_2[1])
93 | square_distance = np.square(
94 | point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])
95 |
96 | cosin = (square_distance - square_distance_1 - square_distance_2) / \
97 | (2 * np.sqrt(square_distance_1 * square_distance_2))
98 | square_sin = 1 - np.square(cosin)
99 | square_sin = np.nan_to_num(square_sin)
100 | result = np.sqrt(square_distance_1 * square_distance_2 *
101 | square_sin / square_distance)
102 |
103 | result[cosin < 0] = np.sqrt(np.fmin(
104 | square_distance_1, square_distance_2))[cosin < 0]
105 | # self.extend_line(point_1, point_2, result)
106 | return result
107 |
108 | def extend_line(self, point_1, point_2, result):
109 | ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))),
110 | int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))))
111 | cv2.line(result, tuple(ex_point_1), tuple(point_1),
112 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
113 | ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))),
114 | int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))))
115 | cv2.line(result, tuple(ex_point_2), tuple(point_2),
116 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
117 | return ex_point_1, ex_point_2
--------------------------------------------------------------------------------
/models/dcn/modules/deform_conv.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules.utils import _pair
6 |
7 | from ..functions.deform_conv import deform_conv, modulated_deform_conv
8 |
9 |
10 | class DeformConv(nn.Module):
11 |
12 | def __init__(self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | stride=1,
17 | padding=0,
18 | dilation=1,
19 | groups=1,
20 | deformable_groups=1,
21 | bias=False):
22 | super(DeformConv, self).__init__()
23 |
24 | assert not bias
25 | assert in_channels % groups == 0, \
26 | 'in_channels {} cannot be divisible by groups {}'.format(
27 | in_channels, groups)
28 | assert out_channels % groups == 0, \
29 | 'out_channels {} cannot be divisible by groups {}'.format(
30 | out_channels, groups)
31 |
32 | self.in_channels = in_channels
33 | self.out_channels = out_channels
34 | self.kernel_size = _pair(kernel_size)
35 | self.stride = _pair(stride)
36 | self.padding = _pair(padding)
37 | self.dilation = _pair(dilation)
38 | self.groups = groups
39 | self.deformable_groups = deformable_groups
40 |
41 | self.weight = nn.Parameter(
42 | torch.Tensor(out_channels, in_channels // self.groups,
43 | *self.kernel_size))
44 |
45 | self.reset_parameters()
46 |
47 | def reset_parameters(self):
48 | n = self.in_channels
49 | for k in self.kernel_size:
50 | n *= k
51 | stdv = 1. / math.sqrt(n)
52 | self.weight.data.uniform_(-stdv, stdv)
53 |
54 | def forward(self, x, offset):
55 | return deform_conv(x, offset, self.weight, self.stride, self.padding,
56 | self.dilation, self.groups, self.deformable_groups)
57 |
58 |
59 | class DeformConvPack(DeformConv):
60 |
61 | def __init__(self, *args, **kwargs):
62 | super(DeformConvPack, self).__init__(*args, **kwargs)
63 |
64 | self.conv_offset = nn.Conv2d(
65 | self.in_channels,
66 | self.deformable_groups * 2 * self.kernel_size[0] *
67 | self.kernel_size[1],
68 | kernel_size=self.kernel_size,
69 | stride=_pair(self.stride),
70 | padding=_pair(self.padding),
71 | bias=True)
72 | self.init_offset()
73 |
74 | def init_offset(self):
75 | self.conv_offset.weight.data.zero_()
76 | self.conv_offset.bias.data.zero_()
77 |
78 | def forward(self, x):
79 | offset = self.conv_offset(x)
80 | return deform_conv(x, offset, self.weight, self.stride, self.padding,
81 | self.dilation, self.groups, self.deformable_groups)
82 |
83 |
84 | class ModulatedDeformConv(nn.Module):
85 |
86 | def __init__(self,
87 | in_channels,
88 | out_channels,
89 | kernel_size,
90 | stride=1,
91 | padding=0,
92 | dilation=1,
93 | groups=1,
94 | deformable_groups=1,
95 | bias=True):
96 | super(ModulatedDeformConv, self).__init__()
97 | self.in_channels = in_channels
98 | self.out_channels = out_channels
99 | self.kernel_size = _pair(kernel_size)
100 | self.stride = stride
101 | self.padding = padding
102 | self.dilation = dilation
103 | self.groups = groups
104 | self.deformable_groups = deformable_groups
105 | self.with_bias = bias
106 |
107 | self.weight = nn.Parameter(
108 | torch.Tensor(out_channels, in_channels // groups,
109 | *self.kernel_size))
110 | if bias:
111 | self.bias = nn.Parameter(torch.Tensor(out_channels))
112 | else:
113 | self.register_parameter('bias', None)
114 | self.reset_parameters()
115 |
116 | def reset_parameters(self):
117 | n = self.in_channels
118 | for k in self.kernel_size:
119 | n *= k
120 | stdv = 1. / math.sqrt(n)
121 | self.weight.data.uniform_(-stdv, stdv)
122 | if self.bias is not None:
123 | self.bias.data.zero_()
124 |
125 | def forward(self, x, offset, mask):
126 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
127 | self.stride, self.padding, self.dilation,
128 | self.groups, self.deformable_groups)
129 |
130 |
131 | class ModulatedDeformConvPack(ModulatedDeformConv):
132 |
133 | def __init__(self, *args, **kwargs):
134 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
135 |
136 | self.conv_offset_mask = nn.Conv2d(
137 | self.in_channels,
138 | self.deformable_groups * 3 * self.kernel_size[0] *
139 | self.kernel_size[1],
140 | kernel_size=self.kernel_size,
141 | stride=_pair(self.stride),
142 | padding=_pair(self.padding),
143 | bias=True)
144 | self.init_offset()
145 |
146 | def init_offset(self):
147 | self.conv_offset_mask.weight.data.zero_()
148 | self.conv_offset_mask.bias.data.zero_()
149 |
150 | def forward(self, x):
151 | out = self.conv_offset_mask(x)
152 | o1, o2, mask = torch.chunk(out, 3, dim=1)
153 | offset = torch.cat((o1, o2), dim=1)
154 | mask = torch.sigmoid(mask)
155 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
156 | self.stride, self.padding, self.dilation,
157 | self.groups, self.deformable_groups)
158 |
--------------------------------------------------------------------------------
/pruned/get_pruned_model.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: tt.py
7 | @time: 2020/6/20 10:51
8 |
9 | """
10 |
11 | import models
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 |
16 |
17 | def get_new_model(model, new_model, prued_mask, bn_index):
18 | merge1_index = [3, 12, 18]
19 | merge2_index = [25, 28, 34]
20 | merge3_index = [41, 44, 50]
21 | merge4_index = [57, 60, 66]
22 |
23 | index_0 = []
24 | for item in merge1_index:
25 | index_0.append(bn_index.index(item))
26 | mask1 = prued_mask[index_0[0]] | prued_mask[index_0[1]] | prued_mask[index_0[2]]
27 |
28 | index_1 = []
29 | for item in merge2_index:
30 | index_1.append(bn_index.index(item))
31 | mask2 = prued_mask[index_1[0]] | prued_mask[index_1[1]] | prued_mask[index_1[2]]
32 |
33 | index_2 = []
34 | for item in merge3_index:
35 | index_2.append(bn_index.index(item))
36 | mask3 = prued_mask[index_2[0]] | prued_mask[index_2[1]] | prued_mask[index_2[2]]
37 |
38 | index_3 = []
39 | for item in merge4_index:
40 | index_3.append(bn_index.index(item))
41 | mask4 = prued_mask[index_3[0]] | prued_mask[index_3[1]] | prued_mask[index_3[2]]
42 |
43 | for index in index_0:
44 | prued_mask[index] = mask1
45 |
46 | for index in index_1:
47 | prued_mask[index] = mask2
48 |
49 | for index in index_2:
50 | prued_mask[index] = mask3
51 |
52 | for index in index_3:
53 | prued_mask[index] = mask4
54 |
55 | ##############################################################
56 | index_bn = 0
57 | index_conv = 0
58 |
59 | bn_mask = []
60 | conv_in_mask = []
61 | conv_out_mask = []
62 | tag = 0
63 | for m in new_model.modules():
64 | if (tag > 69):
65 | break
66 | if (isinstance(m, nn.BatchNorm2d)):
67 | m.num_features = prued_mask[index_bn].sum()
68 | bn_mask.append(prued_mask[index_bn])
69 | index_bn += 1
70 | elif (isinstance(m, nn.Conv2d)):
71 | if (index_conv == 0):
72 | m.in_channels = 3
73 | conv_in_mask.append(torch.ones(3))
74 | else:
75 | m.in_channels = prued_mask[index_conv - 1].sum()
76 | conv_in_mask.append(prued_mask[index_conv - 1])
77 | m.out_channels = int(prued_mask[index_conv].sum())
78 | conv_out_mask.append(prued_mask[index_conv])
79 | index_conv += 1
80 | tag += 1
81 |
82 | conv_change_index = [27, 43, 59] #
83 | change_conv_bn_index = [18, 34, 50] #
84 | tag = 0
85 | for m in new_model.modules():
86 | if (tag > 69):
87 | break
88 | if (isinstance(m, nn.Conv2d)):
89 | if (tag in conv_change_index):
90 | index = conv_change_index.index(tag)
91 | index = change_conv_bn_index[index]
92 | index = bn_index.index(index)
93 | mask = prued_mask[index]
94 | conv_in_mask[index + 3] = mask
95 | m.in_channels = mask.sum()
96 | tag += 1
97 |
98 | #############################################################
99 | bn_i = 0
100 | conv_i = 0
101 | scale_i = 0
102 | scale_mask = [mask4, mask3, mask2, mask1]
103 | # scale = [70,86,90,94] # FPN
104 | scale = [73, 77, 81, 85] # DB
105 | for [m0, m1] in zip(model.modules(), new_model.modules()):
106 | if (scale_i > 69):
107 | if isinstance(m0, nn.Conv2d):
108 | if (scale_i in scale):
109 | index = scale.index(scale_i)
110 | m1.in_channels = scale_mask[index].sum()
111 | idx0 = np.squeeze(np.argwhere(np.asarray(scale_mask[index].cpu().numpy())))
112 | idx1 = np.squeeze(np.argwhere(np.asarray(torch.ones(256).cpu().numpy())))
113 | if idx0.size == 1:
114 | idx0 = np.resize(idx0, (1,))
115 | if idx1.size == 1:
116 | idx1 = np.resize(idx1, (1,))
117 | w = m0.weight.data[:, idx0, :, :].clone()
118 | m1.weight.data = w[idx1, :, :, :].clone()
119 | if m1.bias is not None:
120 | m1.bias.data = m0.bias.data[idx1].clone()
121 |
122 | else:
123 | m1.weight.data = m0.weight.data.clone()
124 | if m1.bias is not None:
125 | m1.bias.data = m0.bias.data.clone()
126 |
127 | elif isinstance(m0, nn.BatchNorm2d):
128 | m1.weight.data = m0.weight.data.clone()
129 | if m1.bias is not None:
130 | m1.bias.data = m0.bias.data.clone()
131 | m1.running_mean = m0.running_mean.clone()
132 | m1.running_var = m0.running_var.clone()
133 |
134 | else:
135 | if isinstance(m0, nn.BatchNorm2d):
136 | idx1 = np.squeeze(np.argwhere(np.asarray(bn_mask[bn_i].cpu().numpy())))
137 | if idx1.size == 1:
138 | idx1 = np.resize(idx1, (1,))
139 | m1.weight.data = m0.weight.data[idx1].clone()
140 | if m1.bias is not None:
141 | m1.bias.data = m0.bias.data[idx1].clone()
142 | m1.running_mean = m0.running_mean[idx1].clone()
143 | m1.running_var = m0.running_var[idx1].clone()
144 | bn_i += 1
145 | elif isinstance(m0, nn.Conv2d):
146 | if (isinstance(conv_in_mask[conv_i], list)):
147 | idx0 = np.squeeze(np.argwhere(np.asarray(torch.cat(conv_in_mask[conv_i], 0).cpu().numpy())))
148 | else:
149 | idx0 = np.squeeze(np.argwhere(np.asarray(conv_in_mask[conv_i].cpu().numpy())))
150 | idx1 = np.squeeze(np.argwhere(np.asarray(conv_out_mask[conv_i].cpu().numpy())))
151 | if idx0.size == 1:
152 | idx0 = np.resize(idx0, (1,))
153 | if idx1.size == 1:
154 | idx1 = np.resize(idx1, (1,))
155 | w = m0.weight.data[:, idx0, :, :].clone()
156 | m1.weight.data = w[idx1, :, :, :].clone()
157 | if m1.bias is not None:
158 | m1.bias.data = m0.bias.data[idx1].clone()
159 | conv_i += 1
160 |
161 | scale_i += 1
162 |
163 | return new_model
164 |
165 |
166 | def load_prune_model(model, pruned_model_dict_path):
167 | _load = torch.load(pruned_model_dict_path)
168 | prued_mask = _load['prued_mask']
169 | bn_index = _load['bn_index']
170 | prune_model = get_new_model(model, model, prued_mask, bn_index)
171 | return prune_model
172 |
173 |
174 |
175 |
--------------------------------------------------------------------------------
/models/dcn/modules/deform_pool.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from ..functions.deform_pool import deform_roi_pooling
4 |
5 |
6 | class DeformRoIPooling(nn.Module):
7 |
8 | def __init__(self,
9 | spatial_scale,
10 | out_size,
11 | out_channels,
12 | no_trans,
13 | group_size=1,
14 | part_size=None,
15 | sample_per_part=4,
16 | trans_std=.0):
17 | super(DeformRoIPooling, self).__init__()
18 | self.spatial_scale = spatial_scale
19 | self.out_size = out_size
20 | self.out_channels = out_channels
21 | self.no_trans = no_trans
22 | self.group_size = group_size
23 | self.part_size = out_size if part_size is None else part_size
24 | self.sample_per_part = sample_per_part
25 | self.trans_std = trans_std
26 |
27 | def forward(self, data, rois, offset):
28 | if self.no_trans:
29 | offset = data.new_empty(0)
30 | return deform_roi_pooling(
31 | data, rois, offset, self.spatial_scale, self.out_size,
32 | self.out_channels, self.no_trans, self.group_size, self.part_size,
33 | self.sample_per_part, self.trans_std)
34 |
35 |
36 | class DeformRoIPoolingPack(DeformRoIPooling):
37 |
38 | def __init__(self,
39 | spatial_scale,
40 | out_size,
41 | out_channels,
42 | no_trans,
43 | group_size=1,
44 | part_size=None,
45 | sample_per_part=4,
46 | trans_std=.0,
47 | num_offset_fcs=3,
48 | deform_fc_channels=1024):
49 | super(DeformRoIPoolingPack,
50 | self).__init__(spatial_scale, out_size, out_channels, no_trans,
51 | group_size, part_size, sample_per_part, trans_std)
52 |
53 | self.num_offset_fcs = num_offset_fcs
54 | self.deform_fc_channels = deform_fc_channels
55 |
56 | if not no_trans:
57 | seq = []
58 | ic = self.out_size * self.out_size * self.out_channels
59 | for i in range(self.num_offset_fcs):
60 | if i < self.num_offset_fcs - 1:
61 | oc = self.deform_fc_channels
62 | else:
63 | oc = self.out_size * self.out_size * 2
64 | seq.append(nn.Linear(ic, oc))
65 | ic = oc
66 | if i < self.num_offset_fcs - 1:
67 | seq.append(nn.ReLU(inplace=True))
68 | self.offset_fc = nn.Sequential(*seq)
69 | self.offset_fc[-1].weight.data.zero_()
70 | self.offset_fc[-1].bias.data.zero_()
71 |
72 | def forward(self, data, rois):
73 | assert data.size(1) == self.out_channels
74 | if self.no_trans:
75 | offset = data.new_empty(0)
76 | return deform_roi_pooling(
77 | data, rois, offset, self.spatial_scale, self.out_size,
78 | self.out_channels, self.no_trans, self.group_size,
79 | self.part_size, self.sample_per_part, self.trans_std)
80 | else:
81 | n = rois.shape[0]
82 | offset = data.new_empty(0)
83 | x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
84 | self.out_size, self.out_channels, True,
85 | self.group_size, self.part_size,
86 | self.sample_per_part, self.trans_std)
87 | offset = self.offset_fc(x.view(n, -1))
88 | offset = offset.view(n, 2, self.out_size, self.out_size)
89 | return deform_roi_pooling(
90 | data, rois, offset, self.spatial_scale, self.out_size,
91 | self.out_channels, self.no_trans, self.group_size,
92 | self.part_size, self.sample_per_part, self.trans_std)
93 |
94 |
95 | class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
96 |
97 | def __init__(self,
98 | spatial_scale,
99 | out_size,
100 | out_channels,
101 | no_trans,
102 | group_size=1,
103 | part_size=None,
104 | sample_per_part=4,
105 | trans_std=.0,
106 | num_offset_fcs=3,
107 | num_mask_fcs=2,
108 | deform_fc_channels=1024):
109 | super(ModulatedDeformRoIPoolingPack, self).__init__(
110 | spatial_scale, out_size, out_channels, no_trans, group_size,
111 | part_size, sample_per_part, trans_std)
112 |
113 | self.num_offset_fcs = num_offset_fcs
114 | self.num_mask_fcs = num_mask_fcs
115 | self.deform_fc_channels = deform_fc_channels
116 |
117 | if not no_trans:
118 | offset_fc_seq = []
119 | ic = self.out_size * self.out_size * self.out_channels
120 | for i in range(self.num_offset_fcs):
121 | if i < self.num_offset_fcs - 1:
122 | oc = self.deform_fc_channels
123 | else:
124 | oc = self.out_size * self.out_size * 2
125 | offset_fc_seq.append(nn.Linear(ic, oc))
126 | ic = oc
127 | if i < self.num_offset_fcs - 1:
128 | offset_fc_seq.append(nn.ReLU(inplace=True))
129 | self.offset_fc = nn.Sequential(*offset_fc_seq)
130 | self.offset_fc[-1].weight.data.zero_()
131 | self.offset_fc[-1].bias.data.zero_()
132 |
133 | mask_fc_seq = []
134 | ic = self.out_size * self.out_size * self.out_channels
135 | for i in range(self.num_mask_fcs):
136 | if i < self.num_mask_fcs - 1:
137 | oc = self.deform_fc_channels
138 | else:
139 | oc = self.out_size * self.out_size
140 | mask_fc_seq.append(nn.Linear(ic, oc))
141 | ic = oc
142 | if i < self.num_mask_fcs - 1:
143 | mask_fc_seq.append(nn.ReLU(inplace=True))
144 | else:
145 | mask_fc_seq.append(nn.Sigmoid())
146 | self.mask_fc = nn.Sequential(*mask_fc_seq)
147 | self.mask_fc[-2].weight.data.zero_()
148 | self.mask_fc[-2].bias.data.zero_()
149 |
150 | def forward(self, data, rois):
151 | assert data.size(1) == self.out_channels
152 | if self.no_trans:
153 | offset = data.new_empty(0)
154 | return deform_roi_pooling(
155 | data, rois, offset, self.spatial_scale, self.out_size,
156 | self.out_channels, self.no_trans, self.group_size,
157 | self.part_size, self.sample_per_part, self.trans_std)
158 | else:
159 | n = rois.shape[0]
160 | offset = data.new_empty(0)
161 | x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
162 | self.out_size, self.out_channels, True,
163 | self.group_size, self.part_size,
164 | self.sample_per_part, self.trans_std)
165 | offset = self.offset_fc(x.view(n, -1))
166 | offset = offset.view(n, 2, self.out_size, self.out_size)
167 | mask = self.mask_fc(x.view(n, -1))
168 | mask = mask.view(n, 1, self.out_size, self.out_size)
169 | return deform_roi_pooling(
170 | data, rois, offset, self.spatial_scale, self.out_size,
171 | self.out_channels, self.no_trans, self.group_size,
172 | self.part_size, self.sample_per_part, self.trans_std) * mask
173 |
--------------------------------------------------------------------------------
/models/dcn/functions/deform_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch.nn.modules.utils import _pair
4 |
5 | from .. import deform_conv_cuda
6 |
7 |
8 | class DeformConvFunction(Function):
9 |
10 | @staticmethod
11 | def forward(ctx,
12 | input,
13 | offset,
14 | weight,
15 | stride=1,
16 | padding=0,
17 | dilation=1,
18 | groups=1,
19 | deformable_groups=1,
20 | im2col_step=64):
21 | if input is not None and input.dim() != 4:
22 | raise ValueError(
23 | "Expected 4D tensor as input, got {}D tensor instead.".format(
24 | input.dim()))
25 | ctx.stride = _pair(stride)
26 | ctx.padding = _pair(padding)
27 | ctx.dilation = _pair(dilation)
28 | ctx.groups = groups
29 | ctx.deformable_groups = deformable_groups
30 | ctx.im2col_step = im2col_step
31 |
32 | ctx.save_for_backward(input, offset, weight)
33 |
34 | output = input.new_empty(
35 | DeformConvFunction._output_size(input, weight, ctx.padding,
36 | ctx.dilation, ctx.stride))
37 |
38 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
39 |
40 | if not input.is_cuda:
41 | raise NotImplementedError
42 | else:
43 | cur_im2col_step = min(ctx.im2col_step, input.shape[0])
44 | assert (input.shape[0] %
45 | cur_im2col_step) == 0, 'im2col step must divide batchsize'
46 | deform_conv_cuda.deform_conv_forward_cuda(
47 | input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
48 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
49 | ctx.padding[1], ctx.padding[0], ctx.dilation[1],
50 | ctx.dilation[0], ctx.groups, ctx.deformable_groups,
51 | cur_im2col_step)
52 | return output
53 |
54 | @staticmethod
55 | def backward(ctx, grad_output):
56 | input, offset, weight = ctx.saved_tensors
57 |
58 | grad_input = grad_offset = grad_weight = None
59 |
60 | if not grad_output.is_cuda:
61 | raise NotImplementedError
62 | else:
63 | cur_im2col_step = min(ctx.im2col_step, input.shape[0])
64 | assert (input.shape[0] %
65 | cur_im2col_step) == 0, 'im2col step must divide batchsize'
66 |
67 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
68 | grad_input = torch.zeros_like(input)
69 | grad_offset = torch.zeros_like(offset)
70 | deform_conv_cuda.deform_conv_backward_input_cuda(
71 | input, offset, grad_output, grad_input,
72 | grad_offset, weight, ctx.bufs_[0], weight.size(3),
73 | weight.size(2), ctx.stride[1], ctx.stride[0],
74 | ctx.padding[1], ctx.padding[0], ctx.dilation[1],
75 | ctx.dilation[0], ctx.groups, ctx.deformable_groups,
76 | cur_im2col_step)
77 |
78 | if ctx.needs_input_grad[2]:
79 | grad_weight = torch.zeros_like(weight)
80 | deform_conv_cuda.deform_conv_backward_parameters_cuda(
81 | input, offset, grad_output,
82 | grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
83 | weight.size(2), ctx.stride[1], ctx.stride[0],
84 | ctx.padding[1], ctx.padding[0], ctx.dilation[1],
85 | ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
86 | cur_im2col_step)
87 |
88 | return (grad_input, grad_offset, grad_weight, None, None, None, None,
89 | None)
90 |
91 | @staticmethod
92 | def _output_size(input, weight, padding, dilation, stride):
93 | channels = weight.size(0)
94 | output_size = (input.size(0), channels)
95 | for d in range(input.dim() - 2):
96 | in_size = input.size(d + 2)
97 | pad = padding[d]
98 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
99 | stride_ = stride[d]
100 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
101 | if not all(map(lambda s: s > 0, output_size)):
102 | raise ValueError(
103 | "convolution input is too small (output would be {})".format(
104 | 'x'.join(map(str, output_size))))
105 | return output_size
106 |
107 |
108 | class ModulatedDeformConvFunction(Function):
109 |
110 | @staticmethod
111 | def forward(ctx,
112 | input,
113 | offset,
114 | mask,
115 | weight,
116 | bias=None,
117 | stride=1,
118 | padding=0,
119 | dilation=1,
120 | groups=1,
121 | deformable_groups=1):
122 | ctx.stride = stride
123 | ctx.padding = padding
124 | ctx.dilation = dilation
125 | ctx.groups = groups
126 | ctx.deformable_groups = deformable_groups
127 | ctx.with_bias = bias is not None
128 | if not ctx.with_bias:
129 | bias = input.new_empty(1) # fake tensor
130 | if not input.is_cuda:
131 | raise NotImplementedError
132 | if weight.requires_grad or mask.requires_grad or offset.requires_grad \
133 | or input.requires_grad:
134 | ctx.save_for_backward(input, offset, mask, weight, bias)
135 | output = input.new_empty(
136 | ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
137 | ctx._bufs = [input.new_empty(0), input.new_empty(0)]
138 | deform_conv_cuda.modulated_deform_conv_cuda_forward(
139 | input, weight, bias, ctx._bufs[0], offset, mask, output,
140 | ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
141 | ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
142 | ctx.groups, ctx.deformable_groups, ctx.with_bias)
143 | return output
144 |
145 | @staticmethod
146 | def backward(ctx, grad_output):
147 | if not grad_output.is_cuda:
148 | raise NotImplementedError
149 | input, offset, mask, weight, bias = ctx.saved_tensors
150 | grad_input = torch.zeros_like(input)
151 | grad_offset = torch.zeros_like(offset)
152 | grad_mask = torch.zeros_like(mask)
153 | grad_weight = torch.zeros_like(weight)
154 | grad_bias = torch.zeros_like(bias)
155 | deform_conv_cuda.modulated_deform_conv_cuda_backward(
156 | input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
157 | grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
158 | grad_output, weight.shape[2], weight.shape[3], ctx.stride,
159 | ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
160 | ctx.groups, ctx.deformable_groups, ctx.with_bias)
161 | if not ctx.with_bias:
162 | grad_bias = None
163 |
164 | return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
165 | None, None, None, None, None)
166 |
167 | @staticmethod
168 | def _infer_shape(ctx, input, weight):
169 | n = input.size(0)
170 | channels_out = weight.size(0)
171 | height, width = input.shape[2:4]
172 | kernel_h, kernel_w = weight.shape[2:4]
173 | height_out = (height + 2 * ctx.padding -
174 | (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
175 | width_out = (width + 2 * ctx.padding -
176 | (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
177 | return n, channels_out, height_out, width_out
178 |
179 |
180 | deform_conv = DeformConvFunction.apply
181 | modulated_deform_conv = ModulatedDeformConvFunction.apply
182 |
--------------------------------------------------------------------------------
/loss/dice_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import cv2
5 | from scipy import ndimage
6 |
7 |
8 | class DiceLoss(nn.Module):
9 | '''
10 | Loss function from https://arxiv.org/abs/1707.03237,
11 | where iou computation is introduced heatmap manner to measure the
12 | diversity bwtween tow heatmaps.
13 | '''
14 |
15 | def __init__(self, eps=1e-6):
16 | super(DiceLoss, self).__init__()
17 | self.eps = eps
18 |
19 | def forward(self, pred: torch.Tensor, gt, mask, weights=None):
20 | '''
21 | pred: one or two heatmaps of shape (N, 1, H, W),
22 | the losses of tow heatmaps are added together.
23 | gt: (N, 1, H, W)
24 | mask: (N, H, W)
25 | '''
26 | assert pred.dim() == 4, pred.dim()
27 | return self._compute(pred, gt, mask, weights)
28 |
29 | def _compute(self, pred, gt, mask, weights):
30 | if pred.dim() == 4:
31 | pred = pred[:, 0, :, :]
32 | gt = gt[:, 0, :, :]
33 | assert pred.shape == gt.shape
34 | assert pred.shape == mask.shape
35 | if weights is not None:
36 | assert weights.shape == mask.shape
37 | mask = weights * mask
38 |
39 | intersection = (pred * gt * mask).sum()
40 | union = (pred * mask).sum() + (gt * mask).sum() + self.eps
41 | loss = 1 - 2.0 * intersection / union
42 | assert loss <= 1
43 | return loss
44 |
45 |
46 | def dice_loss(input, target, mask,eps=1e-6):
47 | input = input.contiguous().view(input.size()[0], -1)
48 | target = target.contiguous().view(target.size()[0], -1)
49 | mask = mask.contiguous().view(mask.size()[0], -1)
50 |
51 | input = input * mask
52 | target = target * mask
53 |
54 | a = torch.sum(input * target, 1)
55 | b = torch.sum(input * input, 1) + eps
56 | c = torch.sum(target * target, 1) + eps
57 | d = (2 * a) / (b + c)
58 | dice_loss = torch.mean(d)
59 | return 1 - dice_loss
60 |
61 |
62 | class LeakyDiceLoss(nn.Module):
63 | '''
64 | Variation from DiceLoss.
65 | The coverage and union are computed separately.
66 | '''
67 |
68 | def __init__(self, eps=1e-6, coverage_scale=5.0):
69 | super(LeakyDiceLoss, self).__init__()
70 | self.eps = eps
71 | self.coverage_scale = coverage_scale
72 |
73 | def forward(self, pred, gt, mask):
74 | if pred.dim() == 4:
75 | pred = pred[:, 0, :, :]
76 | gt = gt[:, 0, :, :]
77 | assert pred.shape == gt.shape
78 | assert pred.shape == mask.shape
79 |
80 | coverage = (pred * mask * gt).sum() / ((gt * mask).sum() + self.eps)
81 | assert coverage <= 1
82 | coverage = 1 - coverage
83 | excede = (pred * mask * gt).sum() / ((pred * mask).sum() + self.eps)
84 | assert excede <= 1
85 | excede = 1 - excede
86 | loss = coverage * self.coverage_scale + excede
87 | return loss, dict(coverage=coverage, excede=excede)
88 |
89 |
90 | class InstanceDiceLoss(DiceLoss):
91 | '''
92 | DiceLoss normalized on each instance.
93 | Input:
94 | pred: (N, 1, H, W)
95 | gt: (N, 1, H, W)
96 | mask: (N, H, W)
97 | Note: This class assume that input tensors are on gpu,
98 | while cput computation is required to find union areas.
99 | '''
100 | REDUCTION = ['mean', 'sum', 'none']
101 |
102 | def __init__(self, threshold=0.3, iou_thresh=0.2, reduction=None,
103 | max_regions=100, eps=1e-6):
104 | nn.Module.__init__(self)
105 | self.threshold = threshold
106 | self.iou_thresh = iou_thresh
107 | self.reduction = reduction
108 | if self.reduction is None:
109 | self.reduction = 'mean'
110 | assert self.reduction in self.REDUCTION
111 | self.max_regions = max_regions
112 | self.eps = eps
113 |
114 | def label(self, tensor_on_gpu, blur=None):
115 | '''
116 | Args:
117 | tensor_on_gpu: (N, 1, H, W)
118 | blur: Lambda. If exists, each instance will be blured using `blur`.
119 | '''
120 | tensor = tensor_on_gpu.cpu().detach().numpy()
121 |
122 | instance_maps = []
123 | instance_counts = []
124 | for batch_index in range(tensor_on_gpu.shape[0]):
125 | instance = tensor[batch_index]
126 | if blur is not None:
127 | instance = blur(instance)
128 | lable_map, instance_count = ndimage.label(instance[0])
129 | instance_count = min(self.max_regions, instance_count)
130 | instance_map = []
131 | for index in range(1, instance_count):
132 | instance = torch.from_numpy(
133 | lable_map == index).to(tensor_on_gpu.device).type(torch.float32)
134 | instance_map.append(instance)
135 | instance_maps.append(instance_map)
136 | return instance_maps, instance_counts
137 |
138 | def iou(self, pred, gt):
139 | overlap = (pred * gt).sum()
140 | return max(overlap / pred.sum(), overlap / gt.sum())
141 |
142 | def replace_or_add(self, dest, value):
143 | if dest is None:
144 | return value
145 | if value is None:
146 | return dest
147 | return dest + value
148 |
149 | def forward(self, pred, gt, mask):
150 | # pred_label_maps: N, P, H, W, where P is the number of regions.
151 | torch.cuda.synchronize()
152 | pred_label_maps, _ = self.label(pred > self.threshold)
153 | gt_label_maps, _ = self.label(gt)
154 |
155 | losses = []
156 | for batch_index, gt_instance_maps in enumerate(gt_label_maps):
157 | pred_instance_maps = pred_label_maps[batch_index]
158 | if gt_instance_maps is None or pred_instance_maps is None:
159 | continue
160 |
161 | single_loss = None # loss on a single image in a batch
162 | mask_not_matched = set(range(len(pred_instance_maps)))
163 | for gt_instance_map in gt_instance_maps:
164 | instance_loss = None # loss on a specific gt region
165 | for instance_index, pred_instance_map in enumerate(pred_instance_maps):
166 | if self.iou(pred_instance_map, gt_instance_map) > self.iou_thresh:
167 | match_loss = self._compute(
168 | pred[batch_index][0], gt[batch_index][0],
169 | mask[batch_index] * (pred_instance_map + gt_instance_map > 0).type(torch.float32))
170 | instance_loss = self.replace_or_add(instance_loss, match_loss)
171 | if instance_index in mask_not_matched:
172 | mask_not_matched.remove(instance_index)
173 | if instance_loss is None:
174 | instance_loss = self._compute(
175 | pred[batch_index][0], gt[batch_index][0],
176 | mask[batch_index] * gt_instance_map)
177 | single_loss = self.replace_or_add(single_loss, instance_loss)
178 |
179 | '''Whether to compute single loss on instances which contrain no positive sample.
180 | if single_loss is None:
181 | single_loss = self._compute(
182 | pred[batch_index][0], gt[batch_index][0],
183 | mask[batch_index])
184 | '''
185 |
186 | for instance_index in mask_not_matched:
187 | single_loss = self.replace_or_add(
188 | single_loss,
189 | self._compute(
190 | pred[batch_index][0], gt[batch_index][0],
191 | mask[batch_index] * pred_instance_maps[instance_index]))
192 |
193 | if single_loss is not None:
194 | losses.append(single_loss)
195 |
196 | if self.reduction == 'none':
197 | loss = losses
198 | else:
199 | assert self.reduction in ['sum', 'mean']
200 | count = len(losses)
201 | loss = sum(losses)
202 | if self.reduction == 'mean':
203 | loss = loss / count
204 | return loss
205 |
--------------------------------------------------------------------------------
/utils/DB_postprocesss.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: DB_postprocesss.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import numpy as np
11 | import string
12 | import cv2
13 | from shapely.geometry import Polygon
14 | import pyclipper
15 |
16 |
17 | class DBPostProcess(object):
18 | """
19 | The post process for Differentiable Binarization (DB).
20 | """
21 |
22 | def __init__(self, params):
23 | self.thresh = params['thresh']
24 | self.box_thresh = params['box_thresh']
25 | self.max_candidates = params['max_candidates']
26 | self.is_poly = params['is_poly']
27 | self.unclip_ratio = params['unclip_ratio']
28 | self.min_size = params['min_size']
29 |
30 | def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
31 | '''
32 | _bitmap: single map with shape (1, H, W),
33 | whose values are binarized as {0, 1}
34 | '''
35 |
36 |
37 | bitmap = _bitmap
38 | pred = pred
39 | height, width = bitmap.shape
40 | boxes = []
41 | scores = []
42 |
43 | contours, _ = cv2.findContours(
44 | (bitmap*255).astype(np.uint8),
45 | cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
46 |
47 | for contour in contours[:self.max_candidates]:
48 | epsilon = 0.001 * cv2.arcLength(contour, True)
49 | approx = cv2.approxPolyDP(contour, epsilon, True)
50 | points = approx.reshape((-1, 2))
51 | if points.shape[0] < 4:
52 | continue
53 | # _, sside = self.get_mini_boxes(contour)
54 | # if sside < self.min_size:
55 | # continue
56 | score = self.box_score_fast(pred, points.reshape(-1, 2))
57 | if self.box_thresh > score:
58 | continue
59 |
60 | if points.shape[0] > 2:
61 | box = self.unclip(points, self.unclip_ratio)
62 | if len(box) > 1:
63 | continue
64 | else:
65 | continue
66 | box = box.reshape(-1, 2)
67 | _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
68 | if sside < self.min_size + 2:
69 | continue
70 |
71 | if not isinstance(dest_width, int):
72 | dest_width = dest_width.item()
73 | dest_height = dest_height.item()
74 |
75 | box[:, 0] = np.clip(
76 | np.round(box[:, 0] / width * dest_width), 0, dest_width)
77 | box[:, 1] = np.clip(
78 | np.round(box[:, 1] / height * dest_height), 0, dest_height)
79 | boxes.append(box.tolist())
80 | scores.append(score)
81 | return boxes, scores
82 |
83 | def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
84 | '''
85 | _bitmap: single map with shape (1, H, W),
86 | whose values are binarized as {0, 1}
87 | '''
88 |
89 | bitmap = _bitmap
90 | height, width = bitmap.shape
91 |
92 | outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
93 | if len(outs) == 3:
94 | img, contours, _ = outs[0], outs[1], outs[2]
95 | elif len(outs) == 2:
96 | contours, _ = outs[0], outs[1]
97 |
98 | num_contours = min(len(contours), self.max_candidates)
99 | boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
100 | scores = np.zeros((num_contours, ), dtype=np.float32)
101 |
102 | for index in range(num_contours):
103 | contour = contours[index]
104 | points, sside = self.get_mini_boxes(contour)
105 | if sside < self.min_size:
106 | continue
107 | points = np.array(points)
108 | score = self.box_score_fast(pred, points.reshape(-1, 2))
109 | if self.box_thresh > score:
110 | continue
111 |
112 | box = self.unclip(points,self.unclip_ratio).reshape(-1, 1, 2)
113 | box, sside = self.get_mini_boxes(box)
114 | if sside < self.min_size + 2:
115 | continue
116 | box = np.array(box)
117 | if not isinstance(dest_width, int):
118 | dest_width = dest_width.item()
119 | dest_height = dest_height.item()
120 |
121 | box[:, 0] = np.clip(
122 | np.round(box[:, 0] / width * dest_width), 0, dest_width)
123 | box[:, 1] = np.clip(
124 | np.round(box[:, 1] / height * dest_height), 0, dest_height)
125 | boxes[index, :, :] = box.astype(np.int16)
126 | scores[index] = score
127 | return boxes, scores
128 |
129 | def unclip(self, box, unclip_ratio=2):
130 | poly = Polygon(box)
131 | distance = poly.area * unclip_ratio / poly.length
132 | offset = pyclipper.PyclipperOffset()
133 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
134 | expanded = np.array(offset.Execute(distance))
135 | return expanded
136 |
137 | def get_mini_boxes(self, contour):
138 | bounding_box = cv2.minAreaRect(contour)
139 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
140 |
141 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3
142 | if points[1][1] > points[0][1]:
143 | index_1 = 0
144 | index_4 = 1
145 | else:
146 | index_1 = 1
147 | index_4 = 0
148 | if points[3][1] > points[2][1]:
149 | index_2 = 2
150 | index_3 = 3
151 | else:
152 | index_2 = 3
153 | index_3 = 2
154 |
155 | box = [
156 | points[index_1], points[index_2], points[index_3], points[index_4]
157 | ]
158 | return box, min(bounding_box[1])
159 |
160 | def box_score_fast(self, bitmap, _box):
161 | h, w = bitmap.shape[:2]
162 | box = _box.copy()
163 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
164 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
165 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
166 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
167 |
168 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
169 | box[:, 0] = box[:, 0] - xmin
170 | box[:, 1] = box[:, 1] - ymin
171 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
172 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
173 |
174 | def __call__(self, pred, ratio_list):
175 | pred = pred[:, 0, :, :]
176 | segmentation = pred > self.thresh
177 |
178 | boxes_batch = []
179 | score_batch = []
180 | for batch_index in range(pred.shape[0]):
181 | height, width = pred.shape[-2:]
182 | if(self.is_poly):
183 | tmp_boxes, tmp_scores = self.polygons_from_bitmap(
184 | pred[batch_index], segmentation[batch_index], width, height)
185 |
186 | boxes = []
187 | score = []
188 | for k in range(len(tmp_boxes)):
189 | if tmp_scores[k] > self.box_thresh:
190 | boxes.append(tmp_boxes[k])
191 | score.append(tmp_scores[k])
192 | if len(boxes) > 0:
193 | ratio_w, ratio_h = ratio_list[batch_index]
194 | for i in range(len(boxes)):
195 | boxes[i] = np.array(boxes[i])
196 | boxes[i][:, 0] = boxes[i][:, 0] * ratio_w
197 | boxes[i][:, 1] = boxes[i][:, 1] * ratio_h
198 |
199 | boxes_batch.append(boxes)
200 | score_batch.append(score)
201 | else:
202 | tmp_boxes, tmp_scores = self.boxes_from_bitmap(
203 | pred[batch_index], segmentation[batch_index], width, height)
204 |
205 | boxes = []
206 | score = []
207 | for k in range(len(tmp_boxes)):
208 | if tmp_scores[k] > self.box_thresh:
209 | boxes.append(tmp_boxes[k])
210 | score.append(tmp_scores[k])
211 | if len(boxes) > 0:
212 | boxes = np.array(boxes)
213 |
214 | ratio_w, ratio_h = ratio_list[batch_index]
215 | boxes[:, :, 0] = boxes[:, :, 0] * ratio_w
216 | boxes[:, :, 1] = boxes[:, :, 1] * ratio_h
217 |
218 | boxes_batch.append(boxes)
219 | score_batch.append(score)
220 | return boxes_batch,score_batch
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 _*-
2 | """
3 | @author:fxw
4 | @file: train.py.py
5 | @time: 2020/04/28
6 | """
7 | import sys
8 | sys.path.append('/home/aistudio/external-libraries')
9 | import os
10 | import torch
11 | import torch.nn as nn
12 | import yaml
13 | import argparse
14 | import numpy as np
15 | olderr = np.seterr(all='ignore')
16 | from models.DBNet import DBNet
17 | from torch.autograd import Variable
18 | from loss.loss import L1BalanceCELoss
19 | from dataloader.dataload import DataLoader
20 | from utils.Logger import Logger
21 | from utils.metrics import runningScore
22 | from utils.model_eval import val
23 | from utils.tools import *
24 | from utils.set_optimizer import *
25 |
26 | torch.backends.cudnn.benchmark = False
27 | torch.backends.cudnn.deterministic = True
28 |
29 | def updateBN(model,config):
30 | tag = 0
31 | for m in model.modules():
32 | if(tag>69):
33 | break
34 | if isinstance(m, nn.BatchNorm2d):
35 | if hasattr(m.weight, 'data'):
36 | m.weight.grad.data.add_(config['train']['sr_lr']*torch.sign(m.weight.data)) #L1正则
37 | tag+=1
38 |
39 | def set_seed(seed):
40 | import numpy as np
41 | import random
42 | import torch
43 | random.seed(seed)
44 | np.random.seed(seed)
45 | torch.manual_seed(seed)
46 | torch.cuda.manual_seed(seed)
47 | torch.cuda.manual_seed_all(seed)
48 |
49 | GLOBAL_WORKER_ID = None
50 | GLOBAL_SEED = 2020
51 |
52 | def worker_init_fn(worker_id):
53 | global GLOBAL_WORKER_ID
54 | GLOBAL_WORKER_ID = worker_id
55 | set_seed(GLOBAL_SEED + worker_id)
56 |
57 | def train_net(config):
58 | os.environ["CUDA_VISIBLE_DEVICES"] = config['train']['gpu_id']
59 | data_loader = DataLoader(config)
60 | train_loader = torch.utils.data.DataLoader(
61 | data_loader,
62 | batch_size=config['train']['batch_size'],
63 | shuffle=True,
64 | num_workers=config['train']['num_workers'],
65 | worker_init_fn = worker_init_fn,
66 | drop_last=True,
67 | pin_memory=False)
68 |
69 | start_epoch = 0
70 | running_metric_binary = runningScore(2)
71 |
72 | if not (os.path.exists(config['train']['checkpoints'])):
73 | os.mkdir(config['train']['checkpoints'])
74 | checkpoints = os.path.join(config['train']['checkpoints'],"DB_%s_bs_%d_ep_%d" % (config['train']['backbone'],
75 | config['train']['batch_size'], config['train']['n_epoch']))
76 | if not (os.path.exists(checkpoints)):
77 | os.mkdir(checkpoints)
78 |
79 |
80 | model = DBNet(config).cuda()
81 | criterion = L1BalanceCELoss()
82 | optimizer = torch.optim.SGD(model.parameters(), lr=config['train']['base_lr'], momentum=0.99, weight_decay=5e-4)
83 |
84 | if config['train']['restore']:
85 | print('Resuming from checkpoint.')
86 | assert os.path.isfile(config['train']['resume']), 'Error: no checkpoint directory found!'
87 | checkpoint = torch.load(config['train']['resume'])
88 | start_epoch = checkpoint['epoch']
89 | model.load_state_dict(checkpoint['state_dict'])
90 | optimizer.load_state_dict(checkpoint['optimizer'])
91 | log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['train']['backbone'], resume=True)
92 | else:
93 | print('Training from scratch.')
94 | log_write = Logger(os.path.join(checkpoints,'log.txt'), title=config['train']['backbone'])
95 | log_write.set_names([' epoch', 'Total loss', ' Bce loss', 'Thresh loss', ' L1 loss', 'Binary Acc', 'Binary IoU', ' rescall',' precision',' hmean'])
96 | max_hmean = -1
97 | for epoch in range(start_epoch,config['train']['n_epoch']):
98 | model.train()
99 |
100 | bce_loss_list = []
101 | thresh_loss_list = []
102 | l1_loss_list = []
103 | total_loss_list = []
104 |
105 | if(config['train']['decay_method']=='e_decay'):
106 | adjust_learning_rate_poly(config['train']['base_lr'], optimizer, epoch, max_epoch=config['train']['n_epoch'], factor=0.9)
107 | else:
108 | adjust_learning_rate(config, optimizer, epoch,config['train']['gama'])
109 |
110 | for batch_idx, (imgs, gts, gt_masks, thresh_maps, thresh_masks) in enumerate(train_loader):
111 | imgs = Variable(imgs.cuda())
112 | gts = Variable(gts.cuda())
113 | gt_masks = Variable(gt_masks.cuda())
114 | thresh_maps = Variable(thresh_maps.cuda())
115 | thresh_masks = Variable(thresh_masks.cuda())
116 | batch = {}
117 | batch['gt'] = gts
118 | batch['mask'] = gt_masks
119 | batch['thresh_map'] = thresh_maps
120 | batch['thresh_mask'] = thresh_masks
121 |
122 | pre = model(imgs)
123 | loss, metrics = criterion(pre, batch)
124 |
125 | optimizer.zero_grad()
126 | loss.backward()
127 | if(config['train']['use_sr']):
128 | updateBN(model,config)
129 | optimizer.step()
130 |
131 | score_binary = cal_binary_score(pre['binary'], gts, gt_masks.unsqueeze(1), running_metric_binary)
132 |
133 | bce_loss_list.append(metrics['bce_loss'].item())
134 | thresh_loss_list.append(metrics['thresh_loss'].item())
135 | l1_loss_list.append(metrics['l1_loss'].item())
136 | total_loss_list.append(loss.item())
137 | if batch_idx % config['train']['show_step'] == 0:
138 | if(config['train']['print_format']=='linux'):
139 | headers = ['epoch/epochs','batch/batchs' ,'TotalLoss' ,'BceLoss',' ThreshLoss','L1Loss', 'Binary Acc','Binary IoU', 'Lr Rate']
140 | show_item = [[str(epoch)+'/'+str(config['train']['n_epoch']),
141 | str(batch_idx + 1)+'/'+str(len(train_loader)),
142 | get_str(np.mean(total_loss_list)),
143 | get_str(np.mean(bce_loss_list)),
144 | get_str(np.mean(thresh_loss_list)),
145 | get_str(np.mean(l1_loss_list)),
146 | get_str(score_binary['Mean Acc']),
147 | get_str(score_binary['Mean IoU']),
148 | get_str(optimizer.param_groups[0]['lr'])
149 | ]]
150 | print_table(headers,show_item,type_str='train')
151 | else:
152 | output_log = '({epoch}/{epochs}/{batch}/{size}) | TotalLoss: {total_loss:.4f} | BceLoss: {bce_loss:.4f} | ThreshLoss: {thresh_loss: .4f} | L1Loss: {l1_loss: .4f} | Binary Acc: {bin_acc: .4f} | Binary IoU: {bin_iou: .4f} | Lr: {lr: .4f}'.format(
153 | epoch=epoch,
154 | epochs=config['train']['n_epoch'] ,
155 | batch=batch_idx + 1,
156 | size=len(train_loader),
157 | total_loss=np.mean(total_loss_list),
158 | bce_loss=np.mean(bce_loss_list),
159 | thresh_loss=np.mean(thresh_loss_list),
160 | l1_loss=np.mean(l1_loss_list),
161 | bin_acc=score_binary['Mean Acc'],
162 | bin_iou=score_binary['Mean IoU'],
163 | lr=optimizer.param_groups[0]['lr']
164 | )
165 | print(output_log)
166 | sys.stdout.flush()
167 |
168 | if( epoch > config['train']['start_val_epoch']):
169 | result_dict = val(model,config)
170 | rescall,precision,hmean = result_dict['recall'],result_dict['precision'],result_dict['hmean']
171 | print('epoch:',epoch,'rescall:',rescall,'precision:',precision,'hmean:',hmean)
172 | else:
173 | rescall = 0
174 | precision = 0
175 | hmean = 0
176 | log_write.append([epoch, np.mean(total_loss_list), np.mean(bce_loss_list), np.mean(thresh_loss_list),
177 | np.mean(l1_loss_list), score_binary['Mean Acc'], score_binary['Mean IoU'],
178 | rescall,precision,hmean])
179 | if(hmean > max_hmean and config['train']['start_val_epoch'] < config['train']['n_epoch']):
180 | max_hmean = hmean
181 | save_checkpoint({
182 | 'epoch': epoch + 1,
183 | 'state_dict': model.state_dict(),
184 | 'lr': config['train']['base_lr'],
185 | 'optimizer': optimizer.state_dict(),
186 | }, checkpoint=checkpoints,filename='best_model.pth.tar')
187 |
188 | save_checkpoint({
189 | 'epoch': epoch + 1,
190 | 'state_dict': model.state_dict(),
191 | 'lr': config['train']['base_lr'],
192 | 'optimizer': optimizer.state_dict(),
193 | }, checkpoint=checkpoints)
194 |
195 |
196 |
197 | if __name__ == '__main__':
198 | stream = open('config.yaml', 'r', encoding='utf-8')
199 | config = yaml.load(stream,Loader=yaml.FullLoader)
200 | train_net(config)
--------------------------------------------------------------------------------
/dataloader/random_thansform.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: random_thansform.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import cv2
11 | import numpy as np
12 | import imgaug.augmenters as aug_img
13 | import imgaug
14 | import random
15 | def solve_polys(polys):
16 | len_max = 0
17 | for poly in polys:
18 | if(len(poly)//2>len_max):
19 | len_max = len(poly)//2
20 | new_polys = []
21 | for poly in polys:
22 | new_poly = []
23 | if(len(poly)//2 x + w:
69 | return False
70 | if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
71 | return False
72 | return True
73 |
74 | def is_poly_outside_rect(self, poly, x, y, w, h):
75 | poly = np.array(poly)
76 | if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
77 | return True
78 | if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
79 | return True
80 | return False
81 |
82 | def split_regions(self, axis):
83 | regions = []
84 | min_axis = 0
85 | for i in range(1, axis.shape[0]):
86 | if axis[i] != axis[i - 1] + 1:
87 | region = axis[min_axis:i]
88 | min_axis = i
89 | regions.append(region)
90 | return regions
91 |
92 | def random_select(self, axis, max_size):
93 | xx = np.random.choice(axis, size=2)
94 | xmin = np.min(xx)
95 | xmax = np.max(xx)
96 | xmin = np.clip(xmin, 0, max_size - 1)
97 | xmax = np.clip(xmax, 0, max_size - 1)
98 | return xmin, xmax
99 |
100 | def region_wise_random_select(self, regions, max_size):
101 | selected_index = list(np.random.choice(len(regions), 2))
102 | selected_values = []
103 | for index in selected_index:
104 | axis = regions[index]
105 | xx = int(np.random.choice(axis, size=1))
106 | selected_values.append(xx)
107 | xmin = min(selected_values)
108 | xmax = max(selected_values)
109 | return xmin, xmax
110 |
111 | def crop_area(self, img, polys):
112 | h, w, _ = img.shape
113 | h_array = np.zeros(h, dtype=np.int32)
114 | w_array = np.zeros(w, dtype=np.int32)
115 | for points in polys:
116 | points = np.round(points, decimals=0).astype(np.int32)
117 | minx = np.min(points[:, 0])
118 | maxx = np.max(points[:, 0])
119 | w_array[minx:maxx] = 1
120 | miny = np.min(points[:, 1])
121 | maxy = np.max(points[:, 1])
122 | h_array[miny:maxy] = 1
123 | # ensure the cropped area not across a text
124 | h_axis = np.where(h_array == 0)[0]
125 | w_axis = np.where(w_array == 0)[0]
126 |
127 | if len(h_axis) == 0 or len(w_axis) == 0:
128 | return 0, 0, w, h
129 |
130 | h_regions = self.split_regions(h_axis)
131 | w_regions = self.split_regions(w_axis)
132 |
133 | for i in range(self.max_tries):
134 | if len(w_regions) > 1:
135 | xmin, xmax = self.region_wise_random_select(w_regions, w)
136 | else:
137 | xmin, xmax = self.random_select(w_axis, w)
138 | if len(h_regions) > 1:
139 | ymin, ymax = self.region_wise_random_select(h_regions, h)
140 | else:
141 | ymin, ymax = self.random_select(h_axis, h)
142 |
143 | if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
144 | # area too small
145 | continue
146 | num_poly_in_rect = 0
147 | for poly in polys:
148 | if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
149 | num_poly_in_rect += 1
150 | break
151 |
152 | if num_poly_in_rect > 0:
153 | return xmin, ymin, xmax - xmin, ymax - ymin
154 |
155 | return 0, 0, w, h
156 |
157 | class Random_Augment():
158 | def __init__(self):
159 | super(Random_Augment,self).__init__()
160 | self.random_crop_data = RandomCropData()
161 |
162 | def augment_poly(self,aug, img_shape, poly):
163 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
164 | keypoints = aug.augment_keypoints([imgaug.KeypointsOnImage(keypoints, shape=img_shape[:2])])[0].keypoints
165 | poly = [(p.x, p.y) for p in keypoints]
166 | return np.array(poly)
167 |
168 | def random_rotate(self,img, polys, random_range):
169 | angle = np.random.randint(random_range[0], random_range[1])
170 | aug_bin = aug_img.Sequential([aug_img.Affine(rotate=angle)])
171 | img = aug_bin.augment_image(img)
172 | new_polys = []
173 | for poly in polys:
174 | poly = self.augment_poly(aug_bin, img.shape, poly)
175 | poly = np.maximum(poly, 0)
176 | new_polys.append(poly)
177 | return img, new_polys
178 |
179 | def random_scale(self,img, polys, min_size):
180 | polys,len_max = solve_polys(polys)
181 | h, w = img.shape[0:2]
182 | new_polys = []
183 | for poly in polys:
184 | poly = np.asarray(poly)
185 | poly = poly / ([w * 1.0, h * 1.0] * len_max)
186 | new_polys.append(poly)
187 | new_polys = np.array(new_polys)
188 | if max(h, w) > 1280:
189 | scale = 1280.0 / max(h, w)
190 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale)
191 | h, w = img.shape[0:2]
192 | random_scale = np.array([0.5, 1.0, 2.0, 3.0])
193 | scale = np.random.choice(random_scale)
194 | if min(h, w) * scale <= min_size:
195 | scale = (min_size + 10) * 1.0 / min(h, w)
196 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale)
197 | new_polys = np.reshape(new_polys * ([img.shape[1], img.shape[0]] * len_max),
198 | (new_polys.shape[0], polys.shape[1] // 2, 2))
199 | return img, new_polys
200 |
201 | def random_flip(self,img, polys):
202 | if (np.random.rand(1)[0] > 0.5):
203 | aug_bin = aug_img.Sequential([aug_img.Fliplr((1))])
204 | img = aug_bin.augment_image(img)
205 | new_polys = []
206 | for poly in polys:
207 | poly = self.augment_poly(aug_bin, img.shape, poly)
208 | poly = np.maximum(poly, 0)
209 | new_polys.append(poly)
210 | else:
211 | new_polys = polys
212 | return img, new_polys
213 |
214 | def random_crop_db(self,img, polys, dont_care):
215 | img, new_polys,new_dotcare = self.random_crop_data.process(img, polys, dont_care)
216 | return img, new_polys,new_dotcare
217 |
218 | def random_crop_pse(self,imgs, img_size=(640, 640)):
219 | h, w = imgs[0].shape[0:2]
220 | th, tw = img_size
221 | if w == tw and h == th:
222 | return imgs
223 |
224 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
225 | tl = np.min(np.where(imgs[1] > 0), axis=1) - img_size
226 | tl[tl < 0] = 0
227 | br = np.max(np.where(imgs[1] > 0), axis=1) - img_size
228 | br[br < 0] = 0
229 | br[0] = min(br[0], h - th)
230 | br[1] = min(br[1], w - tw)
231 |
232 | i = random.randint(tl[0], br[0])
233 | j = random.randint(tl[1], br[1])
234 | else:
235 | i = random.randint(0, h - th)
236 | j = random.randint(0, w - tw)
237 |
238 | # return i, j, th, tw
239 | for idx in range(len(imgs)):
240 | if len(imgs[idx].shape) == 3:
241 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
242 | else:
243 | imgs[idx] = imgs[idx][i:i + th, j:j + tw]
244 | return imgs
245 |
246 |
247 |
--------------------------------------------------------------------------------
/pruned/train_fintune.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 _*-
2 | """
3 | @author:fxw
4 | @file: train.py.py
5 | @time: 2020/04/28
6 | """
7 | import sys
8 |
9 | sys.path.append('/home/aistudio/external-libraries')
10 | sys.path.append('./')
11 | import os
12 | import torch
13 | import torch.nn as nn
14 | import yaml
15 | import argparse
16 | import numpy as np
17 |
18 | olderr = np.seterr(all='ignore')
19 | from models.DBNet import DBNet
20 | from torch.autograd import Variable
21 | from loss.loss import L1BalanceCELoss
22 | from dataloader.dataload import DataLoader
23 | from utils.Logger import Logger
24 | from utils.metrics import runningScore
25 | from utils.model_eval import val
26 | from utils.tools import *
27 | from utils.set_optimizer import *
28 | from pruned.get_pruned_model import load_prune_model
29 |
30 | torch.backends.cudnn.benchmark = False
31 | torch.backends.cudnn.deterministic = True
32 |
33 |
34 | def set_seed(seed):
35 | import numpy as np
36 | import random
37 | import torch
38 | random.seed(seed)
39 | np.random.seed(seed)
40 | torch.manual_seed(seed)
41 | torch.cuda.manual_seed(seed)
42 | torch.cuda.manual_seed_all(seed)
43 |
44 |
45 | GLOBAL_WORKER_ID = None
46 | GLOBAL_SEED = 2020
47 |
48 |
49 | def worker_init_fn(worker_id):
50 | global GLOBAL_WORKER_ID
51 | GLOBAL_WORKER_ID = worker_id
52 | set_seed(GLOBAL_SEED + worker_id)
53 |
54 |
55 | def train_net(config):
56 | os.environ["CUDA_VISIBLE_DEVICES"] = config['pruned']['gpu_id']
57 | data_loader = DataLoader(config)
58 | train_loader = torch.utils.data.DataLoader(
59 | data_loader,
60 | batch_size=config['train']['batch_size'],
61 | shuffle=True,
62 | num_workers=config['train']['num_workers'],
63 | worker_init_fn=worker_init_fn,
64 | drop_last=True,
65 | pin_memory=False)
66 |
67 | start_epoch = 0
68 | running_metric_binary = runningScore(2)
69 |
70 | if not (os.path.exists(config['train']['checkpoints'])):
71 | os.mkdir(config['train']['checkpoints'])
72 | checkpoints = os.path.join(config['pruned']['save_checkpoints'], "DB_%s_bs_%d_ep_%d" % (config['train']['backbone'],
73 | config['train'][
74 | 'batch_size'],
75 | config['train']['n_epoch']))
76 | if not (os.path.exists(checkpoints)):
77 | os.mkdir(checkpoints)
78 |
79 | model = DBNet(config)
80 | criterion = L1BalanceCELoss()
81 | optimizer = torch.optim.SGD(model.parameters(), lr=config['pruned']['finetune_lr'], momentum=0.99,
82 | weight_decay=5e-4)
83 |
84 | if config['pruned']['restore']:
85 | print('Resuming from checkpoint.')
86 | assert os.path.isfile(config['pruned']['resume']), 'Error: no checkpoint directory found!'
87 | checkpoint = torch.load(config['pruned']['resume'])
88 | start_epoch = checkpoint['epoch']
89 | model = load_prune_model(model, config['pruned']['checkpoints_dict']).cuda()
90 | model.load_state_dict(checkpoint['state_dict'])
91 | optimizer.load_state_dict(checkpoint['optimizer'])
92 | log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['train']['backbone'], resume=True)
93 | else:
94 | print('Training from scratch.')
95 | model_dict = torch.load(config['pruned']['pruned_checkpoints'])
96 | model = load_prune_model(model, config['pruned']['checkpoints_dict']).cuda()
97 | print(model)
98 | try:
99 | model.load_state_dict(model_dict)
100 | except:
101 | state = model.state_dict()
102 | for key in state.keys():
103 | state[key] = model_dict['module.' + key]
104 | model.load_state_dict(state)
105 | log_write = Logger(os.path.join(checkpoints, 'log.txt'), title=config['train']['backbone'])
106 | log_write.set_names(
107 | [' epoch', 'Total loss', ' Bce loss', 'Thresh loss', ' L1 loss', 'Binary Acc', 'Binary IoU',
108 | ' rescall', ' precision', ' hmean'])
109 | max_hmean = -1
110 | for epoch in range(start_epoch, config['pruned']['n_epoch']):
111 | model.train()
112 |
113 | bce_loss_list = []
114 | thresh_loss_list = []
115 | l1_loss_list = []
116 | total_loss_list = []
117 |
118 | if (config['train']['decay_method'] == 'e_decay'):
119 | adjust_learning_rate_poly(config['pruned']['finetune_lr'], optimizer, epoch,
120 | max_epoch=config['pruned']['n_epoch'], factor=0.9)
121 | else:
122 | adjust_learning_rate(config, optimizer, epoch, config['train']['gama'])
123 |
124 | for batch_idx, (imgs, gts, gt_masks, thresh_maps, thresh_masks) in enumerate(train_loader):
125 | imgs = Variable(imgs.cuda())
126 | gts = Variable(gts.cuda())
127 | gt_masks = Variable(gt_masks.cuda())
128 | thresh_maps = Variable(thresh_maps.cuda())
129 | thresh_masks = Variable(thresh_masks.cuda())
130 | batch = {}
131 | batch['gt'] = gts
132 | batch['mask'] = gt_masks
133 | batch['thresh_map'] = thresh_maps
134 | batch['thresh_mask'] = thresh_masks
135 |
136 | pre = model(imgs)
137 | loss, metrics = criterion(pre, batch)
138 |
139 | optimizer.zero_grad()
140 | loss.backward()
141 | optimizer.step()
142 |
143 | score_binary = cal_binary_score(pre['binary'], gts, gt_masks.unsqueeze(1), running_metric_binary)
144 |
145 | bce_loss_list.append(metrics['bce_loss'].item())
146 | thresh_loss_list.append(metrics['thresh_loss'].item())
147 | l1_loss_list.append(metrics['l1_loss'].item())
148 | total_loss_list.append(loss.item())
149 | if batch_idx % config['train']['show_step'] == 0:
150 | if (config['train']['print_format'] == 'linux'):
151 | headers = ['epoch/epochs', 'batch/batchs', 'TotalLoss', 'BceLoss', ' ThreshLoss', 'L1Loss',
152 | 'Binary Acc', 'Binary IoU', 'Lr Rate']
153 | show_item = [[str(epoch) + '/' + str(config['pruned']['n_epoch']),
154 | str(batch_idx + 1) + '/' + str(len(train_loader)),
155 | get_str(np.mean(total_loss_list)),
156 | get_str(np.mean(bce_loss_list)),
157 | get_str(np.mean(thresh_loss_list)),
158 | get_str(np.mean(l1_loss_list)),
159 | get_str(score_binary['Mean Acc']),
160 | get_str(score_binary['Mean IoU']),
161 | get_str(optimizer.param_groups[0]['lr'])
162 | ]]
163 | print_table(headers, show_item, type_str='train')
164 | else:
165 | output_log = '({epoch}/{epochs}/{batch}/{size}) | TotalLoss: {total_loss:.4f} | BceLoss: {bce_loss:.4f} | ThreshLoss: {thresh_loss: .4f} | L1Loss: {l1_loss: .4f} | Binary Acc: {bin_acc: .4f} | Binary IoU: {bin_iou: .4f} | Lr: {lr: .4f}'.format(
166 | epoch=epoch,
167 | epochs=config['pruned']['n_epoch'],
168 | batch=batch_idx + 1,
169 | size=len(train_loader),
170 | total_loss=np.mean(total_loss_list),
171 | bce_loss=np.mean(bce_loss_list),
172 | thresh_loss=np.mean(thresh_loss_list),
173 | l1_loss=np.mean(l1_loss_list),
174 | bin_acc=score_binary['Mean Acc'],
175 | bin_iou=score_binary['Mean IoU'],
176 | lr=optimizer.param_groups[0]['lr']
177 | )
178 | print(output_log)
179 | sys.stdout.flush()
180 |
181 | if (epoch > config['pruned']['start_val_epoch']):
182 | result_dict = val(model, config)
183 | rescall, precision, hmean = result_dict['recall'], result_dict['precision'], result_dict['hmean']
184 | print('epoch:', epoch, 'rescall:', rescall, 'precision:', precision, 'hmean:', hmean)
185 | else:
186 | rescall = 0
187 | precision = 0
188 | hmean = 0
189 | log_write.append([epoch, np.mean(total_loss_list), np.mean(bce_loss_list), np.mean(thresh_loss_list),
190 | np.mean(l1_loss_list), score_binary['Mean Acc'], score_binary['Mean IoU'],
191 | rescall, precision, hmean])
192 | if (hmean > max_hmean and config['pruned']['start_val_epoch'] < config['pruned']['n_epoch']):
193 | max_hmean = hmean
194 | save_checkpoint({
195 | 'epoch': epoch + 1,
196 | 'state_dict': model.state_dict(),
197 | 'lr': optimizer.param_groups[0]['lr'],
198 | 'optimizer': optimizer.state_dict(),
199 | }, checkpoint=checkpoints, filename='best_model.pth.tar')
200 |
201 | save_checkpoint({
202 | 'epoch': epoch + 1,
203 | 'state_dict': model.state_dict(),
204 | 'lr': optimizer.param_groups[0]['lr'],
205 | 'optimizer': optimizer.state_dict(),
206 | }, checkpoint=checkpoints)
207 |
208 |
209 | if __name__ == '__main__':
210 | stream = open('./config.yaml', 'r', encoding='utf-8')
211 | config = yaml.load(stream, Loader=yaml.FullLoader)
212 | train_net(config)
--------------------------------------------------------------------------------
/pruned/prune.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 _*-
2 | """
3 | @author:fxw
4 | @file: prune.py
5 | @time: 2020/07/17
6 | """
7 | """
8 | #!-*- coding=utf-8 -*-
9 | @author: BADBADBADBADBOY
10 | @contact: 2441124901@qq.com
11 | @software: PyCharm Community Edition
12 | @file: prune.py
13 | @time: 2020/6/27 10:23
14 |
15 | """
16 | import sys
17 |
18 | sys.path.append('/home/aistudio/external-libraries')
19 | sys.path.append('./')
20 | import yaml
21 | from models.DBNet import DBNet
22 | import torch
23 | import torch.nn as nn
24 | import numpy as np
25 | import collections
26 | import torchvision.transforms as transforms
27 | import cv2
28 | import os
29 | import argparse
30 | import math
31 | from PIL import Image
32 | from torch.autograd import Variable
33 |
34 |
35 | def prune(config):
36 | os.environ["CUDA_VISIBLE_DEVICES"] = config['pruned']['gpu_id']
37 |
38 | model = DBNet(config).cuda()
39 | model_dict = torch.load(config['pruned']['checkpoints'])['state_dict']
40 | state = model.state_dict()
41 | for key in state.keys():
42 | if key in model_dict.keys():
43 | state[key] = model_dict[key]
44 | model.load_state_dict(state)
45 |
46 |
47 | bn_weights = []
48 | for m in model.modules():
49 | if (isinstance(m, nn.BatchNorm2d)):
50 | bn_weights.append(m.weight.data.abs().clone())
51 | bn_weights = torch.cat(bn_weights, 0)
52 |
53 | sort_result, sort_index = torch.sort(bn_weights)
54 |
55 | thresh_index = int(config['pruned']['cut_percent'] * bn_weights.shape[0])
56 |
57 | if (thresh_index == bn_weights.shape[0]):
58 | thresh_index = bn_weights.shape[0] - 1
59 |
60 | prued = 0
61 | prued_mask = []
62 | bn_index = []
63 | conv_index = []
64 | remain_channel_nums = []
65 | for k, m in enumerate(model.modules()):
66 | if (k > 69):
67 | break
68 | if (isinstance(m, nn.BatchNorm2d)):
69 | bn_weight = m.weight.data.clone()
70 | mask = bn_weight.abs().gt(sort_result[thresh_index])
71 | remain_channel = mask.sum()
72 |
73 | if (remain_channel == 0):
74 | remain_channel = 1
75 | mask[int(torch.argmax(bn_weight))] = 1
76 |
77 | v = 0
78 | n = 1
79 | if (remain_channel % config['pruned']['base_num'] != 0):
80 | if (remain_channel > config['pruned']['base_num']):
81 | while (v < remain_channel):
82 | n += 1
83 | v = config['pruned']['base_num'] * n
84 | if (remain_channel - (v - config['pruned']['base_num']) < v - remain_channel):
85 | remain_channel = v - config['pruned']['base_num']
86 | else:
87 | remain_channel = v
88 | if (remain_channel > bn_weight.size()[0]):
89 | remain_channel = bn_weight.size()[0]
90 | remain_channel = torch.tensor(remain_channel)
91 | result, index = torch.sort(bn_weight)
92 | mask = bn_weight.abs().ge(result[-remain_channel])
93 |
94 | remain_channel_nums.append(int(mask.sum()))
95 | prued_mask.append(mask)
96 | bn_index.append(k)
97 | prued += mask.shape[0] - mask.sum()
98 | elif (isinstance(m, nn.Conv2d)):
99 | conv_index.append(k)
100 | print('remain_channel_nums', remain_channel_nums)
101 | print('total_prune_ratio:', float(prued) / bn_weights.shape[0])
102 | print('bn_index', bn_index)
103 |
104 | new_model = DBNet(config).cuda()
105 |
106 | merge1_index = [3, 12, 18]
107 | merge2_index = [25, 28, 34]
108 | merge3_index = [41, 44, 50]
109 | merge4_index = [57, 60, 66]
110 |
111 | index_0 = []
112 | for item in merge1_index:
113 | index_0.append(bn_index.index(item))
114 | mask1 = prued_mask[index_0[0]] | prued_mask[index_0[1]] | prued_mask[index_0[2]]
115 |
116 | index_1 = []
117 | for item in merge2_index:
118 | index_1.append(bn_index.index(item))
119 | mask2 = prued_mask[index_1[0]] | prued_mask[index_1[1]] | prued_mask[index_1[2]]
120 |
121 | index_2 = []
122 | for item in merge3_index:
123 | index_2.append(bn_index.index(item))
124 | mask3 = prued_mask[index_2[0]] | prued_mask[index_2[1]] | prued_mask[index_2[2]]
125 |
126 | index_3 = []
127 | for item in merge4_index:
128 | index_3.append(bn_index.index(item))
129 | mask4 = prued_mask[index_3[0]] | prued_mask[index_3[1]] | prued_mask[index_3[2]]
130 |
131 | for index in index_0:
132 | prued_mask[index] = mask1
133 |
134 | for index in index_1:
135 | prued_mask[index] = mask2
136 |
137 | for index in index_2:
138 | prued_mask[index] = mask3
139 |
140 | for index in index_3:
141 | prued_mask[index] = mask4
142 |
143 | print(model)
144 |
145 | ##############################################################
146 | index_bn = 0
147 | index_conv = 0
148 |
149 | bn_mask = []
150 | conv_in_mask = []
151 | conv_out_mask = []
152 | tag = 0
153 | for m in new_model.modules():
154 | if (tag > 69):
155 | break
156 | if (isinstance(m, nn.BatchNorm2d)):
157 | m.num_features = prued_mask[index_bn].sum()
158 | bn_mask.append(prued_mask[index_bn])
159 | index_bn += 1
160 | elif (isinstance(m, nn.Conv2d)):
161 | if (index_conv == 0):
162 | m.in_channels = 3
163 | conv_in_mask.append(torch.ones(3))
164 | else:
165 | m.in_channels = prued_mask[index_conv - 1].sum()
166 | conv_in_mask.append(prued_mask[index_conv - 1])
167 | m.out_channels = prued_mask[index_conv].sum()
168 | conv_out_mask.append(prued_mask[index_conv])
169 | index_conv += 1
170 | tag += 1
171 |
172 | conv_change_index = [27, 43, 59] #
173 | change_conv_bn_index = [18, 34, 50] #
174 | tag = 0
175 | for m in new_model.modules():
176 | if (tag > 69):
177 | break
178 | if (isinstance(m, nn.Conv2d)):
179 | if (tag in conv_change_index):
180 | index = conv_change_index.index(tag)
181 | index = change_conv_bn_index[index]
182 | index = bn_index.index(index)
183 | mask = prued_mask[index]
184 | conv_in_mask[index + 3] = mask
185 | m.in_channels = mask.sum()
186 | tag += 1
187 |
188 | #############################################################
189 | bn_i = 0
190 | conv_i = 0
191 | scale_i = 0
192 | scale_mask = [mask4, mask3, mask2, mask1]
193 | # scale = [70,86,90,94] # FPN
194 | scale = config['pruned']['scale'] # DB
195 | for [m0, m1] in zip(model.modules(), new_model.modules()):
196 | if (scale_i > 69):
197 | if isinstance(m0, nn.Conv2d):
198 | if (scale_i in scale):
199 | index = scale.index(scale_i)
200 | m1.in_channels = scale_mask[index].sum()
201 | idx0 = np.squeeze(np.argwhere(np.asarray(scale_mask[index].cpu().numpy())))
202 | idx1 = np.squeeze(np.argwhere(np.asarray(torch.ones(256).cpu().numpy())))
203 | if idx0.size == 1:
204 | idx0 = np.resize(idx0, (1,))
205 | if idx1.size == 1:
206 | idx1 = np.resize(idx1, (1,))
207 | w = m0.weight.data[:, idx0, :, :].clone()
208 | m1.weight.data = w[idx1, :, :, :].clone()
209 | if m1.bias is not None:
210 | m1.bias.data = m0.bias.data[idx1].clone()
211 |
212 | else:
213 | m1.weight.data = m0.weight.data.clone()
214 | if m1.bias is not None:
215 | m1.bias.data = m0.bias.data.clone()
216 |
217 | elif isinstance(m0, nn.BatchNorm2d):
218 | m1.weight.data = m0.weight.data.clone()
219 | if m1.bias is not None:
220 | m1.bias.data = m0.bias.data.clone()
221 | m1.running_mean = m0.running_mean.clone()
222 | m1.running_var = m0.running_var.clone()
223 |
224 | else:
225 | if isinstance(m0, nn.BatchNorm2d):
226 | idx1 = np.squeeze(np.argwhere(np.asarray(bn_mask[bn_i].cpu().numpy())))
227 | if idx1.size == 1:
228 | idx1 = np.resize(idx1, (1,))
229 | m1.weight.data = m0.weight.data[idx1].clone()
230 | if m1.bias is not None:
231 | m1.bias.data = m0.bias.data[idx1].clone()
232 | m1.running_mean = m0.running_mean[idx1].clone()
233 | m1.running_var = m0.running_var[idx1].clone()
234 | bn_i += 1
235 | elif isinstance(m0, nn.Conv2d):
236 | if (isinstance(conv_in_mask[conv_i], list)):
237 | idx0 = np.squeeze(np.argwhere(np.asarray(torch.cat(conv_in_mask[conv_i], 0).cpu().numpy())))
238 | else:
239 | idx0 = np.squeeze(np.argwhere(np.asarray(conv_in_mask[conv_i].cpu().numpy())))
240 | idx1 = np.squeeze(np.argwhere(np.asarray(conv_out_mask[conv_i].cpu().numpy())))
241 | if idx0.size == 1:
242 | idx0 = np.resize(idx0, (1,))
243 | if idx1.size == 1:
244 | idx1 = np.resize(idx1, (1,))
245 | w = m0.weight.data[:, idx0, :, :].clone()
246 | m1.weight.data = w[idx1, :, :, :].clone()
247 | if m1.bias is not None:
248 | m1.bias.data = m0.bias.data[idx1].clone()
249 | conv_i += 1
250 |
251 | scale_i += 1
252 |
253 | print(new_model)
254 |
255 | save_obj = {'prued_mask': prued_mask, 'bn_index': bn_index}
256 | torch.save(save_obj, os.path.join(config['pruned']['save_checkpoints'], 'pruned_dict.dict'))
257 | torch.save(new_model.state_dict(), os.path.join(config['pruned']['save_checkpoints'], 'pruned_dict.pth.tar'))
258 |
259 |
260 | if __name__ == '__main__':
261 |
262 | stream = open('./config.yaml', 'r', encoding='utf-8')
263 | config = yaml.load(stream, Loader=yaml.FullLoader)
264 | prune(config)
265 |
--------------------------------------------------------------------------------
/models/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | #!-*- coding=utf-8 -*-
3 | @author: BADBADBADBADBOY
4 | @contact: 2441124901@qq.com
5 | @software: PyCharm Community Edition
6 | @file: resnet.py
7 | @time: 2020/7/4 15:16
8 |
9 | """
10 | import torch.nn as nn
11 | import math
12 | import torch.utils.model_zoo as model_zoo
13 |
14 | BatchNorm2d = nn.BatchNorm2d
15 |
16 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
17 | 'resnet152']
18 |
19 | model_urls = {
20 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
21 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
22 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
23 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
24 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
25 | }
26 |
27 |
28 | def constant_init(module, constant, bias=0):
29 | nn.init.constant_(module.weight, constant)
30 | if hasattr(module, 'bias'):
31 | nn.init.constant_(module.bias, bias)
32 |
33 |
34 | def conv3x3(in_planes, out_planes, stride=1):
35 | """3x3 convolution with padding"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
37 | padding=1, bias=False)
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | expansion = 1
42 |
43 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
44 | super(BasicBlock, self).__init__()
45 | self.with_dcn = dcn is not None
46 | self.conv1 = conv3x3(inplanes, planes, stride)
47 | self.bn1 = BatchNorm2d(planes)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.with_modulated_dcn = False
50 | if self.with_dcn:
51 | fallback_on_stride = dcn.get('fallback_on_stride', False)
52 | self.with_modulated_dcn = dcn.get('modulated', False)
53 | # self.conv2 = conv3x3(planes, planes)
54 | if not self.with_dcn or fallback_on_stride:
55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
56 | padding=1, bias=False)
57 | else:
58 | deformable_groups = dcn.get('deformable_groups', 1)
59 | if not self.with_modulated_dcn:
60 | from models.dcn import DeformConv
61 | conv_op = DeformConv
62 | offset_channels = 18
63 | else:
64 | from models.dcn import ModulatedDeformConv
65 | conv_op = ModulatedDeformConv
66 | offset_channels = 27
67 | self.conv2_offset = nn.Conv2d(
68 | planes,
69 | deformable_groups * offset_channels,
70 | kernel_size=3,
71 | padding=1)
72 | self.conv2 = conv_op(
73 | planes,
74 | planes,
75 | kernel_size=3,
76 | padding=1,
77 | deformable_groups=deformable_groups,
78 | bias=False)
79 | self.bn2 = BatchNorm2d(planes)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | residual = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | # out = self.conv2(out)
91 | if not self.with_dcn:
92 | out = self.conv2(out)
93 | elif self.with_modulated_dcn:
94 | offset_mask = self.conv2_offset(out)
95 | offset = offset_mask[:, :18, :, :]
96 | mask = offset_mask[:, -9:, :, :].sigmoid()
97 | out = self.conv2(out, offset, mask)
98 | else:
99 | offset = self.conv2_offset(out)
100 | out = self.conv2(out, offset)
101 | out = self.bn2(out)
102 |
103 | if self.downsample is not None:
104 | residual = self.downsample(x)
105 |
106 | out += residual
107 | out = self.relu(out)
108 |
109 | return out
110 |
111 |
112 | class Bottleneck(nn.Module):
113 | expansion = 4
114 |
115 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
116 | super(Bottleneck, self).__init__()
117 | self.with_dcn = dcn is not None
118 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
119 | self.bn1 = BatchNorm2d(planes)
120 | fallback_on_stride = False
121 | self.with_modulated_dcn = False
122 | if self.with_dcn:
123 | fallback_on_stride = dcn.get('fallback_on_stride', False)
124 | self.with_modulated_dcn = dcn.get('modulated', False)
125 | if not self.with_dcn or fallback_on_stride:
126 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
127 | stride=stride, padding=1, bias=False)
128 | else:
129 | deformable_groups = dcn.get('deformable_groups', 1)
130 | if not self.with_modulated_dcn:
131 | from models.dcn import DeformConv
132 | conv_op = DeformConv
133 | offset_channels = 18
134 | else:
135 | from models.dcn import ModulatedDeformConv
136 | conv_op = ModulatedDeformConv
137 | offset_channels = 27
138 | self.conv2_offset = nn.Conv2d(
139 | planes, deformable_groups * offset_channels,
140 | kernel_size=3,
141 | padding=1)
142 | self.conv2 = conv_op(
143 | planes, planes, kernel_size=3, padding=1, stride=stride,
144 | deformable_groups=deformable_groups, bias=False)
145 | self.bn2 = BatchNorm2d(planes)
146 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
147 | self.bn3 = BatchNorm2d(planes * 4)
148 | self.relu = nn.ReLU(inplace=True)
149 | self.downsample = downsample
150 | self.stride = stride
151 | self.dcn = dcn
152 | self.with_dcn = dcn is not None
153 |
154 | def forward(self, x):
155 | residual = x
156 |
157 | out = self.conv1(x)
158 | out = self.bn1(out)
159 | out = self.relu(out)
160 |
161 | # out = self.conv2(out)
162 | if not self.with_dcn:
163 | out = self.conv2(out)
164 | elif self.with_modulated_dcn:
165 | offset_mask = self.conv2_offset(out)
166 | offset = offset_mask[:, :18, :, :]
167 | mask = offset_mask[:, -9:, :, :].sigmoid()
168 | out = self.conv2(out, offset, mask)
169 | else:
170 | offset = self.conv2_offset(out)
171 | out = self.conv2(out, offset)
172 | out = self.bn2(out)
173 | out = self.relu(out)
174 |
175 | out = self.conv3(out)
176 | out = self.bn3(out)
177 |
178 | if self.downsample is not None:
179 | residual = self.downsample(x)
180 |
181 | out += residual
182 | out = self.relu(out)
183 |
184 | return out
185 |
186 |
187 | class ResNet(nn.Module):
188 | def __init__(self, block, layers, num_classes=1000,
189 | dcn=None, stage_with_dcn=(False, False, False, False)):
190 | self.dcn = dcn
191 | self.stage_with_dcn = stage_with_dcn
192 | self.inplanes = 64
193 | super(ResNet, self).__init__()
194 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
195 | bias=False)
196 | self.bn1 = BatchNorm2d(64)
197 | self.relu = nn.ReLU(inplace=True)
198 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
199 | self.layer1 = self._make_layer(block, 64, layers[0])
200 | self.layer2 = self._make_layer(
201 | block, 128, layers[1], stride=2, dcn=dcn)
202 | self.layer3 = self._make_layer(
203 | block, 256, layers[2], stride=2, dcn=dcn)
204 | self.layer4 = self._make_layer(
205 | block, 512, layers[3], stride=2, dcn=dcn)
206 | #self.avgpool = nn.AvgPool2d(7, stride=1)
207 | #self.fc = nn.Linear(512 * block.expansion, num_classes)
208 |
209 | #self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1)
210 |
211 | for m in self.modules():
212 | if isinstance(m, nn.Conv2d):
213 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
214 | m.weight.data.normal_(0, math.sqrt(2. / n))
215 | elif isinstance(m, BatchNorm2d):
216 | m.weight.data.fill_(1)
217 | m.bias.data.zero_()
218 | if self.dcn is not None:
219 | for m in self.modules():
220 | if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
221 | if hasattr(m, 'conv2_offset'):
222 | constant_init(m.conv2_offset, 0)
223 |
224 | def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
225 | downsample = None
226 | if stride != 1 or self.inplanes != planes * block.expansion:
227 | downsample = nn.Sequential(
228 | nn.Conv2d(self.inplanes, planes * block.expansion,
229 | kernel_size=1, stride=stride, bias=False),
230 | BatchNorm2d(planes * block.expansion),
231 | )
232 |
233 | layers = []
234 | layers.append(block(self.inplanes, planes,
235 | stride, downsample, dcn=dcn))
236 | self.inplanes = planes * block.expansion
237 | for i in range(1, blocks):
238 | layers.append(block(self.inplanes, planes, dcn=dcn))
239 |
240 | return nn.Sequential(*layers)
241 |
242 | def forward(self, x):
243 | x = self.conv1(x)
244 | x = self.bn1(x)
245 | x = self.relu(x)
246 | x = self.maxpool(x)
247 |
248 | x2 = self.layer1(x)
249 | x3 = self.layer2(x2)
250 | x4 = self.layer3(x3)
251 | x5 = self.layer4(x4)
252 |
253 | return x2, x3, x4, x5
254 |
255 |
256 | def resnet18(pretrained=True, **kwargs):
257 | """Constructs a ResNet-18 model.
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | """
261 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
262 | if pretrained:
263 | model.load_state_dict(model_zoo.load_url(
264 | model_urls['resnet18']), strict=False)
265 | return model
266 |
267 |
268 | def deformable_resnet18(pretrained=True, **kwargs):
269 | """Constructs a ResNet-18 model.
270 | Args:
271 | pretrained (bool): If True, returns a model pre-trained on ImageNet
272 | """
273 | model = ResNet(BasicBlock, [2, 2, 2, 2],
274 | dcn=dict(modulated=True,
275 | deformable_groups=1,
276 | fallback_on_stride=False),
277 | stage_with_dcn=[False, True, True, True], **kwargs)
278 | if pretrained:
279 | model.load_state_dict(model_zoo.load_url(
280 | model_urls['resnet18']), strict=False)
281 | return model
282 |
283 |
284 | def resnet34(pretrained=True, **kwargs):
285 | """Constructs a ResNet-34 model.
286 | Args:
287 | pretrained (bool): If True, returns a model pre-trained on ImageNet
288 | """
289 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
290 | if pretrained:
291 | model.load_state_dict(model_zoo.load_url(
292 | model_urls['resnet34']), strict=False)
293 | return model
294 |
295 |
296 | def resnet50(pretrained=True, **kwargs):
297 | """Constructs a ResNet-50 model.
298 | Args:
299 | pretrained (bool): If True, returns a model pre-trained on ImageNet
300 | """
301 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
302 | if pretrained:
303 | model.load_state_dict(model_zoo.load_url(
304 | model_urls['resnet50']), strict=False)
305 | return model
306 |
307 |
308 | def deformable_resnet50(pretrained=True, **kwargs):
309 | """Constructs a ResNet-50 model with deformable conv.
310 | Args:
311 | pretrained (bool): If True, returns a model pre-trained on ImageNet
312 | """
313 | model = ResNet(Bottleneck, [3, 4, 6, 3],
314 | dcn=dict(modulated=True,
315 | deformable_groups=1,
316 | fallback_on_stride=False),
317 | stage_with_dcn=[False, True, True, True],
318 | **kwargs)
319 | if pretrained:
320 | model.load_state_dict(model_zoo.load_url(
321 | model_urls['resnet50']), strict=False)
322 | return model
323 |
324 |
325 | def resnet101(pretrained=True, **kwargs):
326 | """Constructs a ResNet-101 model.
327 | Args:
328 | pretrained (bool): If True, returns a model pre-trained on ImageNet
329 | """
330 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
331 | if pretrained:
332 | model.load_state_dict(model_zoo.load_url(
333 | model_urls['resnet101']), strict=False)
334 | return model
335 |
336 |
337 | def resnet152(pretrained=True, **kwargs):
338 | """Constructs a ResNet-152 model.
339 | Args:
340 | pretrained (bool): If True, returns a model pre-trained on ImageNet
341 | """
342 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
343 | if pretrained:
344 | model.load_state_dict(model_zoo.load_url(
345 | model_urls['resnet152']), strict=False)
346 | return model
347 |
348 |
--------------------------------------------------------------------------------
/cal_rescall/script.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from collections import namedtuple
3 | from . import rrc_evaluation_funcs
4 | import Polygon as plg
5 | import numpy as np
6 |
7 |
8 | def default_evaluation_params():
9 | """
10 | default_evaluation_params: Default parameters to use for the validation and evaluation.
11 | """
12 | return {
13 | 'IOU_CONSTRAINT': 0.5,
14 | 'AREA_PRECISION_CONSTRAINT': 0.5,
15 | 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt',
16 | 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt',
17 | 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
18 | 'CRLF': False, # Lines are delimited by Windows CRLF format
19 | 'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated
20 | 'PER_SAMPLE_RESULTS': True # Generate per sample results and produce data for visualization
21 | }
22 |
23 |
24 | def validate_data(gtFilePath, submFilePath, evaluationParams):
25 | """
26 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
27 | Validates also that there are no missing files in the folder.
28 | If some error detected, the method raises the error
29 | """
30 | gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
31 |
32 | subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
33 |
34 | # Validate format of GroundTruth
35 | for k in gt:
36 | rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True)
37 |
38 | # Validate format of results
39 | for k in subm:
40 | if (k in gt) == False:
41 | raise Exception("The sample %s not present in GT" % k)
42 |
43 | rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'],
44 | False, evaluationParams['CONFIDENCES'])
45 |
46 |
47 | def evaluate_method(gtFilePath, submFilePath, evaluationParams):
48 | """
49 | Method evaluate_method: evaluate method and returns the results
50 | Results. Dictionary with the following values:
51 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
52 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
53 | """
54 |
55 | def polygon_from_points(points):
56 | """
57 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
58 | """
59 | resBoxes = np.empty([1, 8], dtype='int32')
60 | resBoxes[0, 0] = int(points[0])
61 | resBoxes[0, 4] = int(points[1])
62 | resBoxes[0, 1] = int(points[2])
63 | resBoxes[0, 5] = int(points[3])
64 | resBoxes[0, 2] = int(points[4])
65 | resBoxes[0, 6] = int(points[5])
66 | resBoxes[0, 3] = int(points[6])
67 | resBoxes[0, 7] = int(points[7])
68 | pointMat = resBoxes[0].reshape([2, 4]).T
69 | return plg.Polygon(pointMat)
70 |
71 | def rectangle_to_polygon(rect):
72 | resBoxes = np.empty([1, 8], dtype='int32')
73 | resBoxes[0, 0] = int(rect.xmin)
74 | resBoxes[0, 4] = int(rect.ymax)
75 | resBoxes[0, 1] = int(rect.xmin)
76 | resBoxes[0, 5] = int(rect.ymin)
77 | resBoxes[0, 2] = int(rect.xmax)
78 | resBoxes[0, 6] = int(rect.ymin)
79 | resBoxes[0, 3] = int(rect.xmax)
80 | resBoxes[0, 7] = int(rect.ymax)
81 |
82 | pointMat = resBoxes[0].reshape([2, 4]).T
83 |
84 | return plg.Polygon(pointMat)
85 |
86 | def rectangle_to_points(rect):
87 | points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin),
88 | int(rect.xmin), int(rect.ymin)]
89 | return points
90 |
91 | def get_union(pD, pG):
92 | areaA = pD.area();
93 | areaB = pG.area();
94 | return areaA + areaB - get_intersection(pD, pG);
95 |
96 | def get_intersection_over_union(pD, pG):
97 | try:
98 | return get_intersection(pD, pG) / get_union(pD, pG);
99 | except:
100 | return 0
101 |
102 | def get_intersection(pD, pG):
103 | pInt = pD & pG
104 | if len(pInt) == 0:
105 | return 0
106 | return pInt.area()
107 |
108 | def compute_ap(confList, matchList, numGtCare):
109 | correct = 0
110 | AP = 0
111 | if len(confList) > 0:
112 | confList = np.array(confList)
113 | matchList = np.array(matchList)
114 | sorted_ind = np.argsort(-confList)
115 | confList = confList[sorted_ind]
116 | matchList = matchList[sorted_ind]
117 | for n in range(len(confList)):
118 | match = matchList[n]
119 | if match:
120 | correct += 1
121 | AP += float(correct) / (n + 1)
122 |
123 | if numGtCare > 0:
124 | AP /= numGtCare
125 |
126 | return AP
127 |
128 | perSampleMetrics = {}
129 |
130 | matchedSum = 0
131 |
132 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
133 |
134 | gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
135 | subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
136 |
137 | numGlobalCareGt = 0;
138 | numGlobalCareDet = 0;
139 |
140 | arrGlobalConfidences = [];
141 | arrGlobalMatches = [];
142 |
143 | for resFile in gt:
144 |
145 | gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile])
146 | recall = 0
147 | precision = 0
148 | hmean = 0
149 |
150 | detMatched = 0
151 |
152 | iouMat = np.empty([1, 1])
153 |
154 | gtPols = []
155 | detPols = []
156 |
157 | gtPolPoints = []
158 | detPolPoints = []
159 |
160 | # Array of Ground Truth Polygons' keys marked as don't Care
161 | gtDontCarePolsNum = []
162 | # Array of Detected Polygons' matched with a don't Care GT
163 | detDontCarePolsNum = []
164 |
165 | pairs = []
166 | detMatchedNums = []
167 |
168 | arrSampleConfidences = [];
169 | arrSampleMatch = [];
170 | sampleAP = 0;
171 |
172 | evaluationLog = ""
173 |
174 | pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,
175 | evaluationParams[
176 | 'CRLF'],
177 | evaluationParams[
178 | 'LTRB'],
179 | True, False)
180 | for n in range(len(pointsList)):
181 | points = pointsList[n]
182 | transcription = transcriptionsList[n]
183 | dontCare = transcription == "###"
184 | if evaluationParams['LTRB']:
185 | gtRect = Rectangle(*points)
186 | gtPol = rectangle_to_polygon(gtRect)
187 | else:
188 | gtPol = polygon_from_points(points)
189 | gtPols.append(gtPol)
190 | gtPolPoints.append(points)
191 | if dontCare:
192 | gtDontCarePolsNum.append(len(gtPols) - 1)
193 |
194 | evaluationLog += "GT polygons: " + str(len(gtPols)) + (
195 | " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
196 |
197 | if resFile in subm:
198 |
199 | detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile])
200 |
201 | pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,
202 | evaluationParams[
203 | 'CRLF'],
204 | evaluationParams[
205 | 'LTRB'],
206 | False,
207 | evaluationParams[
208 | 'CONFIDENCES'])
209 | for n in range(len(pointsList)):
210 | points = pointsList[n]
211 |
212 | if evaluationParams['LTRB']:
213 | detRect = Rectangle(*points)
214 | detPol = rectangle_to_polygon(detRect)
215 | else:
216 | detPol = polygon_from_points(points)
217 | detPols.append(detPol)
218 | detPolPoints.append(points)
219 | if len(gtDontCarePolsNum) > 0:
220 | for dontCarePol in gtDontCarePolsNum:
221 | dontCarePol = gtPols[dontCarePol]
222 | intersected_area = get_intersection(dontCarePol, detPol)
223 | pdDimensions = detPol.area()
224 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
225 | if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']):
226 | detDontCarePolsNum.append(len(detPols) - 1)
227 | break
228 |
229 | evaluationLog += "DET polygons: " + str(len(detPols)) + (
230 | " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
231 |
232 | if len(gtPols) > 0 and len(detPols) > 0:
233 | # Calculate IoU and precision matrixs
234 | outputShape = [len(gtPols), len(detPols)]
235 | iouMat = np.empty(outputShape)
236 | gtRectMat = np.zeros(len(gtPols), np.int8)
237 | detRectMat = np.zeros(len(detPols), np.int8)
238 | for gtNum in range(len(gtPols)):
239 | for detNum in range(len(detPols)):
240 | pG = gtPols[gtNum]
241 | pD = detPols[detNum]
242 | iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
243 |
244 | for gtNum in range(len(gtPols)):
245 | for detNum in range(len(detPols)):
246 | if gtRectMat[gtNum] == 0 and detRectMat[
247 | detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
248 | if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']:
249 | gtRectMat[gtNum] = 1
250 | detRectMat[detNum] = 1
251 | detMatched += 1
252 | pairs.append({'gt': gtNum, 'det': detNum})
253 | detMatchedNums.append(detNum)
254 | evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n"
255 |
256 | if evaluationParams['CONFIDENCES']:
257 | for detNum in range(len(detPols)):
258 | if detNum not in detDontCarePolsNum:
259 | # we exclude the don't care detections
260 | match = detNum in detMatchedNums
261 |
262 | arrSampleConfidences.append(confidencesList[detNum])
263 | arrSampleMatch.append(match)
264 |
265 | arrGlobalConfidences.append(confidencesList[detNum]);
266 | arrGlobalMatches.append(match);
267 |
268 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
269 | numDetCare = (len(detPols) - len(detDontCarePolsNum))
270 | if numGtCare == 0:
271 | recall = float(1)
272 | precision = float(0) if numDetCare > 0 else float(1)
273 | sampleAP = precision
274 | else:
275 | recall = float(detMatched) / numGtCare
276 | precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
277 | if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']:
278 | sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare)
279 |
280 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall)
281 |
282 | matchedSum += detMatched
283 | numGlobalCareGt += numGtCare
284 | numGlobalCareDet += numDetCare
285 |
286 | if evaluationParams['PER_SAMPLE_RESULTS']:
287 | perSampleMetrics[resFile] = {
288 | 'precision': precision,
289 | 'recall': recall,
290 | 'hmean': hmean,
291 | 'pairs': pairs,
292 | 'AP': sampleAP,
293 | 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
294 | 'gtPolPoints': gtPolPoints,
295 | 'detPolPoints': detPolPoints,
296 | 'gtDontCare': gtDontCarePolsNum,
297 | 'detDontCare': detDontCarePolsNum,
298 | 'evaluationParams': evaluationParams,
299 | 'evaluationLog': evaluationLog
300 | }
301 |
302 | # Compute MAP and MAR
303 | AP = 0
304 | if evaluationParams['CONFIDENCES']:
305 | AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
306 |
307 | methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt
308 | methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet
309 | methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
310 | methodRecall + methodPrecision)
311 |
312 | methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP}
313 |
314 | resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics}
315 |
316 | return resDict;
317 |
318 |
319 | def cal_recall_precison_f1(gt_path, result_path, show_result=False):
320 | p = {'g': gt_path, 's': result_path}
321 | result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, validate_data, evaluate_method,
322 | show_result)
323 | return result['method']
--------------------------------------------------------------------------------
/models/dcn/src/deform_pool_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017 Microsoft
3 | * Licensed under The MIT License [see LICENSE for details]
4 | * \file deformable_psroi_pooling.cu
5 | * \brief
6 | * \author Yi Li, Guodong Zhang, Jifeng Dai
7 | */
8 | /***************** Adapted by Charles Shang *********************/
9 | // modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | using namespace at;
18 |
19 | #define CUDA_KERNEL_LOOP(i, n) \
20 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
21 | i < (n); \
22 | i += blockDim.x * gridDim.x)
23 |
24 | const int CUDA_NUM_THREADS = 1024;
25 | inline int GET_BLOCKS(const int N)
26 | {
27 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
28 | }
29 |
30 | template
31 | __device__ scalar_t bilinear_interp(
32 | const scalar_t *data,
33 | const scalar_t x,
34 | const scalar_t y,
35 | const int width,
36 | const int height)
37 | {
38 | int x1 = floor(x);
39 | int x2 = ceil(x);
40 | int y1 = floor(y);
41 | int y2 = ceil(y);
42 | scalar_t dist_x = (scalar_t)(x - x1);
43 | scalar_t dist_y = (scalar_t)(y - y1);
44 | scalar_t value11 = data[y1 * width + x1];
45 | scalar_t value12 = data[y2 * width + x1];
46 | scalar_t value21 = data[y1 * width + x2];
47 | scalar_t value22 = data[y2 * width + x2];
48 | scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
49 | return value;
50 | }
51 |
52 | template
53 | __global__ void DeformablePSROIPoolForwardKernel(
54 | const int count,
55 | const scalar_t *bottom_data,
56 | const scalar_t spatial_scale,
57 | const int channels,
58 | const int height, const int width,
59 | const int pooled_height, const int pooled_width,
60 | const scalar_t *bottom_rois, const scalar_t *bottom_trans,
61 | const int no_trans,
62 | const scalar_t trans_std,
63 | const int sample_per_part,
64 | const int output_dim,
65 | const int group_size,
66 | const int part_size,
67 | const int num_classes,
68 | const int channels_each_class,
69 | scalar_t *top_data,
70 | scalar_t *top_count)
71 | {
72 | CUDA_KERNEL_LOOP(index, count)
73 | {
74 | // The output is in order (n, ctop, ph, pw)
75 | int pw = index % pooled_width;
76 | int ph = (index / pooled_width) % pooled_height;
77 | int ctop = (index / pooled_width / pooled_height) % output_dim;
78 | int n = index / pooled_width / pooled_height / output_dim;
79 |
80 | // [start, end) interval for spatial sampling
81 | const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
82 | int roi_batch_ind = offset_bottom_rois[0];
83 | scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
84 | scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
85 | scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
86 | scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
87 |
88 | // Force too small ROIs to be 1x1
89 | scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
90 | scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
91 |
92 | // Compute w and h at bottom
93 | scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
94 | scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
95 |
96 | scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
97 | scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
98 |
99 | int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
100 | int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
101 | int class_id = ctop / channels_each_class;
102 | scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
103 | scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
104 |
105 | scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
106 | wstart += trans_x * roi_width;
107 | scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
108 | hstart += trans_y * roi_height;
109 |
110 | scalar_t sum = 0;
111 | int count = 0;
112 | int gw = floor((scalar_t)(pw)*group_size / pooled_width);
113 | int gh = floor((scalar_t)(ph)*group_size / pooled_height);
114 | gw = min(max(gw, 0), group_size - 1);
115 | gh = min(max(gh, 0), group_size - 1);
116 |
117 | const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
118 | for (int ih = 0; ih < sample_per_part; ih++)
119 | {
120 | for (int iw = 0; iw < sample_per_part; iw++)
121 | {
122 | scalar_t w = wstart + iw * sub_bin_size_w;
123 | scalar_t h = hstart + ih * sub_bin_size_h;
124 | // bilinear interpolation
125 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
126 | {
127 | continue;
128 | }
129 | w = min(max(w, 0.), width - 1.);
130 | h = min(max(h, 0.), height - 1.);
131 | int c = (ctop * group_size + gh) * group_size + gw;
132 | scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
133 | sum += val;
134 | count++;
135 | }
136 | }
137 | top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
138 | top_count[index] = count;
139 | }
140 | }
141 |
142 | template
143 | __global__ void DeformablePSROIPoolBackwardAccKernel(
144 | const int count,
145 | const scalar_t *top_diff,
146 | const scalar_t *top_count,
147 | const int num_rois,
148 | const scalar_t spatial_scale,
149 | const int channels,
150 | const int height, const int width,
151 | const int pooled_height, const int pooled_width,
152 | const int output_dim,
153 | scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
154 | const scalar_t *bottom_data,
155 | const scalar_t *bottom_rois,
156 | const scalar_t *bottom_trans,
157 | const int no_trans,
158 | const scalar_t trans_std,
159 | const int sample_per_part,
160 | const int group_size,
161 | const int part_size,
162 | const int num_classes,
163 | const int channels_each_class)
164 | {
165 | CUDA_KERNEL_LOOP(index, count)
166 | {
167 | // The output is in order (n, ctop, ph, pw)
168 | int pw = index % pooled_width;
169 | int ph = (index / pooled_width) % pooled_height;
170 | int ctop = (index / pooled_width / pooled_height) % output_dim;
171 | int n = index / pooled_width / pooled_height / output_dim;
172 |
173 | // [start, end) interval for spatial sampling
174 | const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
175 | int roi_batch_ind = offset_bottom_rois[0];
176 | scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
177 | scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
178 | scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
179 | scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
180 |
181 | // Force too small ROIs to be 1x1
182 | scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
183 | scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
184 |
185 | // Compute w and h at bottom
186 | scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
187 | scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
188 |
189 | scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
190 | scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
191 |
192 | int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
193 | int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
194 | int class_id = ctop / channels_each_class;
195 | scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
196 | scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
197 |
198 | scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
199 | wstart += trans_x * roi_width;
200 | scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
201 | hstart += trans_y * roi_height;
202 |
203 | if (top_count[index] <= 0)
204 | {
205 | continue;
206 | }
207 | scalar_t diff_val = top_diff[index] / top_count[index];
208 | const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
209 | scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
210 | int gw = floor((scalar_t)(pw)*group_size / pooled_width);
211 | int gh = floor((scalar_t)(ph)*group_size / pooled_height);
212 | gw = min(max(gw, 0), group_size - 1);
213 | gh = min(max(gh, 0), group_size - 1);
214 |
215 | for (int ih = 0; ih < sample_per_part; ih++)
216 | {
217 | for (int iw = 0; iw < sample_per_part; iw++)
218 | {
219 | scalar_t w = wstart + iw * sub_bin_size_w;
220 | scalar_t h = hstart + ih * sub_bin_size_h;
221 | // bilinear interpolation
222 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
223 | {
224 | continue;
225 | }
226 | w = min(max(w, 0.), width - 1.);
227 | h = min(max(h, 0.), height - 1.);
228 | int c = (ctop * group_size + gh) * group_size + gw;
229 | // backward on feature
230 | int x0 = floor(w);
231 | int x1 = ceil(w);
232 | int y0 = floor(h);
233 | int y1 = ceil(h);
234 | scalar_t dist_x = w - x0, dist_y = h - y0;
235 | scalar_t q00 = (1 - dist_x) * (1 - dist_y);
236 | scalar_t q01 = (1 - dist_x) * dist_y;
237 | scalar_t q10 = dist_x * (1 - dist_y);
238 | scalar_t q11 = dist_x * dist_y;
239 | int bottom_index_base = c * height * width;
240 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
241 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
242 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
243 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
244 |
245 | if (no_trans)
246 | {
247 | continue;
248 | }
249 | scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
250 | scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
251 | scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
252 | scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
253 | scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
254 | diff_x *= roi_width;
255 | scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
256 | diff_y *= roi_height;
257 |
258 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
259 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
260 | }
261 | }
262 | }
263 | }
264 |
265 | void DeformablePSROIPoolForward(const at::Tensor data,
266 | const at::Tensor bbox,
267 | const at::Tensor trans,
268 | at::Tensor out,
269 | at::Tensor top_count,
270 | const int batch,
271 | const int channels,
272 | const int height,
273 | const int width,
274 | const int num_bbox,
275 | const int channels_trans,
276 | const int no_trans,
277 | const float spatial_scale,
278 | const int output_dim,
279 | const int group_size,
280 | const int pooled_size,
281 | const int part_size,
282 | const int sample_per_part,
283 | const float trans_std)
284 | {
285 | const int pooled_height = pooled_size;
286 | const int pooled_width = pooled_size;
287 | const int count = num_bbox * output_dim * pooled_height * pooled_width;
288 | const int num_classes = no_trans ? 1 : channels_trans / 2;
289 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
290 |
291 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
292 | data.type(), "deformable_psroi_pool_forward", ([&] {
293 | const scalar_t *bottom_data = data.data();
294 | const scalar_t *bottom_rois = bbox.data();
295 | const scalar_t *bottom_trans = no_trans ? NULL : trans.data();
296 | scalar_t *top_data = out.data();
297 | scalar_t *top_count_data = top_count.data();
298 |
299 | DeformablePSROIPoolForwardKernel<<>>(
300 | count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
301 | bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
302 | group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
303 | }));
304 |
305 | cudaError_t err = cudaGetLastError();
306 | if (err != cudaSuccess)
307 | {
308 | printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
309 | }
310 | }
311 |
312 | void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
313 | const at::Tensor data,
314 | const at::Tensor bbox,
315 | const at::Tensor trans,
316 | const at::Tensor top_count,
317 | at::Tensor in_grad,
318 | at::Tensor trans_grad,
319 | const int batch,
320 | const int channels,
321 | const int height,
322 | const int width,
323 | const int num_bbox,
324 | const int channels_trans,
325 | const int no_trans,
326 | const float spatial_scale,
327 | const int output_dim,
328 | const int group_size,
329 | const int pooled_size,
330 | const int part_size,
331 | const int sample_per_part,
332 | const float trans_std)
333 | {
334 | // LOG(INFO) << "DeformablePSROIPoolBackward";
335 | const int num_rois = num_bbox;
336 | const int pooled_height = pooled_size;
337 | const int pooled_width = pooled_size;
338 | const int count = num_bbox * output_dim * pooled_height * pooled_width;
339 | const int num_classes = no_trans ? 1 : channels_trans / 2;
340 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
341 |
342 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
343 | out_grad.type(), "deformable_psroi_pool_backward_acc", ([&] {
344 | const scalar_t *top_diff = out_grad.data();
345 | const scalar_t *bottom_data = data.data();
346 | const scalar_t *bottom_rois = bbox.data();
347 | const scalar_t *bottom_trans = no_trans ? NULL : trans.data();
348 | scalar_t *bottom_data_diff = in_grad.data();
349 | scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data();
350 | const scalar_t *top_count_data = top_count.data();
351 |
352 | DeformablePSROIPoolBackwardAccKernel<<>>(
353 | count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
354 | pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
355 | bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
356 | group_size, part_size, num_classes, channels_each_class);
357 | }));
358 |
359 | cudaError_t err = cudaGetLastError();
360 | if (err != cudaSuccess)
361 | {
362 | printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
363 | }
364 | }
--------------------------------------------------------------------------------