├── .gitattributes
├── .gitignore
├── LICENSE.md
├── README.MD
├── base
├── __init__.py
├── base_dataset.py
└── base_trainer.py
├── config
├── SynthText.yaml
├── SynthText_resnet18_FPN_DBhead_polyLR.yaml
├── icdar2015.yaml
├── icdar2015_dcn_resnet18_FPN_DBhead_polyLR.yaml
├── icdar2015_resnet18_FPN_DBhead_polyLR.yaml
├── icdar2015_resnet18_FPN_DBhead_polyLR_finetune.yaml
├── icdar2015_resnet50_FPN_DBhead_polyLR.yaml
├── open_dataset.yaml
├── open_dataset_dcn_resnet50_FPN_DBhead_polyLR.yaml
├── open_dataset_resnest50_FPN_DBhead_polyLR.yaml
└── open_dataset_resnet18_FPN_DBhead_polyLR.yaml
├── data_loader
├── __init__.py
├── dataset.py
└── modules
│ ├── __init__.py
│ ├── augment.py
│ ├── iaa_augment.py
│ ├── make_border_map.py
│ ├── make_shrink_map.py
│ └── random_crop_data.py
├── datasets
├── test.txt
├── test
│ ├── gt
│ │ └── README.MD
│ └── img
│ │ └── README.MD
├── train.txt
└── train
│ ├── gt
│ └── README.MD
│ └── img
│ └── README.MD
├── environment.yml
├── eval.sh
├── generate_lists.sh
├── imgs
└── paper
│ └── db.jpg
├── models
├── __init__.py
├── backbone
│ ├── MobilenetV3.py
│ ├── __init__.py
│ ├── resnest
│ │ ├── __init__.py
│ │ ├── ablation.py
│ │ ├── resnest.py
│ │ ├── resnet.py
│ │ └── splat.py
│ ├── resnet.py
│ └── shufflenetv2.py
├── basic.py
├── head
│ ├── ConvHead.py
│ ├── DBHead.py
│ └── __init__.py
├── losses
│ ├── DB_loss.py
│ ├── __init__.py
│ └── basic_loss.py
├── model.py
└── neck
│ ├── FPEM_FFM.py
│ ├── FPN.py
│ └── __init__.py
├── multi_gpu_train.sh
├── post_processing
├── __init__.py
└── seg_detector_representer.py
├── predict.sh
├── requirement.txt
├── singlel_gpu_train.sh
├── test
└── README.MD
├── tools
├── __init__.py
├── eval.py
├── predict.py
└── train.py
├── trainer
├── __init__.py
└── trainer.py
└── utils
├── __init__.py
├── cal_recall
├── __init__.py
├── rrc_evaluation_funcs.py
└── script.py
├── compute_mean_std.py
├── make_trainfile.py
├── metrics.py
├── ocr_metric
├── __init__.py
└── icdar2015
│ ├── __init__.py
│ ├── detection
│ ├── __init__.py
│ ├── deteval.py
│ ├── icdar2013.py
│ ├── iou.py
│ └── mtwi2018.py
│ └── quad_metric.py
├── schedulers.py
└── util.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.html linguist-language=python
2 | *.ipynb linguist-language=python
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.pth
3 | *.pyc
4 | *.pyo
5 | *.log
6 | *.tmp
7 | *.pkl
8 | __pycache__/
9 | .idea/
10 | output/
11 | test/*.jpg
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | # Real-time Scene Text Detection with Differentiable Binarization
2 |
3 | **note**: some code is inherited from [MhLiao/DB](https://github.com/MhLiao/DB)
4 |
5 | [中文解读](https://zhuanlan.zhihu.com/p/94677957)
6 |
7 | 
8 |
9 | ## update
10 | 2020-06-07: 添加灰度图训练,训练灰度图时需要在配置里移除`dataset.args.transforms.Normalize`
11 |
12 | ## Install Using Conda
13 | ```
14 | conda env create -f environment.yml
15 | git clone https://github.com/WenmuZhou/DBNet.pytorch.git
16 | cd DBNet.pytorch/
17 | ```
18 |
19 | or
20 | ## Install Manually
21 | ```bash
22 | conda create -n dbnet python=3.6
23 | conda activate dbnet
24 |
25 | conda install ipython pip
26 |
27 | # python dependencies
28 | pip install -r requirement.txt
29 |
30 | # install PyTorch with cuda-10.1
31 | # Note that you can change the cudatoolkit version to the version you want.
32 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
33 |
34 | # clone repo
35 | git clone https://github.com/WenmuZhou/DBNet.pytorch.git
36 | cd DBNet.pytorch/
37 |
38 | ```
39 |
40 | ## Requirements
41 | * pytorch 1.4+
42 | * torchvision 0.5+
43 | * gcc 4.9+
44 |
45 | ## Download
46 |
47 | TBD
48 |
49 | ## Data Preparation
50 |
51 | Training data: prepare a text `train.txt` in the following format, use '\t' as a separator
52 | ```
53 | ./datasets/train/img/001.jpg ./datasets/train/gt/001.txt
54 | ```
55 |
56 | Validation data: prepare a text `test.txt` in the following format, use '\t' as a separator
57 | ```
58 | ./datasets/test/img/001.jpg ./datasets/test/gt/001.txt
59 | ```
60 | - Store images in the `img` folder
61 | - Store groundtruth in the `gt` folder
62 |
63 | The groundtruth can be `.txt` files, with the following format:
64 | ```
65 | x1, y1, x2, y2, x3, y3, x4, y4, annotation
66 | ```
67 |
68 |
69 | ## Train
70 | 1. config the `dataset['train']['dataset'['data_path']'`,`dataset['validate']['dataset'['data_path']`in [config/icdar2015_resnet18_fpn_DBhead_polyLR.yaml](cconfig/icdar2015_resnet18_fpn_DBhead_polyLR.yaml)
71 | * . single gpu train
72 | ```bash
73 | bash singlel_gpu_train.sh
74 | ```
75 | * . Multi-gpu training
76 | ```bash
77 | bash multi_gpu_train.sh
78 | ```
79 | ## Test
80 |
81 | [eval.py](tools/eval.py) is used to test model on test dataset
82 |
83 | 1. config `model_path` in [eval.sh](eval.sh)
84 | 2. use following script to test
85 | ```bash
86 | bash eval.sh
87 | ```
88 |
89 | ## Predict
90 | [predict.py](tools/predict.py) Can be used to inference on all images in a folder
91 | 1. config `model_path`,`input_folder`,`output_folder` in [predict.sh](predict.sh)
92 | 2. use following script to predict
93 | ```
94 | bash predict.sh
95 | ```
96 | You can change the `model_path` in the `predict.sh` file to your model location.
97 |
98 | tips: if result is not good, you can change `thre` in [predict.sh](predict.sh)
99 |
100 | The project is still under development.
101 |
102 |
103 |
104 | ### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4)
105 | only train on ICDAR2015 dataset
106 |
107 | | Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS |
108 | |:--------------------------:|:-------:|:--------:|:--------:|:------------:|:---------------:|:-----:|
109 | | SynthText-Defrom-ResNet-18(paper) | 736 |0.007 | 86.8 | 78.4 | 82.3 | 48 |
110 | | ImageNet-resnet18-FPN-DBHead |736 |1e-3| 87.03 | 75.06 | 80.6 | 43 |
111 | | ImageNet-Defrom-Resnet18-FPN-DBHead |736 |1e-3| 88.61 | 73.84 | 80.56 | 36 |
112 | | ImageNet-resnet50-FPN-DBHead |736 |1e-3| 88.06 | 77.14 | 82.24 | 27 |
113 | | ImageNet-resnest50-FPN-DBHead |736 |1e-3| 88.18 | 76.27 | 81.78 | 27 |
114 |
115 |
116 | ### examples
117 | TBD
118 |
119 |
120 | ### todo
121 | - [x] mutil gpu training
122 |
123 | ### reference
124 | 1. https://arxiv.org/pdf/1911.08947.pdf
125 | 2. https://github.com/WenmuZhou/PANet.pytorch
126 | 3. https://github.com/MhLiao/DB
127 |
128 | **If this repository helps you,please star it. Thanks.**
129 |
--------------------------------------------------------------------------------
/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_trainer import BaseTrainer
2 | from .base_dataset import BaseDataSet
--------------------------------------------------------------------------------
/base/base_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/4 13:12
3 | # @Author : zhoujun
4 | import copy
5 | from torch.utils.data import Dataset
6 | from data_loader.modules import *
7 |
8 |
9 | class BaseDataSet(Dataset):
10 |
11 | def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None,
12 | target_transform=None):
13 | assert img_mode in ['RGB', 'BRG', 'GRAY']
14 | self.ignore_tags = ignore_tags
15 | self.data_list = self.load_data(data_path)
16 | item_keys = ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags']
17 | for item in item_keys:
18 | assert item in self.data_list[0], 'data_list from load_data must contains {}'.format(item_keys)
19 | self.img_mode = img_mode
20 | self.filter_keys = filter_keys
21 | self.transform = transform
22 | self.target_transform = target_transform
23 | self._init_pre_processes(pre_processes)
24 |
25 | def _init_pre_processes(self, pre_processes):
26 | self.aug = []
27 | if pre_processes is not None:
28 | for aug in pre_processes:
29 | if 'args' not in aug:
30 | args = {}
31 | else:
32 | args = aug['args']
33 | if isinstance(args, dict):
34 | cls = eval(aug['type'])(**args)
35 | else:
36 | cls = eval(aug['type'])(args)
37 | self.aug.append(cls)
38 |
39 | def load_data(self, data_path: str) -> list:
40 | """
41 | 把数据加载为一个list:
42 | :params data_path: 存储数据的文件夹或者文件
43 | return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
44 | """
45 | raise NotImplementedError
46 |
47 | def apply_pre_processes(self, data):
48 | for aug in self.aug:
49 | data = aug(data)
50 | return data
51 |
52 | def __getitem__(self, index):
53 | try:
54 | data = copy.deepcopy(self.data_list[index])
55 | im = cv2.imread(data['img_path'], 1 if self.img_mode != 'GRAY' else 0)
56 | if self.img_mode == 'RGB':
57 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
58 | data['img'] = im
59 | data['shape'] = [im.shape[0], im.shape[1]]
60 | data = self.apply_pre_processes(data)
61 |
62 | if self.transform:
63 | data['img'] = self.transform(data['img'])
64 | data['text_polys'] = data['text_polys'].tolist()
65 | if len(self.filter_keys):
66 | data_dict = {}
67 | for k, v in data.items():
68 | if k not in self.filter_keys:
69 | data_dict[k] = v
70 | return data_dict
71 | else:
72 | return data
73 | except:
74 | return self.__getitem__(np.random.randint(self.__len__()))
75 |
76 | def __len__(self):
77 | return len(self.data_list)
78 |
--------------------------------------------------------------------------------
/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:50
3 | # @Author : zhoujun
4 |
5 | import os
6 | import pathlib
7 | import shutil
8 | from pprint import pformat
9 |
10 | import anyconfig
11 | import torch
12 |
13 | from utils import setup_logger
14 |
15 |
16 | class BaseTrainer:
17 | def __init__(self, config, model, criterion):
18 | config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent),
19 | config['trainer']['output_dir'])
20 | config['name'] = config['name'] + '_' + model.name
21 | self.save_dir = os.path.join(config['trainer']['output_dir'], config['name'])
22 | self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
23 |
24 | if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '':
25 | shutil.rmtree(self.save_dir, ignore_errors=True)
26 | if not os.path.exists(self.checkpoint_dir):
27 | os.makedirs(self.checkpoint_dir)
28 |
29 | self.global_step = 0
30 | self.start_epoch = 0
31 | self.config = config
32 | self.model = model
33 | self.criterion = criterion
34 | # logger and tensorboard
35 | self.tensorboard_enable = self.config['trainer']['tensorboard']
36 | self.epochs = self.config['trainer']['epochs']
37 | self.log_iter = self.config['trainer']['log_iter']
38 | if config['local_rank'] == 0:
39 | anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
40 | self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
41 | self.logger_info(pformat(self.config))
42 |
43 | # device
44 | torch.manual_seed(self.config['trainer']['seed']) # 为CPU设置随机种子
45 | if torch.cuda.device_count() > 0 and torch.cuda.is_available():
46 | self.with_cuda = True
47 | torch.backends.cudnn.benchmark = True
48 | self.device = torch.device("cuda")
49 | torch.cuda.manual_seed(self.config['trainer']['seed']) # 为当前GPU设置随机种子
50 | torch.cuda.manual_seed_all(self.config['trainer']['seed']) # 为所有GPU设置随机种子
51 | else:
52 | self.with_cuda = False
53 | self.device = torch.device("cpu")
54 | self.logger_info('train with device {} and pytorch {}'.format(self.device, torch.__version__))
55 | # metrics
56 | self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf'),'best_model_epoch':0}
57 |
58 | self.optimizer = self._initialize('optimizer', torch.optim, model.parameters())
59 |
60 | # resume or finetune
61 | if self.config['trainer']['resume_checkpoint'] != '':
62 | self._load_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True)
63 | elif self.config['trainer']['finetune_checkpoint'] != '':
64 | self._load_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False)
65 |
66 | if self.config['lr_scheduler']['type'] != 'WarmupPolyLR':
67 | self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer)
68 |
69 | self.model.to(self.device)
70 |
71 | if self.tensorboard_enable and config['local_rank'] == 0:
72 | from torch.utils.tensorboard import SummaryWriter
73 | self.writer = SummaryWriter(self.save_dir)
74 | try:
75 | # add graph
76 | in_channels = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1
77 | dummy_input = torch.zeros(1, in_channels, 640, 640).to(self.device)
78 | self.writer.add_graph(self.model, dummy_input)
79 | torch.cuda.empty_cache()
80 | except:
81 | import traceback
82 | self.logger.error(traceback.format_exc())
83 | self.logger.warn('add graph to tensorboard failed')
84 | # 分布式训练
85 | if torch.cuda.device_count() > 1:
86 | local_rank = config['local_rank']
87 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False,
88 | find_unused_parameters=True)
89 | # make inverse Normalize
90 | self.UN_Normalize = False
91 | for t in self.config['dataset']['train']['dataset']['args']['transforms']:
92 | if t['type'] == 'Normalize':
93 | self.normalize_mean = t['args']['mean']
94 | self.normalize_std = t['args']['std']
95 | self.UN_Normalize = True
96 |
97 | def train(self):
98 | """
99 | Full training logic
100 | """
101 | for epoch in range(self.start_epoch + 1, self.epochs + 1):
102 | if self.config['distributed']:
103 | self.train_loader.sampler.set_epoch(epoch)
104 | self.epoch_result = self._train_epoch(epoch)
105 | if self.config['lr_scheduler']['type'] != 'WarmupPolyLR':
106 | self.scheduler.step()
107 | self._on_epoch_finish()
108 | if self.config['local_rank'] == 0 and self.tensorboard_enable:
109 | self.writer.close()
110 | self._on_train_finish()
111 |
112 | def _train_epoch(self, epoch):
113 | """
114 | Training logic for an epoch
115 |
116 | :param epoch: Current epoch number
117 | """
118 | raise NotImplementedError
119 |
120 | def _eval(self, epoch):
121 | """
122 | eval logic for an epoch
123 |
124 | :param epoch: Current epoch number
125 | """
126 | raise NotImplementedError
127 |
128 | def _on_epoch_finish(self):
129 | raise NotImplementedError
130 |
131 | def _on_train_finish(self):
132 | raise NotImplementedError
133 |
134 | def _save_checkpoint(self, epoch, file_name):
135 | """
136 | Saving checkpoints
137 |
138 | :param epoch: current epoch number
139 | :param log: logging information of the epoch
140 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
141 | """
142 | state_dict = self.model.module.state_dict() if self.config['distributed'] else self.model.state_dict()
143 | state = {
144 | 'epoch': epoch,
145 | 'global_step': self.global_step,
146 | 'state_dict': state_dict,
147 | 'optimizer': self.optimizer.state_dict(),
148 | 'scheduler': self.scheduler.state_dict(),
149 | 'config': self.config,
150 | 'metrics': self.metrics
151 | }
152 | filename = os.path.join(self.checkpoint_dir, file_name)
153 | torch.save(state, filename)
154 |
155 | def _load_checkpoint(self, checkpoint_path, resume):
156 | """
157 | Resume from saved checkpoints
158 | :param checkpoint_path: Checkpoint path to be resumed
159 | """
160 | self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
161 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
162 | self.model.load_state_dict(checkpoint['state_dict'], strict=resume)
163 | if resume:
164 | self.global_step = checkpoint['global_step']
165 | self.start_epoch = checkpoint['epoch']
166 | self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch
167 | # self.scheduler.load_state_dict(checkpoint['scheduler'])
168 | self.optimizer.load_state_dict(checkpoint['optimizer'])
169 | if 'metrics' in checkpoint:
170 | self.metrics = checkpoint['metrics']
171 | if self.with_cuda:
172 | for state in self.optimizer.state.values():
173 | for k, v in state.items():
174 | if isinstance(v, torch.Tensor):
175 | state[k] = v.to(self.device)
176 | self.logger_info("resume from checkpoint {} (epoch {})".format(checkpoint_path, self.start_epoch))
177 | else:
178 | self.logger_info("finetune from checkpoint {}".format(checkpoint_path))
179 |
180 | def _initialize(self, name, module, *args, **kwargs):
181 | module_name = self.config[name]['type']
182 | module_args = self.config[name]['args']
183 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
184 | module_args.update(kwargs)
185 | return getattr(module, module_name)(*args, **module_args)
186 |
187 | def inverse_normalize(self, batch_img):
188 | if self.UN_Normalize:
189 | batch_img[:, 0, :, :] = batch_img[:, 0, :, :] * self.normalize_std[0] + self.normalize_mean[0]
190 | batch_img[:, 1, :, :] = batch_img[:, 1, :, :] * self.normalize_std[1] + self.normalize_mean[1]
191 | batch_img[:, 2, :, :] = batch_img[:, 2, :, :] * self.normalize_std[2] + self.normalize_mean[2]
192 |
193 | def logger_info(self, s):
194 | if self.config['local_rank'] == 0:
195 | self.logger.info(s)
196 |
--------------------------------------------------------------------------------
/config/SynthText.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | dataset:
3 | train:
4 | dataset:
5 | type: SynthTextDataset # 数据集类型
6 | args:
7 | data_path: ''# SynthTextDataset 根目录
8 | pre_processes: # 数据的预处理过程,包含augment和标签制作
9 | - type: IaaAugment # 使用imgaug进行变换
10 | args:
11 | - {'type':Fliplr, 'args':{'p':0.5}}
12 | - {'type': Affine, 'args':{'rotate':[-10,10]}}
13 | - {'type':Resize,'args':{'size':[0.5,3]}}
14 | - type: EastRandomCropData
15 | args:
16 | size: [640,640]
17 | max_tries: 50
18 | keep_ratio: true
19 | - type: MakeBorderMap
20 | args:
21 | shrink_ratio: 0.4
22 | - type: MakeShrinkMap
23 | args:
24 | shrink_ratio: 0.4
25 | min_text_size: 8
26 | transforms: # 对图片进行的变换方式
27 | - type: ToTensor
28 | args: {}
29 | - type: Normalize
30 | args:
31 | mean: [0.485, 0.456, 0.406]
32 | std: [0.229, 0.224, 0.225]
33 | img_mode: RGB
34 | filter_keys: ['img_path','img_name','text_polys','texts','ignore_tags','shape'] # 返回数据之前,从数据字典里删除的key
35 | ignore_tags: ['*', '###']
36 | loader:
37 | batch_size: 1
38 | shuffle: true
39 | pin_memory: false
40 | num_workers: 0
41 | collate_fn: ''
--------------------------------------------------------------------------------
/config/SynthText_resnet18_FPN_DBhead_polyLR.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | base: ['config/SynthText.yaml']
3 | arch:
4 | type: Model
5 | backbone:
6 | type: resnet18
7 | pretrained: true
8 | neck:
9 | type: FPN
10 | inner_channels: 256
11 | head:
12 | type: DBHead
13 | out_channels: 2
14 | k: 50
15 | post_processing:
16 | type: SegDetectorRepresenter
17 | args:
18 | thresh: 0.3
19 | box_thresh: 0.7
20 | max_candidates: 1000
21 | unclip_ratio: 1.5 # from paper
22 | metric:
23 | type: QuadMetric
24 | args:
25 | is_output_polygon: false
26 | loss:
27 | type: DBLoss
28 | alpha: 1
29 | beta: 10
30 | ohem_ratio: 3
31 | optimizer:
32 | type: Adam
33 | args:
34 | lr: 0.001
35 | weight_decay: 0
36 | amsgrad: true
37 | lr_scheduler:
38 | type: WarmupPolyLR
39 | args:
40 | warmup_epoch: 3
41 | trainer:
42 | seed: 2
43 | epochs: 1200
44 | log_iter: 10
45 | show_images_iter: 50
46 | resume_checkpoint: ''
47 | finetune_checkpoint: ''
48 | output_dir: output
49 | tensorboard: true
50 | dataset:
51 | train:
52 | dataset:
53 | args:
54 | data_path: ./datasets/SynthText
55 | img_mode: RGB
56 | loader:
57 | batch_size: 2
58 | shuffle: true
59 | pin_memory: true
60 | num_workers: 6
61 | collate_fn: ''
--------------------------------------------------------------------------------
/config/icdar2015.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | dataset:
3 | train:
4 | dataset:
5 | type: ICDAR2015Dataset # 数据集类型
6 | args:
7 | data_path: # 一个存放 img_path \t gt_path的文件
8 | - ''
9 | pre_processes: # 数据的预处理过程,包含augment和标签制作
10 | - type: IaaAugment # 使用imgaug进行变换
11 | args:
12 | - {'type':Fliplr, 'args':{'p':0.5}}
13 | - {'type': Affine, 'args':{'rotate':[-10,10]}}
14 | - {'type':Resize,'args':{'size':[0.5,3]}}
15 | - type: EastRandomCropData
16 | args:
17 | size: [640,640]
18 | max_tries: 50
19 | keep_ratio: true
20 | - type: MakeBorderMap
21 | args:
22 | shrink_ratio: 0.4
23 | thresh_min: 0.3
24 | thresh_max: 0.7
25 | - type: MakeShrinkMap
26 | args:
27 | shrink_ratio: 0.4
28 | min_text_size: 8
29 | transforms: # 对图片进行的变换方式
30 | - type: ToTensor
31 | args: {}
32 | - type: Normalize
33 | args:
34 | mean: [0.485, 0.456, 0.406]
35 | std: [0.229, 0.224, 0.225]
36 | img_mode: RGB
37 | filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key
38 | ignore_tags: ['*', '###']
39 | loader:
40 | batch_size: 1
41 | shuffle: true
42 | pin_memory: false
43 | num_workers: 0
44 | collate_fn: ''
45 | validate:
46 | dataset:
47 | type: ICDAR2015Dataset
48 | args:
49 | data_path:
50 | - ''
51 | pre_processes:
52 | - type: ResizeShortSize
53 | args:
54 | short_size: 736
55 | resize_text_polys: false
56 | transforms:
57 | - type: ToTensor
58 | args: {}
59 | - type: Normalize
60 | args:
61 | mean: [0.485, 0.456, 0.406]
62 | std: [0.229, 0.224, 0.225]
63 | img_mode: RGB
64 | filter_keys: []
65 | ignore_tags: ['*', '###']
66 | loader:
67 | batch_size: 1
68 | shuffle: true
69 | pin_memory: false
70 | num_workers: 0
71 | collate_fn: ICDARCollectFN
--------------------------------------------------------------------------------
/config/icdar2015_dcn_resnet18_FPN_DBhead_polyLR.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | base: ['config/icdar2015.yaml']
3 | arch:
4 | type: Model
5 | backbone:
6 | type: deformable_resnet18
7 | pretrained: true
8 | neck:
9 | type: FPN
10 | inner_channels: 256
11 | head:
12 | type: DBHead
13 | out_channels: 2
14 | k: 50
15 | post_processing:
16 | type: SegDetectorRepresenter
17 | args:
18 | thresh: 0.3
19 | box_thresh: 0.7
20 | max_candidates: 1000
21 | unclip_ratio: 1.5 # from paper
22 | metric:
23 | type: QuadMetric
24 | args:
25 | is_output_polygon: false
26 | loss:
27 | type: DBLoss
28 | alpha: 1
29 | beta: 10
30 | ohem_ratio: 3
31 | optimizer:
32 | type: Adam
33 | args:
34 | lr: 0.001
35 | weight_decay: 0
36 | amsgrad: true
37 | lr_scheduler:
38 | type: WarmupPolyLR
39 | args:
40 | warmup_epoch: 3
41 | trainer:
42 | seed: 2
43 | epochs: 1200
44 | log_iter: 10
45 | show_images_iter: 50
46 | resume_checkpoint: ''
47 | finetune_checkpoint: ''
48 | output_dir: output
49 | tensorboard: true
50 | dataset:
51 | train:
52 | dataset:
53 | args:
54 | data_path:
55 | - ./datasets/train.txt
56 | img_mode: RGB
57 | loader:
58 | batch_size: 1
59 | shuffle: true
60 | pin_memory: true
61 | num_workers: 6
62 | collate_fn: ''
63 | validate:
64 | dataset:
65 | args:
66 | data_path:
67 | - ./datasets/test.txt
68 | pre_processes:
69 | - type: ResizeShortSize
70 | args:
71 | short_size: 736
72 | resize_text_polys: false
73 | img_mode: RGB
74 | loader:
75 | batch_size: 1
76 | shuffle: true
77 | pin_memory: false
78 | num_workers: 6
79 | collate_fn: ICDARCollectFN
--------------------------------------------------------------------------------
/config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | base: ['config/icdar2015.yaml']
3 | arch:
4 | type: Model
5 | backbone:
6 | type: resnet18
7 | pretrained: true
8 | neck:
9 | type: FPN
10 | inner_channels: 256
11 | head:
12 | type: DBHead
13 | out_channels: 2
14 | k: 50
15 | post_processing:
16 | type: SegDetectorRepresenter
17 | args:
18 | thresh: 0.3
19 | box_thresh: 0.7
20 | max_candidates: 1000
21 | unclip_ratio: 1.5 # from paper
22 | metric:
23 | type: QuadMetric
24 | args:
25 | is_output_polygon: false
26 | loss:
27 | type: DBLoss
28 | alpha: 1
29 | beta: 10
30 | ohem_ratio: 3
31 | optimizer:
32 | type: Adam
33 | args:
34 | lr: 0.001
35 | weight_decay: 0
36 | amsgrad: true
37 | lr_scheduler:
38 | type: WarmupPolyLR
39 | args:
40 | warmup_epoch: 3
41 | trainer:
42 | seed: 2
43 | epochs: 1200
44 | log_iter: 10
45 | show_images_iter: 50
46 | resume_checkpoint: ''
47 | finetune_checkpoint: ''
48 | output_dir: output
49 | tensorboard: true
50 | dataset:
51 | train:
52 | dataset:
53 | args:
54 | data_path:
55 | - ./datasets/train.txt
56 | img_mode: RGB
57 | loader:
58 | batch_size: 1
59 | shuffle: true
60 | pin_memory: true
61 | num_workers: 6
62 | collate_fn: ''
63 | validate:
64 | dataset:
65 | args:
66 | data_path:
67 | - ./datasets/test.txt
68 | pre_processes:
69 | - type: ResizeShortSize
70 | args:
71 | short_size: 736
72 | resize_text_polys: false
73 | img_mode: RGB
74 | loader:
75 | batch_size: 1
76 | shuffle: true
77 | pin_memory: false
78 | num_workers: 6
79 | collate_fn: ICDARCollectFN
80 |
--------------------------------------------------------------------------------
/config/icdar2015_resnet18_FPN_DBhead_polyLR_finetune.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | base: ['config/icdar2015.yaml']
3 | arch:
4 | type: Model
5 | backbone:
6 | type: resnet18
7 | pretrained: true
8 | neck:
9 | type: FPN
10 | inner_channels: 256
11 | head:
12 | type: DBHead
13 | out_channels: 2
14 | k: 50
15 | post_processing:
16 | type: SegDetectorRepresenter
17 | args:
18 | thresh: 0.3
19 | box_thresh: 0.7
20 | max_candidates: 1000
21 | unclip_ratio: 1.5 # from paper
22 | metric:
23 | type: QuadMetric
24 | args:
25 | is_output_polygon: false
26 | loss:
27 | type: DBLoss
28 | alpha: 1
29 | beta: 10
30 | ohem_ratio: 3
31 | optimizer:
32 | type: Adam
33 | args:
34 | lr: 0.001
35 | weight_decay: 0
36 | amsgrad: true
37 | lr_scheduler:
38 | type: StepLR
39 | args:
40 | step_size: 10
41 | gama: 0.8
42 | trainer:
43 | seed: 2
44 | epochs: 500
45 | log_iter: 10
46 | show_images_iter: 50
47 | resume_checkpoint: ''
48 | finetune_checkpoint: ''
49 | output_dir: output
50 | tensorboard: true
51 | dataset:
52 | train:
53 | dataset:
54 | args:
55 | data_path:
56 | - ./datasets/train.txt
57 | img_mode: RGB
58 | loader:
59 | batch_size: 1
60 | shuffle: true
61 | pin_memory: true
62 | num_workers: 6
63 | collate_fn: ''
64 | validate:
65 | dataset:
66 | args:
67 | data_path:
68 | - ./datasets/test.txt
69 | pre_processes:
70 | - type: ResizeShortSize
71 | args:
72 | short_size: 736
73 | resize_text_polys: false
74 | img_mode: RGB
75 | loader:
76 | batch_size: 1
77 | shuffle: true
78 | pin_memory: false
79 | num_workers: 6
80 | collate_fn: ICDARCollectFN
81 |
--------------------------------------------------------------------------------
/config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | base: ['config/icdar2015.yaml']
3 | arch:
4 | type: Model
5 | backbone:
6 | type: resnet50
7 | pretrained: true
8 | neck:
9 | type: FPN
10 | inner_channels: 256
11 | head:
12 | type: DBHead
13 | out_channels: 2
14 | k: 50
15 | post_processing:
16 | type: SegDetectorRepresenter
17 | args:
18 | thresh: 0.3
19 | box_thresh: 0.7
20 | max_candidates: 1000
21 | unclip_ratio: 1.5 # from paper
22 | metric:
23 | type: QuadMetric
24 | args:
25 | is_output_polygon: false
26 | loss:
27 | type: DBLoss
28 | alpha: 1
29 | beta: 10
30 | ohem_ratio: 3
31 | optimizer:
32 | type: Adam
33 | args:
34 | lr: 0.001
35 | weight_decay: 0
36 | amsgrad: true
37 | lr_scheduler:
38 | type: WarmupPolyLR
39 | args:
40 | warmup_epoch: 3
41 | trainer:
42 | seed: 2
43 | epochs: 1200
44 | log_iter: 10
45 | show_images_iter: 50
46 | resume_checkpoint: ''
47 | finetune_checkpoint: ''
48 | output_dir: output
49 | tensorboard: true
50 | dataset:
51 | train:
52 | dataset:
53 | args:
54 | data_path:
55 | - ./datasets/train.txt
56 | img_mode: RGB
57 | loader:
58 | batch_size: 16
59 | shuffle: true
60 | pin_memory: true
61 | num_workers: 6
62 | collate_fn: ''
63 | validate:
64 | dataset:
65 | args:
66 | data_path:
67 | - ./datasets/test.txt
68 | pre_processes:
69 | - type: ResizeShortSize
70 | args:
71 | short_size: 736
72 | resize_text_polys: false
73 | img_mode: RGB
74 | loader:
75 | batch_size: 1
76 | shuffle: true
77 | pin_memory: false
78 | num_workers: 6
79 | collate_fn: ICDARCollectFN
80 |
--------------------------------------------------------------------------------
/config/open_dataset.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | dataset:
3 | train:
4 | dataset:
5 | type: DetDataset # 数据集类型
6 | args:
7 | data_path: # 一个存放 img_path \t gt_path的文件
8 | - ''
9 | pre_processes: # 数据的预处理过程,包含augment和标签制作
10 | - type: IaaAugment # 使用imgaug进行变换
11 | args:
12 | - {'type':Fliplr, 'args':{'p':0.5}}
13 | - {'type': Affine, 'args':{'rotate':[-10,10]}}
14 | - {'type':Resize,'args':{'size':[0.5,3]}}
15 | - type: EastRandomCropData
16 | args:
17 | size: [640,640]
18 | max_tries: 50
19 | keep_ratio: true
20 | - type: MakeBorderMap
21 | args:
22 | shrink_ratio: 0.4
23 | thresh_min: 0.3
24 | thresh_max: 0.7
25 | - type: MakeShrinkMap
26 | args:
27 | shrink_ratio: 0.4
28 | min_text_size: 8
29 | transforms: # 对图片进行的变换方式
30 | - type: ToTensor
31 | args: {}
32 | - type: Normalize
33 | args:
34 | mean: [0.485, 0.456, 0.406]
35 | std: [0.229, 0.224, 0.225]
36 | img_mode: RGB
37 | load_char_annotation: false
38 | expand_one_char: false
39 | filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key
40 | ignore_tags: ['*', '###']
41 | loader:
42 | batch_size: 1
43 | shuffle: true
44 | pin_memory: false
45 | num_workers: 0
46 | collate_fn: ''
47 | validate:
48 | dataset:
49 | type: DetDataset
50 | args:
51 | data_path:
52 | - ''
53 | pre_processes:
54 | - type: ResizeShortSize
55 | args:
56 | short_size: 736
57 | resize_text_polys: false
58 | transforms:
59 | - type: ToTensor
60 | args: {}
61 | - type: Normalize
62 | args:
63 | mean: [0.485, 0.456, 0.406]
64 | std: [0.229, 0.224, 0.225]
65 | img_mode: RGB
66 | load_char_annotation: false # 是否加载字符级标注
67 | expand_one_char: false # 是否对只有一个字符的框进行宽度扩充,扩充后w = w+h
68 | filter_keys: []
69 | ignore_tags: ['*', '###']
70 | loader:
71 | batch_size: 1
72 | shuffle: true
73 | pin_memory: false
74 | num_workers: 0
75 | collate_fn: ICDARCollectFN
--------------------------------------------------------------------------------
/config/open_dataset_dcn_resnet50_FPN_DBhead_polyLR.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | base: ['config/open_dataset.yaml']
3 | arch:
4 | type: Model
5 | backbone:
6 | type: deformable_resnet18
7 | pretrained: true
8 | neck:
9 | type: FPN
10 | inner_channels: 256
11 | head:
12 | type: DBHead
13 | out_channels: 2
14 | k: 50
15 | post_processing:
16 | type: SegDetectorRepresenter
17 | args:
18 | thresh: 0.3
19 | box_thresh: 0.7
20 | max_candidates: 1000
21 | unclip_ratio: 1.5 # from paper
22 | metric:
23 | type: QuadMetric
24 | args:
25 | is_output_polygon: false
26 | loss:
27 | type: DBLoss
28 | alpha: 1
29 | beta: 10
30 | ohem_ratio: 3
31 | optimizer:
32 | type: Adam
33 | args:
34 | lr: 0.001
35 | weight_decay: 0
36 | amsgrad: true
37 | lr_scheduler:
38 | type: WarmupPolyLR
39 | args:
40 | warmup_epoch: 3
41 | trainer:
42 | seed: 2
43 | epochs: 1200
44 | log_iter: 1
45 | show_images_iter: 1
46 | resume_checkpoint: ''
47 | finetune_checkpoint: ''
48 | output_dir: output
49 | tensorboard: true
50 | dataset:
51 | train:
52 | dataset:
53 | args:
54 | data_path:
55 | - ./datasets/train.json
56 | img_mode: RGB
57 | load_char_annotation: false
58 | expand_one_char: false
59 | loader:
60 | batch_size: 2
61 | shuffle: true
62 | pin_memory: true
63 | num_workers: 6
64 | collate_fn: ''
65 | validate:
66 | dataset:
67 | args:
68 | data_path:
69 | - ./datasets/test.json
70 | pre_processes:
71 | - type: ResizeShortSize
72 | args:
73 | short_size: 736
74 | resize_text_polys: false
75 | img_mode: RGB
76 | load_char_annotation: false
77 | expand_one_char: false
78 | loader:
79 | batch_size: 1
80 | shuffle: true
81 | pin_memory: false
82 | num_workers: 6
83 | collate_fn: ICDARCollectFN
84 |
--------------------------------------------------------------------------------
/config/open_dataset_resnest50_FPN_DBhead_polyLR.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | base: ['config/open_dataset.yaml']
3 | arch:
4 | type: Model
5 | backbone:
6 | type: resnest50
7 | pretrained: true
8 | neck:
9 | type: FPN
10 | inner_channels: 256
11 | head:
12 | type: DBHead
13 | out_channels: 2
14 | k: 50
15 | post_processing:
16 | type: SegDetectorRepresenter
17 | args:
18 | thresh: 0.3
19 | box_thresh: 0.7
20 | max_candidates: 1000
21 | unclip_ratio: 1.5 # from paper
22 | metric:
23 | type: QuadMetric
24 | args:
25 | is_output_polygon: false
26 | loss:
27 | type: DBLoss
28 | alpha: 1
29 | beta: 10
30 | ohem_ratio: 3
31 | optimizer:
32 | type: Adam
33 | args:
34 | lr: 0.001
35 | weight_decay: 0
36 | amsgrad: true
37 | lr_scheduler:
38 | type: WarmupPolyLR
39 | args:
40 | warmup_epoch: 3
41 | trainer:
42 | seed: 2
43 | epochs: 1200
44 | log_iter: 1
45 | show_images_iter: 1
46 | resume_checkpoint: ''
47 | finetune_checkpoint: ''
48 | output_dir: output
49 | tensorboard: true
50 | dataset:
51 | train:
52 | dataset:
53 | args:
54 | data_path:
55 | - ./datasets/train.json
56 | img_mode: RGB
57 | load_char_annotation: false
58 | expand_one_char: false
59 | loader:
60 | batch_size: 2
61 | shuffle: true
62 | pin_memory: true
63 | num_workers: 6
64 | collate_fn: ''
65 | validate:
66 | dataset:
67 | args:
68 | data_path:
69 | - ./datasets/test.json
70 | pre_processes:
71 | - type: ResizeShortSize
72 | args:
73 | short_size: 736
74 | resize_text_polys: false
75 | img_mode: RGB
76 | load_char_annotation: false
77 | expand_one_char: false
78 | loader:
79 | batch_size: 1
80 | shuffle: true
81 | pin_memory: false
82 | num_workers: 6
83 | collate_fn: ICDARCollectFN
84 |
--------------------------------------------------------------------------------
/config/open_dataset_resnet18_FPN_DBhead_polyLR.yaml:
--------------------------------------------------------------------------------
1 | name: DBNet
2 | base: ['config/open_dataset.yaml']
3 | arch:
4 | type: Model
5 | backbone:
6 | type: resnet18
7 | pretrained: true
8 | neck:
9 | type: FPN
10 | inner_channels: 256
11 | head:
12 | type: DBHead
13 | out_channels: 2
14 | k: 50
15 | post_processing:
16 | type: SegDetectorRepresenter
17 | args:
18 | thresh: 0.3
19 | box_thresh: 0.7
20 | max_candidates: 1000
21 | unclip_ratio: 1.5 # from paper
22 | metric:
23 | type: QuadMetric
24 | args:
25 | is_output_polygon: false
26 | loss:
27 | type: DBLoss
28 | alpha: 1
29 | beta: 10
30 | ohem_ratio: 3
31 | optimizer:
32 | type: Adam
33 | args:
34 | lr: 0.001
35 | weight_decay: 0
36 | amsgrad: true
37 | lr_scheduler:
38 | type: WarmupPolyLR
39 | args:
40 | warmup_epoch: 3
41 | trainer:
42 | seed: 2
43 | epochs: 1200
44 | log_iter: 1
45 | show_images_iter: 1
46 | resume_checkpoint: ''
47 | finetune_checkpoint: ''
48 | output_dir: output
49 | tensorboard: true
50 | dataset:
51 | train:
52 | dataset:
53 | args:
54 | data_path:
55 | - ./datasets/train.json
56 | transforms: # 对图片进行的变换方式
57 | - type: ToTensor
58 | args: {}
59 | - type: Normalize
60 | args:
61 | mean: [0.485, 0.456, 0.406]
62 | std: [0.229, 0.224, 0.225]
63 | img_mode: RGB
64 | load_char_annotation: false
65 | expand_one_char: false
66 | loader:
67 | batch_size: 2
68 | shuffle: true
69 | pin_memory: true
70 | num_workers: 6
71 | collate_fn: ''
72 | validate:
73 | dataset:
74 | args:
75 | data_path:
76 | - ./datasets/test.json
77 | pre_processes:
78 | - type: ResizeShortSize
79 | args:
80 | short_size: 736
81 | resize_text_polys: false
82 | img_mode: RGB
83 | load_char_annotation: false
84 | expand_one_char: false
85 | loader:
86 | batch_size: 1
87 | shuffle: true
88 | pin_memory: false
89 | num_workers: 6
90 | collate_fn: ICDARCollectFN
91 |
--------------------------------------------------------------------------------
/data_loader/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:52
3 | # @Author : zhoujun
4 | import copy
5 |
6 | import PIL
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import DataLoader
10 | from torchvision import transforms
11 |
12 |
13 | def get_dataset(data_path, module_name, transform, dataset_args):
14 | """
15 | 获取训练dataset
16 | :param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
17 | :param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset
18 | :param transform: 该数据集使用的transforms
19 | :param dataset_args: module_name的参数
20 | :return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None
21 | """
22 | from . import dataset
23 | s_dataset = getattr(dataset, module_name)(transform=transform, data_path=data_path,
24 | **dataset_args)
25 | return s_dataset
26 |
27 |
28 | def get_transforms(transforms_config):
29 | tr_list = []
30 | for item in transforms_config:
31 | if 'args' not in item:
32 | args = {}
33 | else:
34 | args = item['args']
35 | cls = getattr(transforms, item['type'])(**args)
36 | tr_list.append(cls)
37 | tr_list = transforms.Compose(tr_list)
38 | return tr_list
39 |
40 |
41 | class ICDARCollectFN:
42 | def __init__(self, *args, **kwargs):
43 | pass
44 |
45 | def __call__(self, batch):
46 | data_dict = {}
47 | to_tensor_keys = []
48 | for sample in batch:
49 | for k, v in sample.items():
50 | if k not in data_dict:
51 | data_dict[k] = []
52 | if isinstance(v, (np.ndarray, torch.Tensor, PIL.Image.Image)):
53 | if k not in to_tensor_keys:
54 | to_tensor_keys.append(k)
55 | data_dict[k].append(v)
56 | for k in to_tensor_keys:
57 | data_dict[k] = torch.stack(data_dict[k], 0)
58 | return data_dict
59 |
60 |
61 | def get_dataloader(module_config, distributed=False):
62 | if module_config is None:
63 | return None
64 | config = copy.deepcopy(module_config)
65 | dataset_args = config['dataset']['args']
66 | if 'transforms' in dataset_args:
67 | img_transfroms = get_transforms(dataset_args.pop('transforms'))
68 | else:
69 | img_transfroms = None
70 | # 创建数据集
71 | dataset_name = config['dataset']['type']
72 | data_path = dataset_args.pop('data_path')
73 | if data_path == None:
74 | return None
75 |
76 | data_path = [x for x in data_path if x is not None]
77 | if len(data_path) == 0:
78 | return None
79 | if 'collate_fn' not in config['loader'] or config['loader']['collate_fn'] is None or len(config['loader']['collate_fn']) == 0:
80 | config['loader']['collate_fn'] = None
81 | else:
82 | config['loader']['collate_fn'] = eval(config['loader']['collate_fn'])()
83 |
84 | _dataset = get_dataset(data_path=data_path, module_name=dataset_name, transform=img_transfroms, dataset_args=dataset_args)
85 | sampler = None
86 | if distributed:
87 | from torch.utils.data.distributed import DistributedSampler
88 | # 3)使用DistributedSampler
89 | sampler = DistributedSampler(_dataset)
90 | config['loader']['shuffle'] = False
91 | config['loader']['pin_memory'] = True
92 | loader = DataLoader(dataset=_dataset, sampler=sampler, **config['loader'])
93 | return loader
94 |
--------------------------------------------------------------------------------
/data_loader/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:54
3 | # @Author : zhoujun
4 | import pathlib
5 | import os
6 | import cv2
7 | import numpy as np
8 | import scipy.io as sio
9 | from tqdm.auto import tqdm
10 |
11 | from base import BaseDataSet
12 | from utils import order_points_clockwise, get_datalist, load,expand_polygon
13 |
14 |
15 | class ICDAR2015Dataset(BaseDataSet):
16 | def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None, **kwargs):
17 | super().__init__(data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform)
18 |
19 | def load_data(self, data_path: str) -> list:
20 | data_list = get_datalist(data_path)
21 | t_data_list = []
22 | for img_path, label_path in data_list:
23 | data = self._get_annotation(label_path)
24 | if len(data['text_polys']) > 0:
25 | item = {'img_path': img_path, 'img_name': pathlib.Path(img_path).stem}
26 | item.update(data)
27 | t_data_list.append(item)
28 | else:
29 | print('there is no suit bbox in {}'.format(label_path))
30 | return t_data_list
31 |
32 | def _get_annotation(self, label_path: str) -> dict:
33 | boxes = []
34 | texts = []
35 | ignores = []
36 | with open(label_path, encoding='utf-8', mode='r') as f:
37 | for line in f.readlines():
38 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',')
39 | try:
40 | box = order_points_clockwise(np.array(list(map(float, params[:8]))).reshape(-1, 2))
41 | if cv2.contourArea(box) > 0:
42 | boxes.append(box)
43 | label = params[8]
44 | texts.append(label)
45 | ignores.append(label in self.ignore_tags)
46 | except:
47 | print('load label failed on {}'.format(label_path))
48 | data = {
49 | 'text_polys': np.array(boxes),
50 | 'texts': texts,
51 | 'ignore_tags': ignores,
52 | }
53 | return data
54 |
55 |
56 | class DetDataset(BaseDataSet):
57 | def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None, **kwargs):
58 | self.load_char_annotation = kwargs['load_char_annotation']
59 | self.expand_one_char = kwargs['expand_one_char']
60 | super().__init__(data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform)
61 |
62 | def load_data(self, data_path: str) -> list:
63 | """
64 | 从json文件中读取出 文本行的坐标和gt,字符的坐标和gt
65 | :param data_path:
66 | :return:
67 | """
68 | data_list = []
69 | for path in data_path:
70 | content = load(path)
71 | for gt in tqdm(content['data_list'], desc='read file {}'.format(path)):
72 | img_path = os.path.join(content['data_root'], gt['img_name'])
73 | polygons = []
74 | texts = []
75 | illegibility_list = []
76 | language_list = []
77 | for annotation in gt['annotations']:
78 | if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
79 | continue
80 | if len(annotation['text']) > 1 and self.expand_one_char:
81 | annotation['polygon'] = expand_polygon(annotation['polygon'])
82 | polygons.append(annotation['polygon'])
83 | texts.append(annotation['text'])
84 | illegibility_list.append(annotation['illegibility'])
85 | language_list.append(annotation['language'])
86 | if self.load_char_annotation:
87 | for char_annotation in annotation['chars']:
88 | if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
89 | continue
90 | polygons.append(char_annotation['polygon'])
91 | texts.append(char_annotation['char'])
92 | illegibility_list.append(char_annotation['illegibility'])
93 | language_list.append(char_annotation['language'])
94 | data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': np.array(polygons),
95 | 'texts': texts, 'ignore_tags': illegibility_list})
96 | return data_list
97 |
98 |
99 | class SynthTextDataset(BaseDataSet):
100 | def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, transform=None, **kwargs):
101 | self.transform = transform
102 | self.dataRoot = pathlib.Path(data_path)
103 | if not self.dataRoot.exists():
104 | raise FileNotFoundError('Dataset folder is not exist.')
105 |
106 | self.targetFilePath = self.dataRoot / 'gt.mat'
107 | if not self.targetFilePath.exists():
108 | raise FileExistsError('Target file is not exist.')
109 | targets = {}
110 | sio.loadmat(self.targetFilePath, targets, squeeze_me=True, struct_as_record=False,
111 | variable_names=['imnames', 'wordBB', 'txt'])
112 |
113 | self.imageNames = targets['imnames']
114 | self.wordBBoxes = targets['wordBB']
115 | self.transcripts = targets['txt']
116 | super().__init__(data_path, img_mode, pre_processes, filter_keys, transform)
117 |
118 | def load_data(self, data_path: str) -> list:
119 | t_data_list = []
120 | for imageName, wordBBoxes, texts in zip(self.imageNames, self.wordBBoxes, self.transcripts):
121 | item = {}
122 | wordBBoxes = np.expand_dims(wordBBoxes, axis=2) if (wordBBoxes.ndim == 2) else wordBBoxes
123 | _, _, numOfWords = wordBBoxes.shape
124 | text_polys = wordBBoxes.reshape([8, numOfWords], order='F').T # num_words * 8
125 | text_polys = text_polys.reshape(numOfWords, 4, 2) # num_of_words * 4 * 2
126 | transcripts = [word for line in texts for word in line.split()]
127 | if numOfWords != len(transcripts):
128 | continue
129 | item['img_path'] = str(self.dataRoot / imageName)
130 | item['img_name'] = (self.dataRoot / imageName).stem
131 | item['text_polys'] = text_polys
132 | item['texts'] = transcripts
133 | item['ignore_tags'] = [x in self.ignore_tags for x in transcripts]
134 | t_data_list.append(item)
135 | return t_data_list
136 |
137 |
138 | if __name__ == '__main__':
139 | import torch
140 | import anyconfig
141 | from torch.utils.data import DataLoader
142 | from torchvision import transforms
143 |
144 | from utils import parse_config, show_img, plt, draw_bbox
145 |
146 | config = anyconfig.load('config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml')
147 | config = parse_config(config)
148 | dataset_args = config['dataset']['train']['dataset']['args']
149 | # dataset_args.pop('data_path')
150 | # data_list = [(r'E:/zj/dataset/icdar2015/train/img/img_15.jpg', 'E:/zj/dataset/icdar2015/train/gt/gt_img_15.txt')]
151 | train_data = ICDAR2015Dataset(data_path=dataset_args.pop('data_path'), transform=transforms.ToTensor(),
152 | **dataset_args)
153 | train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=True, num_workers=0)
154 | for i, data in enumerate(tqdm(train_loader)):
155 | # img = data['img']
156 | # shrink_label = data['shrink_map']
157 | # threshold_label = data['threshold_map']
158 | #
159 | # print(threshold_label.shape, threshold_label.shape, img.shape)
160 | # show_img(img[0].numpy().transpose(1, 2, 0), title='img')
161 | # show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label')
162 | # show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label')
163 | # img = draw_bbox(img[0].numpy().transpose(1, 2, 0),np.array(data['text_polys']))
164 | # show_img(img, title='draw_bbox')
165 | # plt.show()
166 | pass
167 |
--------------------------------------------------------------------------------
/data_loader/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/4 10:53
3 | # @Author : zhoujun
4 | from .iaa_augment import IaaAugment
5 | from .augment import *
6 | from .random_crop_data import EastRandomCropData,PSERandomCrop
7 | from .make_border_map import MakeBorderMap
8 | from .make_shrink_map import MakeShrinkMap
9 |
--------------------------------------------------------------------------------
/data_loader/modules/augment.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:52
3 | # @Author : zhoujun
4 |
5 | import math
6 | import numbers
7 | import random
8 |
9 | import cv2
10 | import numpy as np
11 | from skimage.util import random_noise
12 |
13 |
14 | class RandomNoise:
15 | def __init__(self, random_rate):
16 | self.random_rate = random_rate
17 |
18 | def __call__(self, data: dict):
19 | """
20 | 对图片加噪声
21 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
22 | :return:
23 | """
24 | if random.random() > self.random_rate:
25 | return data
26 | data['img'] = (random_noise(data['img'], mode='gaussian', clip=True) * 255).astype(im.dtype)
27 | return data
28 |
29 |
30 | class RandomScale:
31 | def __init__(self, scales, random_rate):
32 | """
33 | :param scales: 尺度
34 | :param ramdon_rate: 随机系数
35 | :return:
36 | """
37 | self.random_rate = random_rate
38 | self.scales = scales
39 |
40 | def __call__(self, data: dict) -> dict:
41 | """
42 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
43 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
44 | :return:
45 | """
46 | if random.random() > self.random_rate:
47 | return data
48 | im = data['img']
49 | text_polys = data['text_polys']
50 |
51 | tmp_text_polys = text_polys.copy()
52 | rd_scale = float(np.random.choice(self.scales))
53 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
54 | tmp_text_polys *= rd_scale
55 |
56 | data['img'] = im
57 | data['text_polys'] = tmp_text_polys
58 | return data
59 |
60 |
61 | class RandomRotateImgBox:
62 | def __init__(self, degrees, random_rate, same_size=False):
63 | """
64 | :param degrees: 角度,可以是一个数值或者list
65 | :param ramdon_rate: 随机系数
66 | :param same_size: 是否保持和原图一样大
67 | :return:
68 | """
69 | if isinstance(degrees, numbers.Number):
70 | if degrees < 0:
71 | raise ValueError("If degrees is a single number, it must be positive.")
72 | degrees = (-degrees, degrees)
73 | elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray):
74 | if len(degrees) != 2:
75 | raise ValueError("If degrees is a sequence, it must be of len 2.")
76 | degrees = degrees
77 | else:
78 | raise Exception('degrees must in Number or list or tuple or np.ndarray')
79 | self.degrees = degrees
80 | self.same_size = same_size
81 | self.random_rate = random_rate
82 |
83 | def __call__(self, data: dict) -> dict:
84 | """
85 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
86 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
87 | :return:
88 | """
89 | if random.random() > self.random_rate:
90 | return data
91 | im = data['img']
92 | text_polys = data['text_polys']
93 |
94 | # ---------------------- 旋转图像 ----------------------
95 | w = im.shape[1]
96 | h = im.shape[0]
97 | angle = np.random.uniform(self.degrees[0], self.degrees[1])
98 |
99 | if self.same_size:
100 | nw = w
101 | nh = h
102 | else:
103 | # 角度变弧度
104 | rangle = np.deg2rad(angle)
105 | # 计算旋转之后图像的w, h
106 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
107 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
108 | # 构造仿射矩阵
109 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
110 | # 计算原图中心点到新图中心点的偏移量
111 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
112 | # 更新仿射矩阵
113 | rot_mat[0, 2] += rot_move[0]
114 | rot_mat[1, 2] += rot_move[1]
115 | # 仿射变换
116 | rot_img = cv2.warpAffine(im, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
117 |
118 | # ---------------------- 矫正bbox坐标 ----------------------
119 | # rot_mat是最终的旋转矩阵
120 | # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
121 | rot_text_polys = list()
122 | for bbox in text_polys:
123 | point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
124 | point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
125 | point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
126 | point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
127 | rot_text_polys.append([point1, point2, point3, point4])
128 | data['img'] = rot_img
129 | data['text_polys'] = np.array(rot_text_polys)
130 | return data
131 |
132 |
133 | class RandomResize:
134 | def __init__(self, size, random_rate, keep_ratio=False):
135 | """
136 | :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
137 | :param ramdon_rate: 随机系数
138 | :param keep_ratio: 是否保持长宽比
139 | :return:
140 | """
141 | if isinstance(size, numbers.Number):
142 | if size < 0:
143 | raise ValueError("If input_size is a single number, it must be positive.")
144 | size = (size, size)
145 | elif isinstance(size, list) or isinstance(size, tuple) or isinstance(size, np.ndarray):
146 | if len(size) != 2:
147 | raise ValueError("If input_size is a sequence, it must be of len 2.")
148 | size = (size[0], size[1])
149 | else:
150 | raise Exception('input_size must in Number or list or tuple or np.ndarray')
151 | self.size = size
152 | self.keep_ratio = keep_ratio
153 | self.random_rate = random_rate
154 |
155 | def __call__(self, data: dict) -> dict:
156 | """
157 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
158 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
159 | :return:
160 | """
161 | if random.random() > self.random_rate:
162 | return data
163 | im = data['img']
164 | text_polys = data['text_polys']
165 |
166 | if self.keep_ratio:
167 | # 将图片短边pad到和长边一样
168 | h, w, c = im.shape
169 | max_h = max(h, self.size[0])
170 | max_w = max(w, self.size[1])
171 | im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8)
172 | im_padded[:h, :w] = im.copy()
173 | im = im_padded
174 | text_polys = text_polys.astype(np.float32)
175 | h, w, _ = im.shape
176 | im = cv2.resize(im, self.size)
177 | w_scale = self.size[0] / float(w)
178 | h_scale = self.size[1] / float(h)
179 | text_polys[:, :, 0] *= w_scale
180 | text_polys[:, :, 1] *= h_scale
181 |
182 | data['img'] = im
183 | data['text_polys'] = text_polys
184 | return data
185 |
186 |
187 | def resize_image(img, short_size):
188 | height, width, _ = img.shape
189 | if height < width:
190 | new_height = short_size
191 | new_width = new_height / height * width
192 | else:
193 | new_width = short_size
194 | new_height = new_width / width * height
195 | new_height = int(round(new_height / 32) * 32)
196 | new_width = int(round(new_width / 32) * 32)
197 | resized_img = cv2.resize(img, (new_width, new_height))
198 | return resized_img, (new_width / width, new_height / height)
199 |
200 |
201 | class ResizeShortSize:
202 | def __init__(self, short_size, resize_text_polys=True):
203 | """
204 | :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
205 | :return:
206 | """
207 | self.short_size = short_size
208 | self.resize_text_polys = resize_text_polys
209 |
210 | def __call__(self, data: dict) -> dict:
211 | """
212 | 对图片和文本框进行缩放
213 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
214 | :return:
215 | """
216 | im = data['img']
217 | text_polys = data['text_polys']
218 |
219 | h, w, _ = im.shape
220 | short_edge = min(h, w)
221 | if short_edge < self.short_size:
222 | # 保证短边 >= short_size
223 | scale = self.short_size / short_edge
224 | im = cv2.resize(im, dsize=None, fx=scale, fy=scale)
225 | scale = (scale, scale)
226 | # im, scale = resize_image(im, self.short_size)
227 | if self.resize_text_polys:
228 | # text_polys *= scale
229 | text_polys[:, 0] *= scale[0]
230 | text_polys[:, 1] *= scale[1]
231 |
232 | data['img'] = im
233 | data['text_polys'] = text_polys
234 | return data
235 |
236 |
237 | class HorizontalFlip:
238 | def __init__(self, random_rate):
239 | """
240 |
241 | :param random_rate: 随机系数
242 | """
243 | self.random_rate = random_rate
244 |
245 | def __call__(self, data: dict) -> dict:
246 | """
247 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
248 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
249 | :return:
250 | """
251 | if random.random() > self.random_rate:
252 | return data
253 | im = data['img']
254 | text_polys = data['text_polys']
255 |
256 | flip_text_polys = text_polys.copy()
257 | flip_im = cv2.flip(im, 1)
258 | h, w, _ = flip_im.shape
259 | flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
260 |
261 | data['img'] = flip_im
262 | data['text_polys'] = flip_text_polys
263 | return data
264 |
265 |
266 | class VerticallFlip:
267 | def __init__(self, random_rate):
268 | """
269 |
270 | :param random_rate: 随机系数
271 | """
272 | self.random_rate = random_rate
273 |
274 | def __call__(self, data: dict) -> dict:
275 | """
276 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
277 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
278 | :return:
279 | """
280 | if random.random() > self.random_rate:
281 | return data
282 | im = data['img']
283 | text_polys = data['text_polys']
284 |
285 | flip_text_polys = text_polys.copy()
286 | flip_im = cv2.flip(im, 0)
287 | h, w, _ = flip_im.shape
288 | flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
289 | data['img'] = flip_im
290 | data['text_polys'] = flip_text_polys
291 | return data
292 |
--------------------------------------------------------------------------------
/data_loader/modules/iaa_augment.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/4 18:06
3 | # @Author : zhoujun
4 | import numpy as np
5 | import imgaug
6 | import imgaug.augmenters as iaa
7 |
8 | class AugmenterBuilder(object):
9 | def __init__(self):
10 | pass
11 |
12 | def build(self, args, root=True):
13 | if args is None or len(args) == 0:
14 | return None
15 | elif isinstance(args, list):
16 | if root:
17 | sequence = [self.build(value, root=False) for value in args]
18 | return iaa.Sequential(sequence)
19 | else:
20 | return getattr(iaa, args[0])(*[self.to_tuple_if_list(a) for a in args[1:]])
21 | elif isinstance(args, dict):
22 | cls = getattr(iaa, args['type'])
23 | return cls(**{k: self.to_tuple_if_list(v) for k, v in args['args'].items()})
24 | else:
25 | raise RuntimeError('unknown augmenter arg: ' + str(args))
26 |
27 | def to_tuple_if_list(self, obj):
28 | if isinstance(obj, list):
29 | return tuple(obj)
30 | return obj
31 |
32 |
33 | class IaaAugment():
34 | def __init__(self, augmenter_args):
35 | self.augmenter_args = augmenter_args
36 | self.augmenter = AugmenterBuilder().build(self.augmenter_args)
37 |
38 | def __call__(self, data):
39 | image = data['img']
40 | shape = image.shape
41 |
42 | if self.augmenter:
43 | aug = self.augmenter.to_deterministic()
44 | data['img'] = aug.augment_image(image)
45 | data = self.may_augment_annotation(aug, data, shape)
46 | return data
47 |
48 | def may_augment_annotation(self, aug, data, shape):
49 | if aug is None:
50 | return data
51 |
52 | line_polys = []
53 | for poly in data['text_polys']:
54 | new_poly = self.may_augment_poly(aug, shape, poly)
55 | line_polys.append(new_poly)
56 | data['text_polys'] = np.array(line_polys)
57 | return data
58 |
59 | def may_augment_poly(self, aug, img_shape, poly):
60 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
61 | keypoints = aug.augment_keypoints(
62 | [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints
63 | poly = [(p.x, p.y) for p in keypoints]
64 | return poly
65 |
--------------------------------------------------------------------------------
/data_loader/modules/make_border_map.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | np.seterr(divide='ignore',invalid='ignore')
4 | import pyclipper
5 | from shapely.geometry import Polygon
6 |
7 |
8 | class MakeBorderMap():
9 | def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7):
10 | self.shrink_ratio = shrink_ratio
11 | self.thresh_min = thresh_min
12 | self.thresh_max = thresh_max
13 |
14 | def __call__(self, data: dict) -> dict:
15 | """
16 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
17 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
18 | :return:
19 | """
20 | im = data['img']
21 | text_polys = data['text_polys']
22 | ignore_tags = data['ignore_tags']
23 |
24 | canvas = np.zeros(im.shape[:2], dtype=np.float32)
25 | mask = np.zeros(im.shape[:2], dtype=np.float32)
26 |
27 | for i in range(len(text_polys)):
28 | if ignore_tags[i]:
29 | continue
30 | self.draw_border_map(text_polys[i], canvas, mask=mask)
31 | canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
32 |
33 | data['threshold_map'] = canvas
34 | data['threshold_mask'] = mask
35 | return data
36 |
37 | def draw_border_map(self, polygon, canvas, mask):
38 | polygon = np.array(polygon)
39 | assert polygon.ndim == 2
40 | assert polygon.shape[1] == 2
41 |
42 | polygon_shape = Polygon(polygon)
43 | if polygon_shape.area <= 0:
44 | return
45 | distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
46 | subject = [tuple(l) for l in polygon]
47 | padding = pyclipper.PyclipperOffset()
48 | padding.AddPath(subject, pyclipper.JT_ROUND,
49 | pyclipper.ET_CLOSEDPOLYGON)
50 |
51 | padded_polygon = np.array(padding.Execute(distance)[0])
52 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
53 |
54 | xmin = padded_polygon[:, 0].min()
55 | xmax = padded_polygon[:, 0].max()
56 | ymin = padded_polygon[:, 1].min()
57 | ymax = padded_polygon[:, 1].max()
58 | width = xmax - xmin + 1
59 | height = ymax - ymin + 1
60 |
61 | polygon[:, 0] = polygon[:, 0] - xmin
62 | polygon[:, 1] = polygon[:, 1] - ymin
63 |
64 | xs = np.broadcast_to(
65 | np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
66 | ys = np.broadcast_to(
67 | np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
68 |
69 | distance_map = np.zeros(
70 | (polygon.shape[0], height, width), dtype=np.float32)
71 | for i in range(polygon.shape[0]):
72 | j = (i + 1) % polygon.shape[0]
73 | absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
74 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
75 | distance_map = distance_map.min(axis=0)
76 |
77 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
78 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
79 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
80 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
81 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
82 | 1 - distance_map[
83 | ymin_valid - ymin:ymax_valid - ymax + height,
84 | xmin_valid - xmin:xmax_valid - xmax + width],
85 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
86 |
87 | def distance(self, xs, ys, point_1, point_2):
88 | '''
89 | compute the distance from point to a line
90 | ys: coordinates in the first axis
91 | xs: coordinates in the second axis
92 | point_1, point_2: (x, y), the end of the line
93 | '''
94 | height, width = xs.shape[:2]
95 | square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
96 | square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
97 | square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])
98 |
99 | cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2))
100 | square_sin = 1 - np.square(cosin)
101 | square_sin = np.nan_to_num(square_sin)
102 |
103 | result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance)
104 | result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0]
105 | # self.extend_line(point_1, point_2, result)
106 | return result
107 |
108 | def extend_line(self, point_1, point_2, result):
109 | ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))),
110 | int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))))
111 | cv2.line(result, tuple(ex_point_1), tuple(point_1), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
112 | ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))),
113 | int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))))
114 | cv2.line(result, tuple(ex_point_2), tuple(point_2), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
115 | return ex_point_1, ex_point_2
116 |
--------------------------------------------------------------------------------
/data_loader/modules/make_shrink_map.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | def shrink_polygon_py(polygon, shrink_ratio):
6 | """
7 | 对框进行缩放,返回去的比例为1/shrink_ratio 即可
8 | """
9 | cx = polygon[:, 0].mean()
10 | cy = polygon[:, 1].mean()
11 | polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio
12 | polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio
13 | return polygon
14 |
15 |
16 | def shrink_polygon_pyclipper(polygon, shrink_ratio):
17 | from shapely.geometry import Polygon
18 | import pyclipper
19 | polygon_shape = Polygon(polygon)
20 | distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
21 | subject = [tuple(l) for l in polygon]
22 | padding = pyclipper.PyclipperOffset()
23 | padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
24 | shrinked = padding.Execute(-distance)
25 | if shrinked == []:
26 | shrinked = np.array(shrinked)
27 | else:
28 | shrinked = np.array(shrinked[0]).reshape(-1, 2)
29 | return shrinked
30 |
31 |
32 | class MakeShrinkMap():
33 | r'''
34 | Making binary mask from detection data with ICDAR format.
35 | Typically following the process of class `MakeICDARData`.
36 | '''
37 |
38 | def __init__(self, min_text_size=8, shrink_ratio=0.4, shrink_type='pyclipper'):
39 | shrink_func_dict = {'py': shrink_polygon_py, 'pyclipper': shrink_polygon_pyclipper}
40 | self.shrink_func = shrink_func_dict[shrink_type]
41 | self.min_text_size = min_text_size
42 | self.shrink_ratio = shrink_ratio
43 |
44 | def __call__(self, data: dict) -> dict:
45 | """
46 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
47 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
48 | :return:
49 | """
50 | image = data['img']
51 | text_polys = data['text_polys']
52 | ignore_tags = data['ignore_tags']
53 |
54 | h, w = image.shape[:2]
55 | text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
56 | gt = np.zeros((h, w), dtype=np.float32)
57 | mask = np.ones((h, w), dtype=np.float32)
58 | for i in range(len(text_polys)):
59 | polygon = text_polys[i]
60 | height = max(polygon[:, 1]) - min(polygon[:, 1])
61 | width = max(polygon[:, 0]) - min(polygon[:, 0])
62 | if ignore_tags[i] or min(height, width) < self.min_text_size:
63 | cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
64 | ignore_tags[i] = True
65 | else:
66 | shrinked = self.shrink_func(polygon, self.shrink_ratio)
67 | if shrinked.size == 0:
68 | cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
69 | ignore_tags[i] = True
70 | continue
71 | cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
72 |
73 | data['shrink_map'] = gt
74 | data['shrink_mask'] = mask
75 | return data
76 |
77 | def validate_polygons(self, polygons, ignore_tags, h, w):
78 | '''
79 | polygons (numpy.array, required): of shape (num_instances, num_points, 2)
80 | '''
81 | if len(polygons) == 0:
82 | return polygons, ignore_tags
83 | assert len(polygons) == len(ignore_tags)
84 | for polygon in polygons:
85 | polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
86 | polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
87 |
88 | for i in range(len(polygons)):
89 | area = self.polygon_area(polygons[i])
90 | if abs(area) < 1:
91 | ignore_tags[i] = True
92 | if area > 0:
93 | polygons[i] = polygons[i][::-1, :]
94 | return polygons, ignore_tags
95 |
96 | def polygon_area(self, polygon):
97 | return cv2.contourArea(polygon)
98 | # edge = 0
99 | # for i in range(polygon.shape[0]):
100 | # next_index = (i + 1) % polygon.shape[0]
101 | # edge += (polygon[next_index, 0] - polygon[i, 0]) * (polygon[next_index, 1] - polygon[i, 1])
102 | #
103 | # return edge / 2.
104 |
105 |
106 | if __name__ == '__main__':
107 | from shapely.geometry import Polygon
108 | import pyclipper
109 |
110 | polygon = np.array([[0, 0], [100, 10], [100, 100], [10, 90]])
111 | a = shrink_polygon_py(polygon, 0.4)
112 | print(a)
113 | print(shrink_polygon_py(a, 1 / 0.4))
114 | b = shrink_polygon_pyclipper(polygon, 0.4)
115 | print(b)
116 | poly = Polygon(b)
117 | distance = poly.area * 1.5 / poly.length
118 | offset = pyclipper.PyclipperOffset()
119 | offset.AddPath(b, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
120 | expanded = np.array(offset.Execute(distance))
121 | bounding_box = cv2.minAreaRect(expanded)
122 | points = cv2.boxPoints(bounding_box)
123 | print(points)
124 |
--------------------------------------------------------------------------------
/data_loader/modules/random_crop_data.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import cv2
4 | import numpy as np
5 |
6 |
7 | # random crop algorithm similar to https://github.com/argman/EAST
8 | class EastRandomCropData():
9 | def __init__(self, size=(640, 640), max_tries=50, min_crop_side_ratio=0.1, require_original_image=False, keep_ratio=True):
10 | self.size = size
11 | self.max_tries = max_tries
12 | self.min_crop_side_ratio = min_crop_side_ratio
13 | self.require_original_image = require_original_image
14 | self.keep_ratio = keep_ratio
15 |
16 | def __call__(self, data: dict) -> dict:
17 | """
18 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
19 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
20 | :return:
21 | """
22 | im = data['img']
23 | text_polys = data['text_polys']
24 | ignore_tags = data['ignore_tags']
25 | texts = data['texts']
26 | all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
27 | # 计算crop区域
28 | crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
29 | # crop 图片 保持比例填充
30 | scale_w = self.size[0] / crop_w
31 | scale_h = self.size[1] / crop_h
32 | scale = min(scale_w, scale_h)
33 | h = int(crop_h * scale)
34 | w = int(crop_w * scale)
35 | if self.keep_ratio:
36 | if len(im.shape) == 3:
37 | padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
38 | else:
39 | padimg = np.zeros((self.size[1], self.size[0]), im.dtype)
40 | padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
41 | img = padimg
42 | else:
43 | img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], tuple(self.size))
44 | # crop 文本框
45 | text_polys_crop = []
46 | ignore_tags_crop = []
47 | texts_crop = []
48 | for poly, text, tag in zip(text_polys, texts, ignore_tags):
49 | poly = ((poly - (crop_x, crop_y)) * scale).tolist()
50 | if not self.is_poly_outside_rect(poly, 0, 0, w, h):
51 | text_polys_crop.append(poly)
52 | ignore_tags_crop.append(tag)
53 | texts_crop.append(text)
54 | data['img'] = img
55 | data['text_polys'] = np.float32(text_polys_crop)
56 | data['ignore_tags'] = ignore_tags_crop
57 | data['texts'] = texts_crop
58 | return data
59 |
60 | def is_poly_in_rect(self, poly, x, y, w, h):
61 | poly = np.array(poly)
62 | if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
63 | return False
64 | if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
65 | return False
66 | return True
67 |
68 | def is_poly_outside_rect(self, poly, x, y, w, h):
69 | poly = np.array(poly)
70 | if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
71 | return True
72 | if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
73 | return True
74 | return False
75 |
76 | def split_regions(self, axis):
77 | regions = []
78 | min_axis = 0
79 | for i in range(1, axis.shape[0]):
80 | if axis[i] != axis[i - 1] + 1:
81 | region = axis[min_axis:i]
82 | min_axis = i
83 | regions.append(region)
84 | return regions
85 |
86 | def random_select(self, axis, max_size):
87 | xx = np.random.choice(axis, size=2)
88 | xmin = np.min(xx)
89 | xmax = np.max(xx)
90 | xmin = np.clip(xmin, 0, max_size - 1)
91 | xmax = np.clip(xmax, 0, max_size - 1)
92 | return xmin, xmax
93 |
94 | def region_wise_random_select(self, regions, max_size):
95 | selected_index = list(np.random.choice(len(regions), 2))
96 | selected_values = []
97 | for index in selected_index:
98 | axis = regions[index]
99 | xx = int(np.random.choice(axis, size=1))
100 | selected_values.append(xx)
101 | xmin = min(selected_values)
102 | xmax = max(selected_values)
103 | return xmin, xmax
104 |
105 | def crop_area(self, im, text_polys):
106 | h, w = im.shape[:2]
107 | h_array = np.zeros(h, dtype=np.int32)
108 | w_array = np.zeros(w, dtype=np.int32)
109 | for points in text_polys:
110 | points = np.round(points, decimals=0).astype(np.int32)
111 | minx = np.min(points[:, 0])
112 | maxx = np.max(points[:, 0])
113 | w_array[minx:maxx] = 1
114 | miny = np.min(points[:, 1])
115 | maxy = np.max(points[:, 1])
116 | h_array[miny:maxy] = 1
117 | # ensure the cropped area not across a text
118 | h_axis = np.where(h_array == 0)[0]
119 | w_axis = np.where(w_array == 0)[0]
120 |
121 | if len(h_axis) == 0 or len(w_axis) == 0:
122 | return 0, 0, w, h
123 |
124 | h_regions = self.split_regions(h_axis)
125 | w_regions = self.split_regions(w_axis)
126 |
127 | for i in range(self.max_tries):
128 | if len(w_regions) > 1:
129 | xmin, xmax = self.region_wise_random_select(w_regions, w)
130 | else:
131 | xmin, xmax = self.random_select(w_axis, w)
132 | if len(h_regions) > 1:
133 | ymin, ymax = self.region_wise_random_select(h_regions, h)
134 | else:
135 | ymin, ymax = self.random_select(h_axis, h)
136 |
137 | if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
138 | # area too small
139 | continue
140 | num_poly_in_rect = 0
141 | for poly in text_polys:
142 | if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
143 | num_poly_in_rect += 1
144 | break
145 |
146 | if num_poly_in_rect > 0:
147 | return xmin, ymin, xmax - xmin, ymax - ymin
148 |
149 | return 0, 0, w, h
150 |
151 |
152 | class PSERandomCrop():
153 | def __init__(self, size):
154 | self.size = size
155 |
156 | def __call__(self, data):
157 | imgs = data['imgs']
158 |
159 | h, w = imgs[0].shape[0:2]
160 | th, tw = self.size
161 | if w == tw and h == th:
162 | return imgs
163 |
164 | # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
165 | if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
166 | # 文本实例的左上角点
167 | tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
168 | tl[tl < 0] = 0
169 | # 文本实例的右下角点
170 | br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
171 | br[br < 0] = 0
172 | # 保证选到右下角点时,有足够的距离进行crop
173 | br[0] = min(br[0], h - th)
174 | br[1] = min(br[1], w - tw)
175 |
176 | for _ in range(50000):
177 | i = random.randint(tl[0], br[0])
178 | j = random.randint(tl[1], br[1])
179 | # 保证shrink_label_map有文本
180 | if imgs[1][i:i + th, j:j + tw].sum() <= 0:
181 | continue
182 | else:
183 | break
184 | else:
185 | i = random.randint(0, h - th)
186 | j = random.randint(0, w - tw)
187 |
188 | # return i, j, th, tw
189 | for idx in range(len(imgs)):
190 | if len(imgs[idx].shape) == 3:
191 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
192 | else:
193 | imgs[idx] = imgs[idx][i:i + th, j:j + tw]
194 | data['imgs'] = imgs
195 | return data
196 |
--------------------------------------------------------------------------------
/datasets/test.txt:
--------------------------------------------------------------------------------
1 | ./datasets/test/img/001.jpg ./datasets/test/gt/001.txt
2 | ./datasets/test/img/002.jpg ./datasets/test/gt/002.txt
3 |
--------------------------------------------------------------------------------
/datasets/test/gt/README.MD:
--------------------------------------------------------------------------------
1 | Place the `.txt` ground truth here, with format of `x1, y1, x2, y2, x3, y3, x4, y4, annotation`
2 |
--------------------------------------------------------------------------------
/datasets/test/img/README.MD:
--------------------------------------------------------------------------------
1 | Place the `.jpg` files here.
2 |
--------------------------------------------------------------------------------
/datasets/train.txt:
--------------------------------------------------------------------------------
1 | ./datasets/train/img/001.jpg ./datasets/train/gt/001.txt
2 | ./datasets/train/img/002.jpg ./datasets/train/gt/002.txt
3 |
--------------------------------------------------------------------------------
/datasets/train/gt/README.MD:
--------------------------------------------------------------------------------
1 | Place the `.txt` ground truth here, with format of `x1, y1, x2, y2, x3, y3, x4, y4, annotation`
2 |
--------------------------------------------------------------------------------
/datasets/train/img/README.MD:
--------------------------------------------------------------------------------
1 | Place the `.jpg` files here.
2 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: dbnet
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - pytorch==1.4
8 | - torchvision==0.5
9 | - anyconfig==0.9.10
10 | - future==0.18.2
11 | - imgaug==0.4.0
12 | - matplotlib==3.1.2
13 | - numpy==1.17.4
14 | - opencv
15 | - pyclipper
16 | - PyYAML==5.2
17 | - scikit-image==0.16.2
18 | - Shapely==1.6.4
19 | - tensorboard=2
20 | - tqdm==4.40.1
21 | - ipython
22 | - pip
23 | - pytorch
24 | - torchvision
25 | - pip:
26 | - polygon3
27 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py --model_path ''
--------------------------------------------------------------------------------
/generate_lists.sh:
--------------------------------------------------------------------------------
1 | #Only use if your file names of the images and txts are identical
2 | rm ./datasets/train_img.txt
3 | rm ./datasets/train_gt.txt
4 | rm ./datasets/test_img.txt
5 | rm ./datasets/test_gt.txt
6 | rm ./datasets/train.txt
7 | rm ./datasets/test.txt
8 | ls ./datasets/train/img/*.jpg > ./datasets/train_img.txt
9 | ls ./datasets/train/gt/*.txt > ./datasets/train_gt.txt
10 | ls ./datasets/test/img/*.jpg > ./datasets/test_img.txt
11 | ls ./datasets/test/gt/*.txt > ./datasets/test_gt.txt
12 | paste ./datasets/train_img.txt ./datasets/train_gt.txt > ./datasets/train.txt
13 | paste ./datasets/test_img.txt ./datasets/test_gt.txt > ./datasets/test.txt
14 | rm ./datasets/train_img.txt
15 | rm ./datasets/train_gt.txt
16 | rm ./datasets/test_img.txt
17 | rm ./datasets/test_gt.txt
18 |
--------------------------------------------------------------------------------
/imgs/paper/db.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/DBNet.pytorch/e03acf0e6b3b62f7d1dc7e10a6d2587456ac9ea1/imgs/paper/db.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:55
3 | # @Author : zhoujun
4 | import copy
5 | from .model import Model
6 | from .losses import build_loss
7 |
8 | __all__ = ['build_loss', 'build_model']
9 | support_model = ['Model']
10 |
11 |
12 | def build_model(config):
13 | """
14 | get architecture model class
15 | """
16 | copy_config = copy.deepcopy(config)
17 | arch_type = copy_config.pop('type')
18 | assert arch_type in support_model, f'{arch_type} is not developed yet!, only {support_model} are support now'
19 | arch_model = eval(arch_type)(copy_config)
20 | return arch_model
21 |
--------------------------------------------------------------------------------
/models/backbone/MobilenetV3.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class HSwish(nn.Module):
9 | def forward(self, x):
10 | out = x * F.relu6(x + 3, inplace=True) / 6
11 | return out
12 |
13 |
14 | class HardSigmoid(nn.Module):
15 | def __init__(self, slope=.2, offset=.5):
16 | super().__init__()
17 | self.slope = slope
18 | self.offset = offset
19 |
20 | def forward(self, x):
21 | x = (self.slope * x) + self.offset
22 | x = F.threshold(-x, -1, -1)
23 | x = F.threshold(-x, 0, 0)
24 | return x
25 |
26 |
27 | class ConvBNACT(nn.Module):
28 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None):
29 | super().__init__()
30 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
31 | stride=stride, padding=padding, groups=groups,
32 | bias=False)
33 | self.bn = nn.BatchNorm2d(out_channels)
34 | if act == 'relu':
35 | self.act = nn.ReLU()
36 | elif act == 'hard_swish':
37 | self.act = HSwish()
38 | elif act is None:
39 | self.act = None
40 |
41 | def forward(self, x):
42 | x = self.conv(x)
43 | x = self.bn(x)
44 | if self.act is not None:
45 | x = self.act(x)
46 | return x
47 |
48 |
49 | class SEBlock(nn.Module):
50 | def __init__(self, in_channels, out_channels, ratio=4):
51 | super().__init__()
52 | num_mid_filter = out_channels // ratio
53 | self.pool = nn.AdaptiveAvgPool2d(1)
54 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_mid_filter, kernel_size=1, bias=True)
55 | self.relu1 = nn.ReLU()
56 | self.conv2 = nn.Conv2d(in_channels=num_mid_filter, kernel_size=1, out_channels=out_channels, bias=True)
57 | self.relu2 = HardSigmoid()
58 |
59 | def forward(self, x):
60 | attn = self.pool(x)
61 | attn = self.conv1(attn)
62 | attn = self.relu1(attn)
63 | attn = self.conv2(attn)
64 | attn = self.relu2(attn)
65 | return x * attn
66 |
67 |
68 | class ResidualUnit(nn.Module):
69 | def __init__(self, num_in_filter, num_mid_filter, num_out_filter, stride, kernel_size, act=None, use_se=False):
70 | super().__init__()
71 | self.conv0 = ConvBNACT(in_channels=num_in_filter, out_channels=num_mid_filter, kernel_size=1, stride=1,
72 | padding=0, act=act)
73 |
74 | self.conv1 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_mid_filter, kernel_size=kernel_size,
75 | stride=stride,
76 | padding=int((kernel_size - 1) // 2), act=act, groups=num_mid_filter)
77 | if use_se:
78 | self.se = SEBlock(in_channels=num_mid_filter, out_channels=num_mid_filter)
79 | else:
80 | self.se = None
81 |
82 | self.conv2 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_out_filter, kernel_size=1, stride=1,
83 | padding=0)
84 | self.not_add = num_in_filter != num_out_filter or stride != 1
85 |
86 | def forward(self, x):
87 | y = self.conv0(x)
88 | y = self.conv1(y)
89 | if self.se is not None:
90 | y = self.se(y)
91 | y = self.conv2(y)
92 | if not self.not_add:
93 | y = x + y
94 | return y
95 |
96 |
97 | class MobileNetV3(nn.Module):
98 | def __init__(self, in_channels=3, **kwargs):
99 | """
100 | the MobilenetV3 backbone network for detection module.
101 | Args:
102 | params(dict): the super parameters for build network
103 | """
104 | super().__init__()
105 | self.scale = kwargs.get('scale', 0.5)
106 | model_name = kwargs.get('model_name', 'large')
107 | self.inplanes = 16
108 | if model_name == "large":
109 | self.cfg = [
110 | # k, exp, c, se, nl, s,
111 | [3, 16, 16, False, 'relu', 1],
112 | [3, 64, 24, False, 'relu', 2],
113 | [3, 72, 24, False, 'relu', 1],
114 | [5, 72, 40, True, 'relu', 2],
115 | [5, 120, 40, True, 'relu', 1],
116 | [5, 120, 40, True, 'relu', 1],
117 | [3, 240, 80, False, 'hard_swish', 2],
118 | [3, 200, 80, False, 'hard_swish', 1],
119 | [3, 184, 80, False, 'hard_swish', 1],
120 | [3, 184, 80, False, 'hard_swish', 1],
121 | [3, 480, 112, True, 'hard_swish', 1],
122 | [3, 672, 112, True, 'hard_swish', 1],
123 | [5, 672, 160, True, 'hard_swish', 2],
124 | [5, 960, 160, True, 'hard_swish', 1],
125 | [5, 960, 160, True, 'hard_swish', 1],
126 | ]
127 | self.cls_ch_squeeze = 960
128 | self.cls_ch_expand = 1280
129 | elif model_name == "small":
130 | self.cfg = [
131 | # k, exp, c, se, nl, s,
132 | [3, 16, 16, True, 'relu', 2],
133 | [3, 72, 24, False, 'relu', 2],
134 | [3, 88, 24, False, 'relu', 1],
135 | [5, 96, 40, True, 'hard_swish', 2],
136 | [5, 240, 40, True, 'hard_swish', 1],
137 | [5, 240, 40, True, 'hard_swish', 1],
138 | [5, 120, 48, True, 'hard_swish', 1],
139 | [5, 144, 48, True, 'hard_swish', 1],
140 | [5, 288, 96, True, 'hard_swish', 2],
141 | [5, 576, 96, True, 'hard_swish', 1],
142 | [5, 576, 96, True, 'hard_swish', 1],
143 | ]
144 | self.cls_ch_squeeze = 576
145 | self.cls_ch_expand = 1280
146 | else:
147 | raise NotImplementedError("mode[" + model_name +
148 | "_model] is not implemented!")
149 |
150 | supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
151 | assert self.scale in supported_scale, \
152 | "supported scale are {} but input scale is {}".format(supported_scale, self.scale)
153 |
154 | scale = self.scale
155 | inplanes = self.inplanes
156 | cfg = self.cfg
157 | cls_ch_squeeze = self.cls_ch_squeeze
158 | # conv1
159 | self.conv1 = ConvBNACT(in_channels=in_channels,
160 | out_channels=self.make_divisible(inplanes * scale),
161 | kernel_size=3,
162 | stride=2,
163 | padding=1,
164 | groups=1,
165 | act='hard_swish')
166 | i = 0
167 | inplanes = self.make_divisible(inplanes * scale)
168 | self.stages = nn.ModuleList()
169 | block_list = []
170 | self.out_channels = []
171 | for layer_cfg in cfg:
172 | if layer_cfg[5] == 2 and i > 2:
173 | self.out_channels.append(inplanes)
174 | self.stages.append(nn.Sequential(*block_list))
175 | block_list = []
176 | block = ResidualUnit(num_in_filter=inplanes,
177 | num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
178 | num_out_filter=self.make_divisible(scale * layer_cfg[2]),
179 | act=layer_cfg[4],
180 | stride=layer_cfg[5],
181 | kernel_size=layer_cfg[0],
182 | use_se=layer_cfg[3])
183 | block_list.append(block)
184 | inplanes = self.make_divisible(scale * layer_cfg[2])
185 | i += 1
186 | self.stages.append(nn.Sequential(*block_list))
187 | self.conv2 = ConvBNACT(
188 | in_channels=inplanes,
189 | out_channels=self.make_divisible(scale * cls_ch_squeeze),
190 | kernel_size=1,
191 | stride=1,
192 | padding=0,
193 | groups=1,
194 | act='hard_swish')
195 | self.out_channels.append(self.make_divisible(scale * cls_ch_squeeze))
196 |
197 | def make_divisible(self, v, divisor=8, min_value=None):
198 | if min_value is None:
199 | min_value = divisor
200 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
201 | if new_v < 0.9 * v:
202 | new_v += divisor
203 | return new_v
204 |
205 | def forward(self, x):
206 | x = self.conv1(x)
207 | out = []
208 | for stage in self.stages:
209 | x = stage(x)
210 | out.append(x)
211 | out[-1] = self.conv2(out[-1])
212 | return out
213 |
--------------------------------------------------------------------------------
/models/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:54
3 | # @Author : zhoujun
4 |
5 | from .resnet import *
6 | from .resnest import *
7 | from .shufflenetv2 import *
8 | from .MobilenetV3 import MobileNetV3
9 |
10 | __all__ = ['build_backbone']
11 |
12 | support_backbone = ['resnet18', 'deformable_resnet18', 'deformable_resnet50',
13 | 'resnet50', 'resnet34', 'resnet101', 'resnet152',
14 | 'resnest50', 'resnest101', 'resnest200', 'resnest269',
15 | 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0',
16 | 'MobileNetV3']
17 |
18 |
19 | def build_backbone(backbone_name, **kwargs):
20 | assert backbone_name in support_backbone, f'all support backbone is {support_backbone}'
21 | backbone = eval(backbone_name)(**kwargs)
22 | return backbone
23 |
--------------------------------------------------------------------------------
/models/backbone/resnest/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnest import *
--------------------------------------------------------------------------------
/models/backbone/resnest/ablation.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang
3 | ## Email: zhanghang0704@gmail.com
4 | ## Copyright (c) 2020
5 | ##
6 | ## LICENSE file in the root directory of this source tree
7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8 | """ResNeSt ablation study models"""
9 |
10 | import torch
11 | from .resnet import ResNet, Bottleneck
12 |
13 | __all__ = ['resnest50_fast_1s1x64d', 'resnest50_fast_2s1x64d', 'resnest50_fast_4s1x64d',
14 | 'resnest50_fast_1s2x40d', 'resnest50_fast_2s2x40d', 'resnest50_fast_4s2x40d',
15 | 'resnest50_fast_1s4x24d']
16 |
17 | _url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
18 |
19 | _model_sha256 = {name: checksum for checksum, name in [
20 | ('d8fbf808', 'resnest50_fast_1s1x64d'),
21 | ('44938639', 'resnest50_fast_2s1x64d'),
22 | ('f74f3fc3', 'resnest50_fast_4s1x64d'),
23 | ('32830b84', 'resnest50_fast_1s2x40d'),
24 | ('9d126481', 'resnest50_fast_2s2x40d'),
25 | ('41d14ed0', 'resnest50_fast_4s2x40d'),
26 | ('d4a4f76f', 'resnest50_fast_1s4x24d'),
27 | ]}
28 |
29 | def short_hash(name):
30 | if name not in _model_sha256:
31 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
32 | return _model_sha256[name][:8]
33 |
34 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for
35 | name in _model_sha256.keys()
36 | }
37 |
38 | def resnest50_fast_1s1x64d(pretrained=False, root='~/.encoding/models', **kwargs):
39 | model = ResNet(Bottleneck, [3, 4, 6, 3],
40 | radix=1, groups=1, bottleneck_width=64,
41 | deep_stem=True, stem_width=32, avg_down=True,
42 | avd=True, avd_first=True, **kwargs)
43 | if pretrained:
44 | model.load_state_dict(torch.hub.load_state_dict_from_url(
45 | resnest_model_urls['resnest50_fast_1s1x64d'], progress=True, check_hash=True))
46 | return model
47 |
48 | def resnest50_fast_2s1x64d(pretrained=False, root='~/.encoding/models', **kwargs):
49 | model = ResNet(Bottleneck, [3, 4, 6, 3],
50 | radix=2, groups=1, bottleneck_width=64,
51 | deep_stem=True, stem_width=32, avg_down=True,
52 | avd=True, avd_first=True, **kwargs)
53 | if pretrained:
54 | model.load_state_dict(torch.hub.load_state_dict_from_url(
55 | resnest_model_urls['resnest50_fast_2s1x64d'], progress=True, check_hash=True))
56 | return model
57 |
58 | def resnest50_fast_4s1x64d(pretrained=False, root='~/.encoding/models', **kwargs):
59 | model = ResNet(Bottleneck, [3, 4, 6, 3],
60 | radix=4, groups=1, bottleneck_width=64,
61 | deep_stem=True, stem_width=32, avg_down=True,
62 | avd=True, avd_first=True, **kwargs)
63 | if pretrained:
64 | model.load_state_dict(torch.hub.load_state_dict_from_url(
65 | resnest_model_urls['resnest50_fast_4s1x64d'], progress=True, check_hash=True))
66 | return model
67 |
68 | def resnest50_fast_1s2x40d(pretrained=False, root='~/.encoding/models', **kwargs):
69 | model = ResNet(Bottleneck, [3, 4, 6, 3],
70 | radix=1, groups=2, bottleneck_width=40,
71 | deep_stem=True, stem_width=32, avg_down=True,
72 | avd=True, avd_first=True, **kwargs)
73 | if pretrained:
74 | model.load_state_dict(torch.hub.load_state_dict_from_url(
75 | resnest_model_urls['resnest50_fast_1s2x40d'], progress=True, check_hash=True))
76 | return model
77 |
78 | def resnest50_fast_2s2x40d(pretrained=False, root='~/.encoding/models', **kwargs):
79 | model = ResNet(Bottleneck, [3, 4, 6, 3],
80 | radix=2, groups=2, bottleneck_width=40,
81 | deep_stem=True, stem_width=32, avg_down=True,
82 | avd=True, avd_first=True, **kwargs)
83 | if pretrained:
84 | model.load_state_dict(torch.hub.load_state_dict_from_url(
85 | resnest_model_urls['resnest50_fast_2s2x40d'], progress=True, check_hash=True))
86 | return model
87 |
88 | def resnest50_fast_4s2x40d(pretrained=False, root='~/.encoding/models', **kwargs):
89 | model = ResNet(Bottleneck, [3, 4, 6, 3],
90 | radix=4, groups=2, bottleneck_width=40,
91 | deep_stem=True, stem_width=32, avg_down=True,
92 | avd=True, avd_first=True, **kwargs)
93 | if pretrained:
94 | model.load_state_dict(torch.hub.load_state_dict_from_url(
95 | resnest_model_urls['resnest50_fast_4s2x40d'], progress=True, check_hash=True))
96 | return model
97 |
98 | def resnest50_fast_1s4x24d(pretrained=False, root='~/.encoding/models', **kwargs):
99 | model = ResNet(Bottleneck, [3, 4, 6, 3],
100 | radix=1, groups=4, bottleneck_width=24,
101 | deep_stem=True, stem_width=32, avg_down=True,
102 | avd=True, avd_first=True, **kwargs)
103 | if pretrained:
104 | model.load_state_dict(torch.hub.load_state_dict_from_url(
105 | resnest_model_urls['resnest50_fast_1s4x24d'], progress=True, check_hash=True))
106 | return model
107 |
--------------------------------------------------------------------------------
/models/backbone/resnest/resnest.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang
3 | ## Email: zhanghang0704@gmail.com
4 | ## Copyright (c) 2020
5 | ##
6 | ## LICENSE file in the root directory of this source tree
7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8 | """ResNeSt models"""
9 |
10 | import torch
11 | from models.backbone.resnest.resnet import ResNet, Bottleneck
12 |
13 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269']
14 |
15 | _url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
16 |
17 | _model_sha256 = {name: checksum for checksum, name in [
18 | ('528c19ca', 'resnest50'),
19 | ('22405ba7', 'resnest101'),
20 | ('75117900', 'resnest200'),
21 | ('0cc87c48', 'resnest269'),
22 | ]}
23 |
24 | def short_hash(name):
25 | if name not in _model_sha256:
26 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
27 | return _model_sha256[name][:8]
28 |
29 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for
30 | name in _model_sha256.keys()
31 | }
32 |
33 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs):
34 | model = ResNet(Bottleneck, [3, 4, 6, 3],
35 | radix=2, groups=1, bottleneck_width=64,
36 | deep_stem=True, stem_width=32, avg_down=True,
37 | avd=True, avd_first=False, **kwargs)
38 | if pretrained:
39 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
40 | model.load_state_dict(torch.hub.load_state_dict_from_url(
41 | resnest_model_urls['resnest50'], progress=True, check_hash=True))
42 | return model
43 |
44 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
45 | model = ResNet(Bottleneck, [3, 4, 23, 3],
46 | radix=2, groups=1, bottleneck_width=64,
47 | deep_stem=True, stem_width=64, avg_down=True,
48 | avd=True, avd_first=False, **kwargs)
49 | if pretrained:
50 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
51 | model.load_state_dict(torch.hub.load_state_dict_from_url(
52 | resnest_model_urls['resnest101'], progress=True, check_hash=True))
53 | return model
54 |
55 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
56 | model = ResNet(Bottleneck, [3, 24, 36, 3],
57 | radix=2, groups=1, bottleneck_width=64,
58 | deep_stem=True, stem_width=64, avg_down=True,
59 | avd=True, avd_first=False, **kwargs)
60 | if pretrained:
61 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
62 | model.load_state_dict(torch.hub.load_state_dict_from_url(
63 | resnest_model_urls['resnest200'], progress=True, check_hash=True))
64 | return model
65 |
66 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs):
67 | model = ResNet(Bottleneck, [3, 30, 48, 8],
68 | radix=2, groups=1, bottleneck_width=64,
69 | deep_stem=True, stem_width=64, avg_down=True,
70 | avd=True, avd_first=False, **kwargs)
71 | if pretrained:
72 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
73 | model.load_state_dict(torch.hub.load_state_dict_from_url(
74 | resnest_model_urls['resnest269'], progress=True, check_hash=True))
75 | return model
76 |
77 | if __name__ == '__main__':
78 | x = torch.zeros(2,3,640,640)
79 | net = resnest269(pretrained=False)
80 | y = net(x)
81 | for u in y:
82 | print(u.shape)
83 | print(net.out_channels)
--------------------------------------------------------------------------------
/models/backbone/resnest/splat.py:
--------------------------------------------------------------------------------
1 | """Split-Attention"""
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
7 | from torch.nn.modules.utils import _pair
8 |
9 | __all__ = ['SplAtConv2d']
10 |
11 | class SplAtConv2d(Module):
12 | """Split-Attention Conv2d
13 | """
14 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
15 | dilation=(1, 1), groups=1, bias=True,
16 | radix=2, reduction_factor=4,
17 | rectify=False, rectify_avg=False, norm_layer=None,
18 | dropblock_prob=0.0, **kwargs):
19 | super(SplAtConv2d, self).__init__()
20 | padding = _pair(padding)
21 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
22 | self.rectify_avg = rectify_avg
23 | inter_channels = max(in_channels*radix//reduction_factor, 32)
24 | self.radix = radix
25 | self.cardinality = groups
26 | self.channels = channels
27 | self.dropblock_prob = dropblock_prob
28 | if self.rectify:
29 | from rfconv import RFConv2d
30 | self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
31 | groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs)
32 | else:
33 | self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
34 | groups=groups*radix, bias=bias, **kwargs)
35 | self.use_bn = norm_layer is not None
36 | if self.use_bn:
37 | self.bn0 = norm_layer(channels*radix)
38 | self.relu = ReLU(inplace=True)
39 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
40 | if self.use_bn:
41 | self.bn1 = norm_layer(inter_channels)
42 | self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality)
43 | if dropblock_prob > 0.0:
44 | self.dropblock = DropBlock2D(dropblock_prob, 3)
45 | self.rsoftmax = rSoftMax(radix, groups)
46 |
47 | def forward(self, x):
48 | x = self.conv(x)
49 | if self.use_bn:
50 | x = self.bn0(x)
51 | if self.dropblock_prob > 0.0:
52 | x = self.dropblock(x)
53 | x = self.relu(x)
54 |
55 | batch, rchannel = x.shape[:2]
56 | if self.radix > 1:
57 | if torch.__version__ < '1.5':
58 | splited = torch.split(x, int(rchannel//self.radix), dim=1)
59 | else:
60 | splited = torch.split(x, rchannel//self.radix, dim=1)
61 | gap = sum(splited)
62 | else:
63 | gap = x
64 | gap = F.adaptive_avg_pool2d(gap, 1)
65 | gap = self.fc1(gap)
66 |
67 | if self.use_bn:
68 | gap = self.bn1(gap)
69 | gap = self.relu(gap)
70 |
71 | atten = self.fc2(gap)
72 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
73 |
74 | if self.radix > 1:
75 | if torch.__version__ < '1.5':
76 | attens = torch.split(atten, int(rchannel//self.radix), dim=1)
77 | else:
78 | attens = torch.split(atten, rchannel//self.radix, dim=1)
79 | out = sum([att*split for (att, split) in zip(attens, splited)])
80 | else:
81 | out = atten * x
82 | return out.contiguous()
83 |
84 | class rSoftMax(nn.Module):
85 | def __init__(self, radix, cardinality):
86 | super().__init__()
87 | self.radix = radix
88 | self.cardinality = cardinality
89 |
90 | def forward(self, x):
91 | batch = x.size(0)
92 | if self.radix > 1:
93 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
94 | x = F.softmax(x, dim=1)
95 | x = x.reshape(batch, -1)
96 | else:
97 | x = torch.sigmoid(x)
98 | return x
99 |
100 |
--------------------------------------------------------------------------------
/models/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | BatchNorm2d = nn.BatchNorm2d
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'deformable_resnet18', 'deformable_resnet50',
8 | 'resnet152']
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | def constant_init(module, constant, bias=0):
20 | nn.init.constant_(module.weight, constant)
21 | if hasattr(module, 'bias'):
22 | nn.init.constant_(module.bias, bias)
23 |
24 |
25 | def conv3x3(in_planes, out_planes, stride=1):
26 | """3x3 convolution with padding"""
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28 | padding=1, bias=False)
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | expansion = 1
33 |
34 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
35 | super(BasicBlock, self).__init__()
36 | self.with_dcn = dcn is not None
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | self.bn1 = BatchNorm2d(planes)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.with_modulated_dcn = False
41 | if not self.with_dcn:
42 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
43 | else:
44 | from torchvision.ops import DeformConv2d
45 | deformable_groups = dcn.get('deformable_groups', 1)
46 | offset_channels = 18
47 | self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1)
48 | self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, bias=False)
49 | self.bn2 = BatchNorm2d(planes)
50 | self.downsample = downsample
51 | self.stride = stride
52 |
53 | def forward(self, x):
54 | residual = x
55 |
56 | out = self.conv1(x)
57 | out = self.bn1(out)
58 | out = self.relu(out)
59 |
60 | # out = self.conv2(out)
61 | if not self.with_dcn:
62 | out = self.conv2(out)
63 | else:
64 | offset = self.conv2_offset(out)
65 | out = self.conv2(out, offset)
66 | out = self.bn2(out)
67 |
68 | if self.downsample is not None:
69 | residual = self.downsample(x)
70 |
71 | out += residual
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 |
77 | class Bottleneck(nn.Module):
78 | expansion = 4
79 |
80 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
81 | super(Bottleneck, self).__init__()
82 | self.with_dcn = dcn is not None
83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
84 | self.bn1 = BatchNorm2d(planes)
85 | self.with_modulated_dcn = False
86 | if not self.with_dcn:
87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
88 | else:
89 | deformable_groups = dcn.get('deformable_groups', 1)
90 | from torchvision.ops import DeformConv2d
91 | offset_channels = 18
92 | self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, stride=stride, kernel_size=3, padding=1)
93 | self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
94 | self.bn2 = BatchNorm2d(planes)
95 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
96 | self.bn3 = BatchNorm2d(planes * 4)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.downsample = downsample
99 | self.stride = stride
100 | self.dcn = dcn
101 | self.with_dcn = dcn is not None
102 |
103 | def forward(self, x):
104 | residual = x
105 |
106 | out = self.conv1(x)
107 | out = self.bn1(out)
108 | out = self.relu(out)
109 |
110 | # out = self.conv2(out)
111 | if not self.with_dcn:
112 | out = self.conv2(out)
113 | else:
114 | offset = self.conv2_offset(out)
115 | out = self.conv2(out, offset)
116 | out = self.bn2(out)
117 | out = self.relu(out)
118 |
119 | out = self.conv3(out)
120 | out = self.bn3(out)
121 |
122 | if self.downsample is not None:
123 | residual = self.downsample(x)
124 |
125 | out += residual
126 | out = self.relu(out)
127 |
128 | return out
129 |
130 |
131 | class ResNet(nn.Module):
132 | def __init__(self, block, layers, in_channels=3, dcn=None):
133 | self.dcn = dcn
134 | self.inplanes = 64
135 | super(ResNet, self).__init__()
136 | self.out_channels = []
137 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
138 | bias=False)
139 | self.bn1 = BatchNorm2d(64)
140 | self.relu = nn.ReLU(inplace=True)
141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
142 | self.layer1 = self._make_layer(block, 64, layers[0])
143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dcn=dcn)
144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dcn=dcn)
145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dcn=dcn)
146 |
147 | for m in self.modules():
148 | if isinstance(m, nn.Conv2d):
149 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
150 | m.weight.data.normal_(0, math.sqrt(2. / n))
151 | elif isinstance(m, BatchNorm2d):
152 | m.weight.data.fill_(1)
153 | m.bias.data.zero_()
154 | if self.dcn is not None:
155 | for m in self.modules():
156 | if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
157 | if hasattr(m, 'conv2_offset'):
158 | constant_init(m.conv2_offset, 0)
159 |
160 | def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
161 | downsample = None
162 | if stride != 1 or self.inplanes != planes * block.expansion:
163 | downsample = nn.Sequential(
164 | nn.Conv2d(self.inplanes, planes * block.expansion,
165 | kernel_size=1, stride=stride, bias=False),
166 | BatchNorm2d(planes * block.expansion),
167 | )
168 |
169 | layers = []
170 | layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn))
171 | self.inplanes = planes * block.expansion
172 | for i in range(1, blocks):
173 | layers.append(block(self.inplanes, planes, dcn=dcn))
174 | self.out_channels.append(planes * block.expansion)
175 | return nn.Sequential(*layers)
176 |
177 | def forward(self, x):
178 | x = self.conv1(x)
179 | x = self.bn1(x)
180 | x = self.relu(x)
181 | x = self.maxpool(x)
182 |
183 | x2 = self.layer1(x)
184 | x3 = self.layer2(x2)
185 | x4 = self.layer3(x3)
186 | x5 = self.layer4(x4)
187 |
188 | return x2, x3, x4, x5
189 |
190 |
191 | def resnet18(pretrained=True, **kwargs):
192 | """Constructs a ResNet-18 model.
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on ImageNet
195 | """
196 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
197 | if pretrained:
198 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
199 | print('load from imagenet')
200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
201 | return model
202 |
203 |
204 | def deformable_resnet18(pretrained=True, **kwargs):
205 | """Constructs a ResNet-18 model.
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | """
209 | model = ResNet(BasicBlock, [2, 2, 2, 2], dcn=dict(deformable_groups=1), **kwargs)
210 | if pretrained:
211 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
212 | print('load from imagenet')
213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
214 | return model
215 |
216 |
217 | def resnet34(pretrained=True, **kwargs):
218 | """Constructs a ResNet-34 model.
219 | Args:
220 | pretrained (bool): If True, returns a model pre-trained on ImageNet
221 | """
222 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
223 | if pretrained:
224 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
225 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
226 | return model
227 |
228 |
229 | def resnet50(pretrained=True, **kwargs):
230 | """Constructs a ResNet-50 model.
231 | Args:
232 | pretrained (bool): If True, returns a model pre-trained on ImageNet
233 | """
234 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
235 | if pretrained:
236 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
238 | return model
239 |
240 |
241 | def deformable_resnet50(pretrained=True, **kwargs):
242 | """Constructs a ResNet-50 model with deformable conv.
243 | Args:
244 | pretrained (bool): If True, returns a model pre-trained on ImageNet
245 | """
246 | model = ResNet(Bottleneck, [3, 4, 6, 3], dcn=dict(deformable_groups=1), **kwargs)
247 | if pretrained:
248 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
249 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
250 | return model
251 |
252 |
253 | def resnet101(pretrained=True, **kwargs):
254 | """Constructs a ResNet-101 model.
255 | Args:
256 | pretrained (bool): If True, returns a model pre-trained on ImageNet
257 | """
258 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
259 | if pretrained:
260 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
261 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
262 | return model
263 |
264 |
265 | def resnet152(pretrained=True, **kwargs):
266 | """Constructs a ResNet-152 model.
267 | Args:
268 | pretrained (bool): If True, returns a model pre-trained on ImageNet
269 | """
270 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
271 | if pretrained:
272 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
273 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
274 | return model
275 |
276 |
277 | if __name__ == '__main__':
278 | import torch
279 |
280 | x = torch.zeros(2, 3, 640, 640)
281 | net = deformable_resnet50(pretrained=False)
282 | y = net(x)
283 | for u in y:
284 | print(u.shape)
285 |
286 | print(net.out_channels)
287 |
--------------------------------------------------------------------------------
/models/backbone/shufflenetv2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/11/1 15:31
3 | # @Author : zhoujun
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision.models.utils import load_state_dict_from_url
8 |
9 | __all__ = [
10 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
11 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
12 | ]
13 |
14 | model_urls = {
15 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
16 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
17 | 'shufflenetv2_x1.5': None,
18 | 'shufflenetv2_x2.0': None,
19 | }
20 |
21 |
22 | def channel_shuffle(x, groups):
23 | batchsize, num_channels, height, width = x.data.size()
24 | channels_per_group = num_channels // groups
25 |
26 | # reshape
27 | x = x.view(batchsize, groups,
28 | channels_per_group, height, width)
29 |
30 | x = torch.transpose(x, 1, 2).contiguous()
31 |
32 | # flatten
33 | x = x.view(batchsize, -1, height, width)
34 |
35 | return x
36 |
37 |
38 | class InvertedResidual(nn.Module):
39 | def __init__(self, inp, oup, stride):
40 | super(InvertedResidual, self).__init__()
41 |
42 | if not (1 <= stride <= 3):
43 | raise ValueError('illegal stride value')
44 | self.stride = stride
45 |
46 | branch_features = oup // 2
47 | assert (self.stride != 1) or (inp == branch_features << 1)
48 |
49 | if self.stride > 1:
50 | self.branch1 = nn.Sequential(
51 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
52 | nn.BatchNorm2d(inp),
53 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
54 | nn.BatchNorm2d(branch_features),
55 | nn.ReLU(inplace=True),
56 | )
57 |
58 | self.branch2 = nn.Sequential(
59 | nn.Conv2d(inp if (self.stride > 1) else branch_features,
60 | branch_features, kernel_size=1, stride=1, padding=0, bias=False),
61 | nn.BatchNorm2d(branch_features),
62 | nn.ReLU(inplace=True),
63 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
64 | nn.BatchNorm2d(branch_features),
65 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
66 | nn.BatchNorm2d(branch_features),
67 | nn.ReLU(inplace=True),
68 | )
69 |
70 | @staticmethod
71 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
72 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
73 |
74 | def forward(self, x):
75 | if self.stride == 1:
76 | x1, x2 = x.chunk(2, dim=1)
77 | out = torch.cat((x1, self.branch2(x2)), dim=1)
78 | else:
79 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
80 |
81 | out = channel_shuffle(out, 2)
82 |
83 | return out
84 |
85 |
86 | class ShuffleNetV2(nn.Module):
87 | def __init__(self, stages_repeats, stages_out_channels, in_channels=3, **kwargs):
88 | super(ShuffleNetV2, self).__init__()
89 | self.out_channels = []
90 | if len(stages_repeats) != 3:
91 | raise ValueError('expected stages_repeats as list of 3 positive ints')
92 | if len(stages_out_channels) != 5:
93 | raise ValueError('expected stages_out_channels as list of 5 positive ints')
94 | self._stage_out_channels = stages_out_channels
95 |
96 | output_channels = self._stage_out_channels[0]
97 | self.conv1 = nn.Sequential(
98 | nn.Conv2d(in_channels, output_channels, 3, 2, 1, bias=False),
99 | nn.BatchNorm2d(output_channels),
100 | nn.ReLU(inplace=True),
101 | )
102 | input_channels = output_channels
103 |
104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
105 | self.out_channels.append(input_channels)
106 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
107 | for name, repeats, output_channels in zip(
108 | stage_names, stages_repeats, self._stage_out_channels[1:]):
109 | seq = [InvertedResidual(input_channels, output_channels, 2)]
110 | for i in range(repeats - 1):
111 | seq.append(InvertedResidual(output_channels, output_channels, 1))
112 | setattr(self, name, nn.Sequential(*seq))
113 | input_channels = output_channels
114 | self.out_channels.append(input_channels)
115 | output_channels = self._stage_out_channels[-1]
116 | self.conv5 = nn.Sequential(
117 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
118 | nn.BatchNorm2d(output_channels),
119 | nn.ReLU(inplace=True),
120 | )
121 |
122 | def forward(self, x):
123 | x = self.conv1(x)
124 | c2 = self.maxpool(x)
125 | c3 = self.stage2(c2)
126 | c4 = self.stage3(c3)
127 | c5 = self.stage4(c4)
128 | # c5 = self.conv5(c5)
129 | return c2, c3, c4, c5
130 |
131 |
132 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
133 | model = ShuffleNetV2(*args, **kwargs)
134 |
135 | if pretrained:
136 | model_url = model_urls[arch]
137 | if model_url is None:
138 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
139 | else:
140 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
141 | state_dict = load_state_dict_from_url(model_url, progress=progress)
142 | model.load_state_dict(state_dict, strict=False)
143 |
144 | return model
145 |
146 |
147 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
148 | """
149 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in
150 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
151 | `_.
152 |
153 | Args:
154 | pretrained (bool): If True, returns a model pre-trained on ImageNet
155 | progress (bool): If True, displays a progress bar of the download to stderr
156 | """
157 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
158 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
159 |
160 |
161 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
162 | """
163 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in
164 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
165 | `_.
166 |
167 | Args:
168 | pretrained (bool): If True, returns a model pre-trained on ImageNet
169 | progress (bool): If True, displays a progress bar of the download to stderr
170 | """
171 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
172 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
173 |
174 |
175 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
176 | """
177 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in
178 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
179 | `_.
180 |
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on ImageNet
183 | progress (bool): If True, displays a progress bar of the download to stderr
184 | """
185 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
186 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
187 |
188 |
189 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
190 | """
191 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in
192 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
193 | `_.
194 |
195 | Args:
196 | pretrained (bool): If True, returns a model pre-trained on ImageNet
197 | progress (bool): If True, displays a progress bar of the download to stderr
198 | """
199 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
200 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
201 |
--------------------------------------------------------------------------------
/models/basic.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/6 11:19
3 | # @Author : zhoujun
4 | from torch import nn
5 |
6 |
7 | class ConvBnRelu(nn.Module):
8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', inplace=True):
9 | super().__init__()
10 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
11 | groups=groups, bias=bias, padding_mode=padding_mode)
12 | self.bn = nn.BatchNorm2d(out_channels)
13 | self.relu = nn.ReLU(inplace=inplace)
14 |
15 | def forward(self, x):
16 | x = self.conv(x)
17 | x = self.bn(x)
18 | x = self.relu(x)
19 | return x
20 |
--------------------------------------------------------------------------------
/models/head/ConvHead.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/4 14:54
3 | # @Author : zhoujun
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class ConvHead(nn.Module):
9 | def __init__(self, in_channels, out_channels,**kwargs):
10 | super().__init__()
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
13 | nn.Sigmoid()
14 | )
15 |
16 | def forward(self, x):
17 | return self.conv(x)
--------------------------------------------------------------------------------
/models/head/DBHead.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/4 14:54
3 | # @Author : zhoujun
4 | import torch
5 | from torch import nn
6 |
7 | class DBHead(nn.Module):
8 | def __init__(self, in_channels, out_channels, k = 50):
9 | super().__init__()
10 | self.k = k
11 | self.binarize = nn.Sequential(
12 | nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
13 | nn.BatchNorm2d(in_channels // 4),
14 | nn.ReLU(inplace=True),
15 | nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
16 | nn.BatchNorm2d(in_channels // 4),
17 | nn.ReLU(inplace=True),
18 | nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),
19 | nn.Sigmoid())
20 | self.binarize.apply(self.weights_init)
21 |
22 | self.thresh = self._init_thresh(in_channels)
23 | self.thresh.apply(self.weights_init)
24 |
25 | def forward(self, x):
26 | shrink_maps = self.binarize(x)
27 | threshold_maps = self.thresh(x)
28 | if self.training:
29 | binary_maps = self.step_function(shrink_maps, threshold_maps)
30 | y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
31 | else:
32 | y = torch.cat((shrink_maps, threshold_maps), dim=1)
33 | return y
34 |
35 | def weights_init(self, m):
36 | classname = m.__class__.__name__
37 | if classname.find('Conv') != -1:
38 | nn.init.kaiming_normal_(m.weight.data)
39 | elif classname.find('BatchNorm') != -1:
40 | m.weight.data.fill_(1.)
41 | m.bias.data.fill_(1e-4)
42 |
43 | def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
44 | in_channels = inner_channels
45 | if serial:
46 | in_channels += 1
47 | self.thresh = nn.Sequential(
48 | nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
49 | nn.BatchNorm2d(inner_channels // 4),
50 | nn.ReLU(inplace=True),
51 | self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
52 | nn.BatchNorm2d(inner_channels // 4),
53 | nn.ReLU(inplace=True),
54 | self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
55 | nn.Sigmoid())
56 | return self.thresh
57 |
58 | def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
59 | if smooth:
60 | inter_out_channels = out_channels
61 | if out_channels == 1:
62 | inter_out_channels = in_channels
63 | module_list = [
64 | nn.Upsample(scale_factor=2, mode='nearest'),
65 | nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
66 | if out_channels == 1:
67 | module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
68 | return nn.Sequential(module_list)
69 | else:
70 | return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
71 |
72 | def step_function(self, x, y):
73 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
74 |
--------------------------------------------------------------------------------
/models/head/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/6/5 11:35
3 | # @Author : zhoujun
4 | from .DBHead import DBHead
5 | from .ConvHead import ConvHead
6 |
7 | __all__ = ['build_head']
8 | support_head = ['ConvHead', 'DBHead']
9 |
10 |
11 | def build_head(head_name, **kwargs):
12 | assert head_name in support_head, f'all support head is {support_head}'
13 | head = eval(head_name)(**kwargs)
14 | return head
--------------------------------------------------------------------------------
/models/losses/DB_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:56
3 | # @Author : zhoujun
4 | from torch import nn
5 |
6 | from models.losses.basic_loss import BalanceCrossEntropyLoss, MaskL1Loss, DiceLoss
7 |
8 |
9 | class DBLoss(nn.Module):
10 | def __init__(self, alpha=1.0, beta=10, ohem_ratio=3, reduction='mean', eps=1e-6):
11 | """
12 | Implement PSE Loss.
13 | :param alpha: binary_map loss 前面的系数
14 | :param beta: threshold_map loss 前面的系数
15 | :param ohem_ratio: OHEM的比例
16 | :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
17 | """
18 | super().__init__()
19 | assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
20 | self.alpha = alpha
21 | self.beta = beta
22 | self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
23 | self.dice_loss = DiceLoss(eps=eps)
24 | self.l1_loss = MaskL1Loss(eps=eps)
25 | self.ohem_ratio = ohem_ratio
26 | self.reduction = reduction
27 |
28 | def forward(self, pred, batch):
29 | shrink_maps = pred[:, 0, :, :]
30 | threshold_maps = pred[:, 1, :, :]
31 | binary_maps = pred[:, 2, :, :]
32 |
33 | loss_shrink_maps = self.bce_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask'])
34 | loss_threshold_maps = self.l1_loss(threshold_maps, batch['threshold_map'], batch['threshold_mask'])
35 | metrics = dict(loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps)
36 | if pred.size()[1] > 2:
37 | loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'], batch['shrink_mask'])
38 | metrics['loss_binary_maps'] = loss_binary_maps
39 | loss_all = self.alpha * loss_shrink_maps + self.beta * loss_threshold_maps + loss_binary_maps
40 | metrics['loss'] = loss_all
41 | else:
42 | metrics['loss'] = loss_shrink_maps
43 | return metrics
44 |
--------------------------------------------------------------------------------
/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/6/5 11:36
3 | # @Author : zhoujun
4 | import copy
5 | from .DB_loss import DBLoss
6 |
7 | __all__ = ['build_loss']
8 | support_loss = ['DBLoss']
9 |
10 | def build_loss(config):
11 | copy_config = copy.deepcopy(config)
12 | loss_type = copy_config.pop('type')
13 | assert loss_type in support_loss, f'all support loss is {support_loss}'
14 | criterion = eval(loss_type)(**copy_config)
15 | return criterion
16 |
--------------------------------------------------------------------------------
/models/losses/basic_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/4 14:39
3 | # @Author : zhoujun
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class BalanceCrossEntropyLoss(nn.Module):
9 | '''
10 | Balanced cross entropy loss.
11 | Shape:
12 | - Input: :math:`(N, 1, H, W)`
13 | - GT: :math:`(N, 1, H, W)`, same shape as the input
14 | - Mask: :math:`(N, H, W)`, same spatial shape as the input
15 | - Output: scalar.
16 |
17 | Examples::
18 |
19 | >>> m = nn.Sigmoid()
20 | >>> loss = nn.BCELoss()
21 | >>> input = torch.randn(3, requires_grad=True)
22 | >>> target = torch.empty(3).random_(2)
23 | >>> output = loss(m(input), target)
24 | >>> output.backward()
25 | '''
26 |
27 | def __init__(self, negative_ratio=3.0, eps=1e-6):
28 | super(BalanceCrossEntropyLoss, self).__init__()
29 | self.negative_ratio = negative_ratio
30 | self.eps = eps
31 |
32 | def forward(self,
33 | pred: torch.Tensor,
34 | gt: torch.Tensor,
35 | mask: torch.Tensor,
36 | return_origin=False):
37 | '''
38 | Args:
39 | pred: shape :math:`(N, 1, H, W)`, the prediction of network
40 | gt: shape :math:`(N, 1, H, W)`, the target
41 | mask: shape :math:`(N, H, W)`, the mask indicates positive regions
42 | '''
43 | positive = (gt * mask).byte()
44 | negative = ((1 - gt) * mask).byte()
45 | positive_count = int(positive.float().sum())
46 | negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))
47 | loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
48 | positive_loss = loss * positive.float()
49 | negative_loss = loss * negative.float()
50 | # negative_loss, _ = torch.topk(negative_loss.view(-1).contiguous(), negative_count)
51 | negative_loss, _ = negative_loss.view(-1).topk(negative_count)
52 |
53 | balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps)
54 |
55 | if return_origin:
56 | return balance_loss, loss
57 | return balance_loss
58 |
59 |
60 | class DiceLoss(nn.Module):
61 | '''
62 | Loss function from https://arxiv.org/abs/1707.03237,
63 | where iou computation is introduced heatmap manner to measure the
64 | diversity bwtween tow heatmaps.
65 | '''
66 |
67 | def __init__(self, eps=1e-6):
68 | super(DiceLoss, self).__init__()
69 | self.eps = eps
70 |
71 | def forward(self, pred: torch.Tensor, gt, mask, weights=None):
72 | '''
73 | pred: one or two heatmaps of shape (N, 1, H, W),
74 | the losses of tow heatmaps are added together.
75 | gt: (N, 1, H, W)
76 | mask: (N, H, W)
77 | '''
78 | return self._compute(pred, gt, mask, weights)
79 |
80 | def _compute(self, pred, gt, mask, weights):
81 | if pred.dim() == 4:
82 | pred = pred[:, 0, :, :]
83 | gt = gt[:, 0, :, :]
84 | assert pred.shape == gt.shape
85 | assert pred.shape == mask.shape
86 | if weights is not None:
87 | assert weights.shape == mask.shape
88 | mask = weights * mask
89 | intersection = (pred * gt * mask).sum()
90 |
91 | union = (pred * mask).sum() + (gt * mask).sum() + self.eps
92 | loss = 1 - 2.0 * intersection / union
93 | assert loss <= 1
94 | return loss
95 |
96 |
97 | class MaskL1Loss(nn.Module):
98 | def __init__(self, eps=1e-6):
99 | super(MaskL1Loss, self).__init__()
100 | self.eps = eps
101 |
102 | def forward(self, pred: torch.Tensor, gt, mask):
103 | loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
104 | return loss
105 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:57
3 | # @Author : zhoujun
4 | from addict import Dict
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 | from models.backbone import build_backbone
9 | from models.neck import build_neck
10 | from models.head import build_head
11 |
12 |
13 | class Model(nn.Module):
14 | def __init__(self, model_config: dict):
15 | """
16 | PANnet
17 | :param model_config: 模型配置
18 | """
19 | super().__init__()
20 | model_config = Dict(model_config)
21 | backbone_type = model_config.backbone.pop('type')
22 | neck_type = model_config.neck.pop('type')
23 | head_type = model_config.head.pop('type')
24 | self.backbone = build_backbone(backbone_type, **model_config.backbone)
25 | self.neck = build_neck(neck_type, in_channels=self.backbone.out_channels, **model_config.neck)
26 | self.head = build_head(head_type, in_channels=self.neck.out_channels, **model_config.head)
27 | self.name = f'{backbone_type}_{neck_type}_{head_type}'
28 |
29 | def forward(self, x):
30 | _, _, H, W = x.size()
31 | backbone_out = self.backbone(x)
32 | neck_out = self.neck(backbone_out)
33 | y = self.head(neck_out)
34 | y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True)
35 | return y
36 |
37 |
38 | if __name__ == '__main__':
39 | import torch
40 |
41 | device = torch.device('cpu')
42 | x = torch.zeros(2, 3, 640, 640).to(device)
43 |
44 | model_config = {
45 | 'backbone': {'type': 'resnest50', 'pretrained': True, "in_channels": 3},
46 | 'neck': {'type': 'FPN', 'inner_channels': 256}, # 分割头,FPN or FPEM_FFM
47 | 'head': {'type': 'DBHead', 'out_channels': 2, 'k': 50},
48 | }
49 | model = Model(model_config=model_config).to(device)
50 | import time
51 |
52 | tic = time.time()
53 | y = model(x)
54 | print(time.time() - tic)
55 | print(y.shape)
56 | print(model.name)
57 | print(model)
58 | # torch.save(model.state_dict(), 'PAN.pth')
59 |
--------------------------------------------------------------------------------
/models/neck/FPEM_FFM.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/9/13 10:29
3 | # @Author : zhoujun
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from models.basic import ConvBnRelu
9 |
10 |
11 | class FPEM_FFM(nn.Module):
12 | def __init__(self, in_channels, inner_channels=128, fpem_repeat=2, **kwargs):
13 | """
14 | PANnet
15 | :param in_channels: 基础网络输出的维度
16 | """
17 | super().__init__()
18 | self.conv_out = inner_channels
19 | inplace = True
20 | # reduce layers
21 | self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace)
22 | self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace)
23 | self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace)
24 | self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace)
25 | self.fpems = nn.ModuleList()
26 | for i in range(fpem_repeat):
27 | self.fpems.append(FPEM(self.conv_out))
28 | self.out_channels = self.conv_out * 4
29 |
30 | def forward(self, x):
31 | c2, c3, c4, c5 = x
32 | # reduce channel
33 | c2 = self.reduce_conv_c2(c2)
34 | c3 = self.reduce_conv_c3(c3)
35 | c4 = self.reduce_conv_c4(c4)
36 | c5 = self.reduce_conv_c5(c5)
37 |
38 | # FPEM
39 | for i, fpem in enumerate(self.fpems):
40 | c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
41 | if i == 0:
42 | c2_ffm = c2
43 | c3_ffm = c3
44 | c4_ffm = c4
45 | c5_ffm = c5
46 | else:
47 | c2_ffm += c2
48 | c3_ffm += c3
49 | c4_ffm += c4
50 | c5_ffm += c5
51 |
52 | # FFM
53 | c5 = F.interpolate(c5_ffm, c2_ffm.size()[-2:])
54 | c4 = F.interpolate(c4_ffm, c2_ffm.size()[-2:])
55 | c3 = F.interpolate(c3_ffm, c2_ffm.size()[-2:])
56 | Fy = torch.cat([c2_ffm, c3, c4, c5], dim=1)
57 | return Fy
58 |
59 |
60 | class FPEM(nn.Module):
61 | def __init__(self, in_channels=128):
62 | super().__init__()
63 | self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
64 | self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
65 | self.up_add3 = SeparableConv2d(in_channels, in_channels, 1)
66 | self.down_add1 = SeparableConv2d(in_channels, in_channels, 2)
67 | self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
68 | self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
69 |
70 | def forward(self, c2, c3, c4, c5):
71 | # up阶段
72 | c4 = self.up_add1(self._upsample_add(c5, c4))
73 | c3 = self.up_add2(self._upsample_add(c4, c3))
74 | c2 = self.up_add3(self._upsample_add(c3, c2))
75 |
76 | # down 阶段
77 | c3 = self.down_add1(self._upsample_add(c3, c2))
78 | c4 = self.down_add2(self._upsample_add(c4, c3))
79 | c5 = self.down_add3(self._upsample_add(c5, c4))
80 | return c2, c3, c4, c5
81 |
82 | def _upsample_add(self, x, y):
83 | return F.interpolate(x, size=y.size()[2:]) + y
84 |
85 |
86 | class SeparableConv2d(nn.Module):
87 | def __init__(self, in_channels, out_channels, stride=1):
88 | super(SeparableConv2d, self).__init__()
89 |
90 | self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1,
91 | stride=stride, groups=in_channels)
92 | self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
93 | self.bn = nn.BatchNorm2d(out_channels)
94 | self.relu = nn.ReLU()
95 |
96 | def forward(self, x):
97 | x = self.depthwise_conv(x)
98 | x = self.pointwise_conv(x)
99 | x = self.bn(x)
100 | x = self.relu(x)
101 | return x
102 |
--------------------------------------------------------------------------------
/models/neck/FPN.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/9/13 10:29
3 | # @Author : zhoujun
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from models.basic import ConvBnRelu
9 |
10 |
11 | class FPN(nn.Module):
12 | def __init__(self, in_channels, inner_channels=256, **kwargs):
13 | """
14 | :param in_channels: 基础网络输出的维度
15 | :param kwargs:
16 | """
17 | super().__init__()
18 | inplace = True
19 | self.conv_out = inner_channels
20 | inner_channels = inner_channels // 4
21 | # reduce layers
22 | self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace)
23 | self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace)
24 | self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace)
25 | self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace)
26 | # Smooth layers
27 | self.smooth_p4 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
28 | self.smooth_p3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
29 | self.smooth_p2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
30 |
31 | self.conv = nn.Sequential(
32 | nn.Conv2d(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1),
33 | nn.BatchNorm2d(self.conv_out),
34 | nn.ReLU(inplace=inplace)
35 | )
36 | self.out_channels = self.conv_out
37 |
38 | def forward(self, x):
39 | c2, c3, c4, c5 = x
40 | # Top-down
41 | p5 = self.reduce_conv_c5(c5)
42 | p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
43 | p4 = self.smooth_p4(p4)
44 | p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
45 | p3 = self.smooth_p3(p3)
46 | p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
47 | p2 = self.smooth_p2(p2)
48 |
49 | x = self._upsample_cat(p2, p3, p4, p5)
50 | x = self.conv(x)
51 | return x
52 |
53 | def _upsample_add(self, x, y):
54 | return F.interpolate(x, size=y.size()[2:]) + y
55 |
56 | def _upsample_cat(self, p2, p3, p4, p5):
57 | h, w = p2.size()[2:]
58 | p3 = F.interpolate(p3, size=(h, w))
59 | p4 = F.interpolate(p4, size=(h, w))
60 | p5 = F.interpolate(p5, size=(h, w))
61 | return torch.cat([p2, p3, p4, p5], dim=1)
62 |
--------------------------------------------------------------------------------
/models/neck/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/6/5 11:34
3 | # @Author : zhoujun
4 | from .FPN import FPN
5 | from .FPEM_FFM import FPEM_FFM
6 |
7 | __all__ = ['build_neck']
8 | support_neck = ['FPN', 'FPEM_FFM']
9 |
10 |
11 | def build_neck(neck_name, **kwargs):
12 | assert neck_name in support_neck, f'all support neck is {support_neck}'
13 | neck = eval(neck_name)(**kwargs)
14 | return neck
15 |
--------------------------------------------------------------------------------
/multi_gpu_train.sh:
--------------------------------------------------------------------------------
1 | # export NCCL_P2P_DISABLE=1
2 | export NGPUS=4
3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train.py --config_file "config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml"
--------------------------------------------------------------------------------
/post_processing/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/5 15:17
3 | # @Author : zhoujun
4 |
5 | from .seg_detector_representer import SegDetectorRepresenter
6 |
7 |
8 | def get_post_processing(config):
9 | try:
10 | cls = eval(config['type'])(**config['args'])
11 | return cls
12 | except:
13 | return None
--------------------------------------------------------------------------------
/post_processing/seg_detector_representer.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import pyclipper
4 | from shapely.geometry import Polygon
5 |
6 |
7 | class SegDetectorRepresenter():
8 | def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5):
9 | self.min_size = 3
10 | self.thresh = thresh
11 | self.box_thresh = box_thresh
12 | self.max_candidates = max_candidates
13 | self.unclip_ratio = unclip_ratio
14 |
15 | def __call__(self, batch, pred, is_output_polygon=False):
16 | '''
17 | batch: (image, polygons, ignore_tags
18 | batch: a dict produced by dataloaders.
19 | image: tensor of shape (N, C, H, W).
20 | polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
21 | ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
22 | shape: the original shape of images.
23 | filename: the original filenames of images.
24 | pred:
25 | binary: text region segmentation map, with shape (N, H, W)
26 | thresh: [if exists] thresh hold prediction with shape (N, H, W)
27 | thresh_binary: [if exists] binarized with threshhold, (N, H, W)
28 | '''
29 | pred = pred[:, 0, :, :]
30 | segmentation = self.binarize(pred)
31 | boxes_batch = []
32 | scores_batch = []
33 | for batch_index in range(pred.size(0)):
34 | height, width = batch['shape'][batch_index]
35 | if is_output_polygon:
36 | boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
37 | else:
38 | boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
39 | boxes_batch.append(boxes)
40 | scores_batch.append(scores)
41 | return boxes_batch, scores_batch
42 |
43 | def binarize(self, pred):
44 | return pred > self.thresh
45 |
46 | def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
47 | '''
48 | _bitmap: single map with shape (H, W),
49 | whose values are binarized as {0, 1}
50 | '''
51 |
52 | assert len(_bitmap.shape) == 2
53 | bitmap = _bitmap.cpu().numpy() # The first channel
54 | pred = pred.cpu().detach().numpy()
55 | height, width = bitmap.shape
56 | boxes = []
57 | scores = []
58 |
59 | contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
60 |
61 | for contour in contours[:self.max_candidates]:
62 | epsilon = 0.005 * cv2.arcLength(contour, True)
63 | approx = cv2.approxPolyDP(contour, epsilon, True)
64 | points = approx.reshape((-1, 2))
65 | if points.shape[0] < 4:
66 | continue
67 | # _, sside = self.get_mini_boxes(contour)
68 | # if sside < self.min_size:
69 | # continue
70 | score = self.box_score_fast(pred, contour.squeeze(1))
71 | if self.box_thresh > score:
72 | continue
73 |
74 | if points.shape[0] > 2:
75 | box = self.unclip(points, unclip_ratio=self.unclip_ratio)
76 | if len(box) > 1:
77 | continue
78 | else:
79 | continue
80 | box = box.reshape(-1, 2)
81 | _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
82 | if sside < self.min_size + 2:
83 | continue
84 |
85 | if not isinstance(dest_width, int):
86 | dest_width = dest_width.item()
87 | dest_height = dest_height.item()
88 |
89 | box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
90 | box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
91 | boxes.append(box)
92 | scores.append(score)
93 | return boxes, scores
94 |
95 | def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
96 | '''
97 | _bitmap: single map with shape (H, W),
98 | whose values are binarized as {0, 1}
99 | '''
100 |
101 | assert len(_bitmap.shape) == 2
102 | bitmap = _bitmap.cpu().numpy() # The first channel
103 | pred = pred.cpu().detach().numpy()
104 | height, width = bitmap.shape
105 | contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
106 | num_contours = min(len(contours), self.max_candidates)
107 | boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
108 | scores = np.zeros((num_contours,), dtype=np.float32)
109 |
110 | for index in range(num_contours):
111 | contour = contours[index].squeeze(1)
112 | points, sside = self.get_mini_boxes(contour)
113 | if sside < self.min_size:
114 | continue
115 | points = np.array(points)
116 | score = self.box_score_fast(pred, contour)
117 | if self.box_thresh > score:
118 | continue
119 |
120 | box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
121 | box, sside = self.get_mini_boxes(box)
122 | if sside < self.min_size + 2:
123 | continue
124 | box = np.array(box)
125 | if not isinstance(dest_width, int):
126 | dest_width = dest_width.item()
127 | dest_height = dest_height.item()
128 |
129 | box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
130 | box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
131 | boxes[index, :, :] = box.astype(np.int16)
132 | scores[index] = score
133 | return boxes, scores
134 |
135 | def unclip(self, box, unclip_ratio=1.5):
136 | poly = Polygon(box)
137 | distance = poly.area * unclip_ratio / poly.length
138 | offset = pyclipper.PyclipperOffset()
139 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
140 | expanded = np.array(offset.Execute(distance))
141 | return expanded
142 |
143 | def get_mini_boxes(self, contour):
144 | bounding_box = cv2.minAreaRect(contour)
145 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
146 |
147 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3
148 | if points[1][1] > points[0][1]:
149 | index_1 = 0
150 | index_4 = 1
151 | else:
152 | index_1 = 1
153 | index_4 = 0
154 | if points[3][1] > points[2][1]:
155 | index_2 = 2
156 | index_3 = 3
157 | else:
158 | index_2 = 3
159 | index_3 = 2
160 |
161 | box = [points[index_1], points[index_2], points[index_3], points[index_4]]
162 | return box, min(bounding_box[1])
163 |
164 | def box_score_fast(self, bitmap, _box):
165 | h, w = bitmap.shape[:2]
166 | box = _box.copy()
167 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
168 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
169 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
170 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
171 |
172 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
173 | box[:, 0] = box[:, 0] - xmin
174 | box[:, 1] = box[:, 1] - ymin
175 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
176 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
177 |
--------------------------------------------------------------------------------
/predict.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python tools/predict.py --model_path model_best.pth --input_folder ./input --output_folder ./output --thre 0.7 --polygon --show --save_result
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | anyconfig==0.9.10
2 | future==0.18.2
3 | imgaug==0.4.0
4 | matplotlib==3.1.2
5 | numpy==1.17.4
6 | opencv-python==4.1.2.30
7 | Polygon3==3.0.8
8 | pyclipper==1.1.0.post3
9 | PyYAML==5.2
10 | scikit-image==0.16.2
11 | Shapely==1.6.4.post2
12 | tensorboard==2.1.0
13 | tqdm==4.40.1
14 | torch==1.4
15 | torchvision==0.5
16 |
--------------------------------------------------------------------------------
/singlel_gpu_train.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python3 tools/train.py --config_file "config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml"
--------------------------------------------------------------------------------
/test/README.MD:
--------------------------------------------------------------------------------
1 | Place the images that you want to detect here. You better named them as such:
2 | img_10.jpg
3 | img_11.jpg
4 | img_{img_id}.jpg
5 |
6 | For predicting single images, you can change the `img_path` in the `/tools/predict.py` to your image number.
7 |
8 | The result will be saved in the output_folder(default is test/output) you give in predict.sh
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/8 13:14
3 | # @Author : zhoujun
--------------------------------------------------------------------------------
/tools/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2018/6/11 15:54
3 | # @Author : zhoujun
4 | import os
5 | import sys
6 | import pathlib
7 | __dir__ = pathlib.Path(os.path.abspath(__file__))
8 | sys.path.append(str(__dir__))
9 | sys.path.append(str(__dir__.parent.parent))
10 | # project = 'DBNet.pytorch' # 工作项目根目录
11 | # sys.path.append(os.getcwd().split(project)[0] + project)
12 |
13 | import argparse
14 | import time
15 | import torch
16 | from tqdm.auto import tqdm
17 |
18 |
19 | class EVAL():
20 | def __init__(self, model_path, gpu_id=0):
21 | from models import build_model
22 | from data_loader import get_dataloader
23 | from post_processing import get_post_processing
24 | from utils import get_metric
25 | self.gpu_id = gpu_id
26 | if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
27 | self.device = torch.device("cuda:%s" % self.gpu_id)
28 | torch.backends.cudnn.benchmark = True
29 | else:
30 | self.device = torch.device("cpu")
31 | checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
32 | config = checkpoint['config']
33 | config['arch']['backbone']['pretrained'] = False
34 |
35 | self.validate_loader = get_dataloader(config['dataset']['validate'], config['distributed'])
36 |
37 | self.model = build_model(config['arch'])
38 | self.model.load_state_dict(checkpoint['state_dict'])
39 | self.model.to(self.device)
40 |
41 | self.post_process = get_post_processing(config['post_processing'])
42 | self.metric_cls = get_metric(config['metric'])
43 |
44 | def eval(self):
45 | self.model.eval()
46 | # torch.cuda.empty_cache() # speed up evaluating after training finished
47 | raw_metrics = []
48 | total_frame = 0.0
49 | total_time = 0.0
50 | for i, batch in tqdm(enumerate(self.validate_loader), total=len(self.validate_loader), desc='test model'):
51 | with torch.no_grad():
52 | # 数据进行转换和丢到gpu
53 | for key, value in batch.items():
54 | if value is not None:
55 | if isinstance(value, torch.Tensor):
56 | batch[key] = value.to(self.device)
57 | start = time.time()
58 | preds = self.model(batch['img'])
59 | boxes, scores = self.post_process(batch, preds,is_output_polygon=self.metric_cls.is_output_polygon)
60 | total_frame += batch['img'].size()[0]
61 | total_time += time.time() - start
62 | raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
63 | raw_metrics.append(raw_metric)
64 | metrics = self.metric_cls.gather_measure(raw_metrics)
65 | print('FPS:{}'.format(total_frame / total_time))
66 | return metrics['recall'].avg, metrics['precision'].avg, metrics['fmeasure'].avg
67 |
68 |
69 | def init_args():
70 | parser = argparse.ArgumentParser(description='DBNet.pytorch')
71 | parser.add_argument('--model_path', required=False,default='output/DBNet_resnet18_FPN_DBHead/checkpoint/1.pth', type=str)
72 | args = parser.parse_args()
73 | return args
74 |
75 |
76 | if __name__ == '__main__':
77 | args = init_args()
78 | eval = EVAL(args.model_path)
79 | result = eval.eval()
80 | print(result)
81 |
--------------------------------------------------------------------------------
/tools/predict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/24 12:06
3 | # @Author : zhoujun
4 |
5 | import os
6 | import sys
7 | import pathlib
8 | __dir__ = pathlib.Path(os.path.abspath(__file__))
9 | sys.path.append(str(__dir__))
10 | sys.path.append(str(__dir__.parent.parent))
11 |
12 | # project = 'DBNet.pytorch' # 工作项目根目录
13 | # sys.path.append(os.getcwd().split(project)[0] + project)
14 | import time
15 | import cv2
16 | import torch
17 |
18 | from data_loader import get_transforms
19 | from models import build_model
20 | from post_processing import get_post_processing
21 |
22 |
23 | def resize_image(img, short_size):
24 | height, width, _ = img.shape
25 | if height < width:
26 | new_height = short_size
27 | new_width = new_height / height * width
28 | else:
29 | new_width = short_size
30 | new_height = new_width / width * height
31 | new_height = int(round(new_height / 32) * 32)
32 | new_width = int(round(new_width / 32) * 32)
33 | resized_img = cv2.resize(img, (new_width, new_height))
34 | return resized_img
35 |
36 |
37 | class Pytorch_model:
38 | def __init__(self, model_path, post_p_thre=0.7, gpu_id=None):
39 | '''
40 | 初始化pytorch模型
41 | :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
42 | :param gpu_id: 在哪一块gpu上运行
43 | '''
44 | self.gpu_id = gpu_id
45 |
46 | if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
47 | self.device = torch.device("cuda:%s" % self.gpu_id)
48 | else:
49 | self.device = torch.device("cpu")
50 | print('device:', self.device)
51 | checkpoint = torch.load(model_path, map_location=self.device)
52 |
53 | config = checkpoint['config']
54 | config['arch']['backbone']['pretrained'] = False
55 | self.model = build_model(config['arch'])
56 | self.post_process = get_post_processing(config['post_processing'])
57 | self.post_process.box_thresh = post_p_thre
58 | self.img_mode = config['dataset']['train']['dataset']['args']['img_mode']
59 | self.model.load_state_dict(checkpoint['state_dict'])
60 | self.model.to(self.device)
61 | self.model.eval()
62 |
63 | self.transform = []
64 | for t in config['dataset']['train']['dataset']['args']['transforms']:
65 | if t['type'] in ['ToTensor', 'Normalize']:
66 | self.transform.append(t)
67 | self.transform = get_transforms(self.transform)
68 |
69 | def predict(self, img_path: str, is_output_polygon=False, short_size: int = 1024):
70 | '''
71 | 对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢
72 | :param img_path: 图像地址
73 | :param is_numpy:
74 | :return:
75 | '''
76 | assert os.path.exists(img_path), 'file is not exists'
77 | img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0)
78 | if self.img_mode == 'RGB':
79 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
80 | h, w = img.shape[:2]
81 | img = resize_image(img, short_size)
82 | # 将图片由(w,h)变为(1,img_channel,h,w)
83 | tensor = self.transform(img)
84 | tensor = tensor.unsqueeze_(0)
85 |
86 | tensor = tensor.to(self.device)
87 | batch = {'shape': [(h, w)]}
88 | with torch.no_grad():
89 | if str(self.device).__contains__('cuda'):
90 | torch.cuda.synchronize(self.device)
91 | start = time.time()
92 | preds = self.model(tensor)
93 | if str(self.device).__contains__('cuda'):
94 | torch.cuda.synchronize(self.device)
95 | box_list, score_list = self.post_process(batch, preds, is_output_polygon=is_output_polygon)
96 | box_list, score_list = box_list[0], score_list[0]
97 | if len(box_list) > 0:
98 | if is_output_polygon:
99 | idx = [x.sum() > 0 for x in box_list]
100 | box_list = [box_list[i] for i, v in enumerate(idx) if v]
101 | score_list = [score_list[i] for i, v in enumerate(idx) if v]
102 | else:
103 | idx = box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0 # 去掉全为0的框
104 | box_list, score_list = box_list[idx], score_list[idx]
105 | else:
106 | box_list, score_list = [], []
107 | t = time.time() - start
108 | return preds[0, 0, :, :].detach().cpu().numpy(), box_list, score_list, t
109 |
110 |
111 | def save_depoly(model, input, save_path):
112 | traced_script_model = torch.jit.trace(model, input)
113 | traced_script_model.save(save_path)
114 |
115 |
116 | def init_args():
117 | import argparse
118 | parser = argparse.ArgumentParser(description='DBNet.pytorch')
119 | parser.add_argument('--model_path', default=r'model_best.pth', type=str)
120 | parser.add_argument('--input_folder', default='./test/input', type=str, help='img path for predict')
121 | parser.add_argument('--output_folder', default='./test/output', type=str, help='img path for output')
122 | parser.add_argument('--thre', default=0.3,type=float, help='the thresh of post_processing')
123 | parser.add_argument('--polygon', action='store_true', help='output polygon or box')
124 | parser.add_argument('--show', action='store_true', help='show result')
125 | parser.add_argument('--save_resut', action='store_true', help='save box and score to txt file')
126 | args = parser.parse_args()
127 | return args
128 |
129 |
130 | if __name__ == '__main__':
131 | import pathlib
132 | from tqdm import tqdm
133 | import matplotlib.pyplot as plt
134 | from utils.util import show_img, draw_bbox, save_result, get_file_list
135 |
136 | args = init_args()
137 | print(args)
138 | os.environ['CUDA_VISIBLE_DEVICES'] = str('0')
139 | # 初始化网络
140 | model = Pytorch_model(args.model_path, post_p_thre=args.thre, gpu_id=0)
141 | img_folder = pathlib.Path(args.input_folder)
142 | for img_path in tqdm(get_file_list(args.input_folder, p_postfix=['.jpg'])):
143 | preds, boxes_list, score_list, t = model.predict(img_path, is_output_polygon=args.polygon)
144 | img = draw_bbox(cv2.imread(img_path)[:, :, ::-1], boxes_list)
145 | if args.show:
146 | show_img(preds)
147 | show_img(img, title=os.path.basename(img_path))
148 | plt.show()
149 | # 保存结果到路径
150 | os.makedirs(args.output_folder, exist_ok=True)
151 | img_path = pathlib.Path(img_path)
152 | output_path = os.path.join(args.output_folder, img_path.stem + '_result.jpg')
153 | pred_path = os.path.join(args.output_folder, img_path.stem + '_pred.jpg')
154 | cv2.imwrite(output_path, img[:, :, ::-1])
155 | cv2.imwrite(pred_path, preds * 255)
156 | save_result(output_path.replace('_result.jpg', '.txt'), boxes_list, score_list, args.polygon)
157 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 22:00
3 | # @Author : zhoujun
4 |
5 | from __future__ import print_function
6 |
7 | import argparse
8 | import os
9 |
10 | import anyconfig
11 |
12 |
13 | def init_args():
14 | parser = argparse.ArgumentParser(description='DBNet.pytorch')
15 | parser.add_argument('--config_file', default='config/open_dataset_resnet18_FPN_DBhead_polyLR.yaml', type=str)
16 | parser.add_argument('--local_rank', dest='local_rank', default=0, type=int, help='Use distributed training')
17 |
18 | args = parser.parse_args()
19 | return args
20 |
21 |
22 | def main(config):
23 | import torch
24 | from models import build_model, build_loss
25 | from data_loader import get_dataloader
26 | from trainer import Trainer
27 | from post_processing import get_post_processing
28 | from utils import get_metric
29 | if torch.cuda.device_count() > 1:
30 | torch.cuda.set_device(args.local_rank)
31 | torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=torch.cuda.device_count(), rank=args.local_rank)
32 | config['distributed'] = True
33 | else:
34 | config['distributed'] = False
35 | config['local_rank'] = args.local_rank
36 |
37 | train_loader = get_dataloader(config['dataset']['train'], config['distributed'])
38 | assert train_loader is not None
39 | if 'validate' in config['dataset']:
40 | validate_loader = get_dataloader(config['dataset']['validate'], False)
41 | else:
42 | validate_loader = None
43 |
44 | criterion = build_loss(config['loss']).cuda()
45 |
46 | config['arch']['backbone']['in_channels'] = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1
47 | model = build_model(config['arch'])
48 |
49 | post_p = get_post_processing(config['post_processing'])
50 | metric = get_metric(config['metric'])
51 |
52 | trainer = Trainer(config=config,
53 | model=model,
54 | criterion=criterion,
55 | train_loader=train_loader,
56 | post_process=post_p,
57 | metric_cls=metric,
58 | validate_loader=validate_loader)
59 | trainer.train()
60 |
61 |
62 | if __name__ == '__main__':
63 | import sys
64 | import pathlib
65 | __dir__ = pathlib.Path(os.path.abspath(__file__))
66 | sys.path.append(str(__dir__))
67 | sys.path.append(str(__dir__.parent.parent))
68 | # project = 'DBNet.pytorch' # 工作项目根目录
69 | # sys.path.append(os.getcwd().split(project)[0] + project)
70 |
71 | from utils import parse_config
72 |
73 | args = init_args()
74 | assert os.path.exists(args.config_file)
75 | config = anyconfig.load(open(args.config_file, 'rb'))
76 | if 'base' in config:
77 | config = parse_config(config)
78 | main(config)
79 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:58
3 | # @Author : zhoujun
4 | from .trainer import Trainer
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:58
3 | # @Author : zhoujun
4 | import time
5 |
6 | import torch
7 | import torchvision.utils as vutils
8 | from tqdm import tqdm
9 |
10 | from base import BaseTrainer
11 | from utils import WarmupPolyLR, runningScore, cal_text_score
12 |
13 |
14 | class Trainer(BaseTrainer):
15 | def __init__(self, config, model, criterion, train_loader, validate_loader, metric_cls, post_process=None):
16 | super(Trainer, self).__init__(config, model, criterion)
17 | self.show_images_iter = self.config['trainer']['show_images_iter']
18 | self.train_loader = train_loader
19 | if validate_loader is not None:
20 | assert post_process is not None and metric_cls is not None
21 | self.validate_loader = validate_loader
22 | self.post_process = post_process
23 | self.metric_cls = metric_cls
24 | self.train_loader_len = len(train_loader)
25 | if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
26 | warmup_iters = config['lr_scheduler']['args']['warmup_epoch'] * self.train_loader_len
27 | if self.start_epoch > 1:
28 | self.config['lr_scheduler']['args']['last_epoch'] = (self.start_epoch - 1) * self.train_loader_len
29 | self.scheduler = WarmupPolyLR(self.optimizer, max_iters=self.epochs * self.train_loader_len,
30 | warmup_iters=warmup_iters, **config['lr_scheduler']['args'])
31 | if self.validate_loader is not None:
32 | self.logger_info(
33 | 'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'.format(
34 | len(self.train_loader.dataset), self.train_loader_len, len(self.validate_loader.dataset), len(self.validate_loader)))
35 | else:
36 | self.logger_info('train dataset has {} samples,{} in dataloader'.format(len(self.train_loader.dataset), self.train_loader_len))
37 |
38 | def _train_epoch(self, epoch):
39 | self.model.train()
40 | epoch_start = time.time()
41 | batch_start = time.time()
42 | train_loss = 0.
43 | running_metric_text = runningScore(2)
44 | lr = self.optimizer.param_groups[0]['lr']
45 |
46 | for i, batch in enumerate(self.train_loader):
47 | if i >= self.train_loader_len:
48 | break
49 | self.global_step += 1
50 | lr = self.optimizer.param_groups[0]['lr']
51 |
52 | # 数据进行转换和丢到gpu
53 | for key, value in batch.items():
54 | if value is not None:
55 | if isinstance(value, torch.Tensor):
56 | batch[key] = value.to(self.device)
57 | cur_batch_size = batch['img'].size()[0]
58 |
59 | preds = self.model(batch['img'])
60 | loss_dict = self.criterion(preds, batch)
61 | # backward
62 | self.optimizer.zero_grad()
63 | loss_dict['loss'].backward()
64 | self.optimizer.step()
65 | if self.config['lr_scheduler']['type'] == 'WarmupPolyLR':
66 | self.scheduler.step()
67 | # acc iou
68 | score_shrink_map = cal_text_score(preds[:, 0, :, :], batch['shrink_map'], batch['shrink_mask'], running_metric_text,
69 | thred=self.config['post_processing']['args']['thresh'])
70 |
71 | # loss 和 acc 记录到日志
72 | loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item())
73 | for idx, (key, value) in enumerate(loss_dict.items()):
74 | loss_dict[key] = value.item()
75 | if key == 'loss':
76 | continue
77 | loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
78 | if idx < len(loss_dict) - 1:
79 | loss_str += ', '
80 |
81 | train_loss += loss_dict['loss']
82 | acc = score_shrink_map['Mean Acc']
83 | iou_shrink_map = score_shrink_map['Mean IoU']
84 |
85 | if self.global_step % self.log_iter == 0:
86 | batch_time = time.time() - batch_start
87 | self.logger_info(
88 | '[{}/{}], [{}/{}], global_step: {}, speed: {:.1f} samples/sec, acc: {:.4f}, iou_shrink_map: {:.4f}, {}, lr:{:.6}, time:{:.2f}'.format(
89 | epoch, self.epochs, i + 1, self.train_loader_len, self.global_step, self.log_iter * cur_batch_size / batch_time, acc,
90 | iou_shrink_map, loss_str, lr, batch_time))
91 | batch_start = time.time()
92 |
93 | if self.tensorboard_enable and self.config['local_rank'] == 0:
94 | # write tensorboard
95 | for key, value in loss_dict.items():
96 | self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value, self.global_step)
97 | self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc, self.global_step)
98 | self.writer.add_scalar('TRAIN/ACC_IOU/iou_shrink_map', iou_shrink_map, self.global_step)
99 | self.writer.add_scalar('TRAIN/lr', lr, self.global_step)
100 | if self.global_step % self.show_images_iter == 0:
101 | # show images on tensorboard
102 | self.inverse_normalize(batch['img'])
103 | self.writer.add_images('TRAIN/imgs', batch['img'], self.global_step)
104 | # shrink_labels and threshold_labels
105 | shrink_labels = batch['shrink_map']
106 | threshold_labels = batch['threshold_map']
107 | shrink_labels[shrink_labels <= 0.5] = 0
108 | shrink_labels[shrink_labels > 0.5] = 1
109 | show_label = torch.cat([shrink_labels, threshold_labels])
110 | show_label = vutils.make_grid(show_label.unsqueeze(1), nrow=cur_batch_size, normalize=False, padding=20, pad_value=1)
111 | self.writer.add_image('TRAIN/gt', show_label, self.global_step)
112 | # model output
113 | show_pred = []
114 | for kk in range(preds.shape[1]):
115 | show_pred.append(preds[:, kk, :, :])
116 | show_pred = torch.cat(show_pred)
117 | show_pred = vutils.make_grid(show_pred.unsqueeze(1), nrow=cur_batch_size, normalize=False, padding=20, pad_value=1)
118 | self.writer.add_image('TRAIN/preds', show_pred, self.global_step)
119 | return {'train_loss': train_loss / self.train_loader_len, 'lr': lr, 'time': time.time() - epoch_start,
120 | 'epoch': epoch}
121 |
122 | def _eval(self, epoch):
123 | self.model.eval()
124 | # torch.cuda.empty_cache() # speed up evaluating after training finished
125 | raw_metrics = []
126 | total_frame = 0.0
127 | total_time = 0.0
128 | for i, batch in tqdm(enumerate(self.validate_loader), total=len(self.validate_loader), desc='test model'):
129 | with torch.no_grad():
130 | # 数据进行转换和丢到gpu
131 | for key, value in batch.items():
132 | if value is not None:
133 | if isinstance(value, torch.Tensor):
134 | batch[key] = value.to(self.device)
135 | start = time.time()
136 | preds = self.model(batch['img'])
137 | boxes, scores = self.post_process(batch, preds,is_output_polygon=self.metric_cls.is_output_polygon)
138 | total_frame += batch['img'].size()[0]
139 | total_time += time.time() - start
140 | raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
141 | raw_metrics.append(raw_metric)
142 | metrics = self.metric_cls.gather_measure(raw_metrics)
143 | self.logger_info('FPS:{}'.format(total_frame / total_time))
144 | return metrics['recall'].avg, metrics['precision'].avg, metrics['fmeasure'].avg
145 |
146 | def _on_epoch_finish(self):
147 | self.logger_info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
148 | self.epoch_result['epoch'], self.epochs, self.epoch_result['train_loss'], self.epoch_result['time'],
149 | self.epoch_result['lr']))
150 | net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir)
151 | net_save_path_best = '{}/model_best.pth'.format(self.checkpoint_dir)
152 |
153 | if self.config['local_rank'] == 0:
154 | self._save_checkpoint(self.epoch_result['epoch'], net_save_path)
155 | save_best = False
156 | if self.validate_loader is not None and self.metric_cls is not None: # 使用f1作为最优模型指标
157 | recall, precision, hmean = self._eval(self.epoch_result['epoch'])
158 |
159 | if self.tensorboard_enable:
160 | self.writer.add_scalar('EVAL/recall', recall, self.global_step)
161 | self.writer.add_scalar('EVAL/precision', precision, self.global_step)
162 | self.writer.add_scalar('EVAL/hmean', hmean, self.global_step)
163 | self.logger_info('test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.format(recall, precision, hmean))
164 |
165 | if hmean >= self.metrics['hmean']:
166 | save_best = True
167 | self.metrics['train_loss'] = self.epoch_result['train_loss']
168 | self.metrics['hmean'] = hmean
169 | self.metrics['precision'] = precision
170 | self.metrics['recall'] = recall
171 | self.metrics['best_model_epoch'] = self.epoch_result['epoch']
172 | else:
173 | if self.epoch_result['train_loss'] <= self.metrics['train_loss']:
174 | save_best = True
175 | self.metrics['train_loss'] = self.epoch_result['train_loss']
176 | self.metrics['best_model_epoch'] = self.epoch_result['epoch']
177 | best_str = 'current best, '
178 | for k, v in self.metrics.items():
179 | best_str += '{}: {:.6f}, '.format(k, v)
180 | self.logger_info(best_str)
181 | if save_best:
182 | import shutil
183 | shutil.copy(net_save_path, net_save_path_best)
184 | self.logger_info("Saving current best: {}".format(net_save_path_best))
185 | else:
186 | self.logger_info("Saving checkpoint: {}".format(net_save_path))
187 |
188 |
189 | def _on_train_finish(self):
190 | for k, v in self.metrics.items():
191 | self.logger_info('{}:{}'.format(k, v))
192 | self.logger_info('finish train')
193 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:58
3 | # @Author : zhoujun
4 | from .util import *
5 | from .metrics import *
6 | from .schedulers import *
7 | from .cal_recall.script import cal_recall_precison_f1
8 | from .ocr_metric import get_metric
--------------------------------------------------------------------------------
/utils/cal_recall/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/16/19 6:40 AM
3 | # @Author : zhoujun
4 | from .script import cal_recall_precison_f1
5 | __all__ = ['cal_recall_precison_f1']
--------------------------------------------------------------------------------
/utils/compute_mean_std.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/7 14:46
3 | # @Author : zhoujun
4 |
5 | import numpy as np
6 | import cv2
7 | import os
8 | import random
9 | from tqdm import tqdm
10 | # calculate means and std
11 | train_txt_path = './train_val_list.txt'
12 |
13 | CNum = 10000 # 挑选多少图片进行计算
14 |
15 | img_h, img_w = 640, 640
16 | imgs = np.zeros([img_w, img_h, 3, 1])
17 | means, stdevs = [], []
18 |
19 | with open(train_txt_path, 'r') as f:
20 | lines = f.readlines()
21 | random.shuffle(lines) # shuffle , 随机挑选图片
22 |
23 | for i in tqdm(range(CNum)):
24 | img_path = lines[i].split('\t')[0]
25 |
26 | img = cv2.imread(img_path)
27 | img = cv2.resize(img, (img_h, img_w))
28 | img = img[:, :, :, np.newaxis]
29 |
30 | imgs = np.concatenate((imgs, img), axis=3)
31 | # print(i)
32 |
33 | imgs = imgs.astype(np.float32) / 255.
34 |
35 | for i in tqdm(range(3)):
36 | pixels = imgs[:, :, i, :].ravel() # 拉成一行
37 | means.append(np.mean(pixels))
38 | stdevs.append(np.std(pixels))
39 |
40 | # cv2 读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转
41 | means.reverse() # BGR --> RGB
42 | stdevs.reverse()
43 |
44 | print("normMean = {}".format(means))
45 | print("normStd = {}".format(stdevs))
46 | print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
--------------------------------------------------------------------------------
/utils/make_trainfile.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/24 12:06
3 | # @Author : zhoujun
4 | import os
5 | import glob
6 | import pathlib
7 |
8 | data_path = r'test'
9 | # data_path/img 存放图片
10 | # data_path/gt 存放标签文件
11 |
12 | f_w = open(os.path.join(data_path, 'test.txt'), 'w', encoding='utf8')
13 | for img_path in glob.glob(data_path + '/img/*.jpg', recursive=True):
14 | d = pathlib.Path(img_path)
15 | label_path = os.path.join(data_path, 'gt', ('gt_' + str(d.stem) + '.txt'))
16 | if os.path.exists(img_path) and os.path.exists(label_path):
17 | print(img_path, label_path)
18 | else:
19 | print('不存在', img_path, label_path)
20 | f_w.write('{}\t{}\n'.format(img_path, label_path))
21 | f_w.close()
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Adapted from score written by wkentaro
2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3 |
4 | import numpy as np
5 |
6 |
7 | class runningScore(object):
8 |
9 | def __init__(self, n_classes):
10 | self.n_classes = n_classes
11 | self.confusion_matrix = np.zeros((n_classes, n_classes))
12 |
13 | def _fast_hist(self, label_true, label_pred, n_class):
14 | mask = (label_true >= 0) & (label_true < n_class)
15 |
16 | if np.sum((label_pred[mask] < 0)) > 0:
17 | print(label_pred[label_pred < 0])
18 | hist = np.bincount(n_class * label_true[mask].astype(int) +
19 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
20 | return hist
21 |
22 | def update(self, label_trues, label_preds):
23 | # print label_trues.dtype, label_preds.dtype
24 | for lt, lp in zip(label_trues, label_preds):
25 | try:
26 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
27 | except:
28 | pass
29 |
30 | def get_scores(self):
31 | """Returns accuracy score evaluation result.
32 | - overall accuracy
33 | - mean accuracy
34 | - mean IU
35 | - fwavacc
36 | """
37 | hist = self.confusion_matrix
38 | acc = np.diag(hist).sum() / (hist.sum() + 0.0001)
39 | acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001)
40 | acc_cls = np.nanmean(acc_cls)
41 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001)
42 | mean_iu = np.nanmean(iu)
43 | freq = hist.sum(axis=1) / (hist.sum() + 0.0001)
44 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
45 | cls_iu = dict(zip(range(self.n_classes), iu))
46 |
47 | return {'Overall Acc': acc,
48 | 'Mean Acc': acc_cls,
49 | 'FreqW Acc': fwavacc,
50 | 'Mean IoU': mean_iu, }, cls_iu
51 |
52 | def reset(self):
53 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
54 |
--------------------------------------------------------------------------------
/utils/ocr_metric/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/5 15:36
3 | # @Author : zhoujun
4 | from .icdar2015 import QuadMetric
5 |
6 |
7 | def get_metric(config):
8 | try:
9 | if 'args' not in config:
10 | args = {}
11 | else:
12 | args = config['args']
13 | if isinstance(args, dict):
14 | cls = eval(config['type'])(**args)
15 | else:
16 | cls = eval(config['type'])(args)
17 | return cls
18 | except:
19 | return None
--------------------------------------------------------------------------------
/utils/ocr_metric/icdar2015/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/12/5 15:36
3 | # @Author : zhoujun
4 |
5 | from .quad_metric import QuadMetric
--------------------------------------------------------------------------------
/utils/ocr_metric/icdar2015/detection/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/DBNet.pytorch/e03acf0e6b3b62f7d1dc7e10a6d2587456ac9ea1/utils/ocr_metric/icdar2015/detection/__init__.py
--------------------------------------------------------------------------------
/utils/ocr_metric/icdar2015/detection/iou.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from collections import namedtuple
4 | import numpy as np
5 | from shapely.geometry import Polygon
6 | import cv2
7 |
8 |
9 | def iou_rotate(box_a, box_b, method='union'):
10 | rect_a = cv2.minAreaRect(box_a)
11 | rect_b = cv2.minAreaRect(box_b)
12 | r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b)
13 | if r1[0] == 0:
14 | return 0
15 | else:
16 | inter_area = cv2.contourArea(r1[1])
17 | area_a = cv2.contourArea(box_a)
18 | area_b = cv2.contourArea(box_b)
19 | union_area = area_a + area_b - inter_area
20 | if union_area == 0 or inter_area == 0:
21 | return 0
22 | if method == 'union':
23 | iou = inter_area / union_area
24 | elif method == 'intersection':
25 | iou = inter_area / min(area_a, area_b)
26 | else:
27 | raise NotImplementedError
28 | return iou
29 |
30 |
31 | class DetectionIoUEvaluator(object):
32 | def __init__(self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5):
33 | self.is_output_polygon = is_output_polygon
34 | self.iou_constraint = iou_constraint
35 | self.area_precision_constraint = area_precision_constraint
36 |
37 | def evaluate_image(self, gt, pred):
38 |
39 | def get_union(pD, pG):
40 | return Polygon(pD).union(Polygon(pG)).area
41 |
42 | def get_intersection_over_union(pD, pG):
43 | return get_intersection(pD, pG) / get_union(pD, pG)
44 |
45 | def get_intersection(pD, pG):
46 | return Polygon(pD).intersection(Polygon(pG)).area
47 |
48 | def compute_ap(confList, matchList, numGtCare):
49 | correct = 0
50 | AP = 0
51 | if len(confList) > 0:
52 | confList = np.array(confList)
53 | matchList = np.array(matchList)
54 | sorted_ind = np.argsort(-confList)
55 | confList = confList[sorted_ind]
56 | matchList = matchList[sorted_ind]
57 | for n in range(len(confList)):
58 | match = matchList[n]
59 | if match:
60 | correct += 1
61 | AP += float(correct) / (n + 1)
62 |
63 | if numGtCare > 0:
64 | AP /= numGtCare
65 |
66 | return AP
67 |
68 | perSampleMetrics = {}
69 |
70 | matchedSum = 0
71 |
72 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
73 |
74 | numGlobalCareGt = 0
75 | numGlobalCareDet = 0
76 |
77 | arrGlobalConfidences = []
78 | arrGlobalMatches = []
79 |
80 | recall = 0
81 | precision = 0
82 | hmean = 0
83 |
84 | detMatched = 0
85 |
86 | iouMat = np.empty([1, 1])
87 |
88 | gtPols = []
89 | detPols = []
90 |
91 | gtPolPoints = []
92 | detPolPoints = []
93 |
94 | # Array of Ground Truth Polygons' keys marked as don't Care
95 | gtDontCarePolsNum = []
96 | # Array of Detected Polygons' matched with a don't Care GT
97 | detDontCarePolsNum = []
98 |
99 | pairs = []
100 | detMatchedNums = []
101 |
102 | arrSampleConfidences = []
103 | arrSampleMatch = []
104 |
105 | evaluationLog = ""
106 |
107 | for n in range(len(gt)):
108 | points = gt[n]['points']
109 | # transcription = gt[n]['text']
110 | dontCare = gt[n]['ignore']
111 |
112 | if not Polygon(points).is_valid or not Polygon(points).is_simple:
113 | continue
114 |
115 | gtPol = points
116 | gtPols.append(gtPol)
117 | gtPolPoints.append(points)
118 | if dontCare:
119 | gtDontCarePolsNum.append(len(gtPols) - 1)
120 |
121 | evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(
122 | gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
123 |
124 | for n in range(len(pred)):
125 | points = pred[n]['points']
126 | if not Polygon(points).is_valid or not Polygon(points).is_simple:
127 | continue
128 |
129 | detPol = points
130 | detPols.append(detPol)
131 | detPolPoints.append(points)
132 | if len(gtDontCarePolsNum) > 0:
133 | for dontCarePol in gtDontCarePolsNum:
134 | dontCarePol = gtPols[dontCarePol]
135 | intersected_area = get_intersection(dontCarePol, detPol)
136 | pdDimensions = Polygon(detPol).area
137 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
138 | if (precision > self.area_precision_constraint):
139 | detDontCarePolsNum.append(len(detPols) - 1)
140 | break
141 |
142 | evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(
143 | detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
144 |
145 | if len(gtPols) > 0 and len(detPols) > 0:
146 | # Calculate IoU and precision matrixs
147 | outputShape = [len(gtPols), len(detPols)]
148 | iouMat = np.empty(outputShape)
149 | gtRectMat = np.zeros(len(gtPols), np.int8)
150 | detRectMat = np.zeros(len(detPols), np.int8)
151 | if self.is_output_polygon:
152 | for gtNum in range(len(gtPols)):
153 | for detNum in range(len(detPols)):
154 | pG = gtPols[gtNum]
155 | pD = detPols[detNum]
156 | iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
157 | else:
158 | # gtPols = np.float32(gtPols)
159 | # detPols = np.float32(detPols)
160 | for gtNum in range(len(gtPols)):
161 | for detNum in range(len(detPols)):
162 | pG = np.float32(gtPols[gtNum])
163 | pD = np.float32(detPols[detNum])
164 | iouMat[gtNum, detNum] = iou_rotate(pD, pG)
165 | for gtNum in range(len(gtPols)):
166 | for detNum in range(len(detPols)):
167 | if gtRectMat[gtNum] == 0 and detRectMat[
168 | detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
169 | if iouMat[gtNum, detNum] > self.iou_constraint:
170 | gtRectMat[gtNum] = 1
171 | detRectMat[detNum] = 1
172 | detMatched += 1
173 | pairs.append({'gt': gtNum, 'det': detNum})
174 | detMatchedNums.append(detNum)
175 | evaluationLog += "Match GT #" + \
176 | str(gtNum) + " with Det #" + str(detNum) + "\n"
177 |
178 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
179 | numDetCare = (len(detPols) - len(detDontCarePolsNum))
180 | if numGtCare == 0:
181 | recall = float(1)
182 | precision = float(0) if numDetCare > 0 else float(1)
183 | else:
184 | recall = float(detMatched) / numGtCare
185 | precision = 0 if numDetCare == 0 else float(
186 | detMatched) / numDetCare
187 |
188 | hmean = 0 if (precision + recall) == 0 else 2.0 * \
189 | precision * recall / (precision + recall)
190 |
191 | matchedSum += detMatched
192 | numGlobalCareGt += numGtCare
193 | numGlobalCareDet += numDetCare
194 |
195 | perSampleMetrics = {
196 | 'precision': precision,
197 | 'recall': recall,
198 | 'hmean': hmean,
199 | 'pairs': pairs,
200 | 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
201 | 'gtPolPoints': gtPolPoints,
202 | 'detPolPoints': detPolPoints,
203 | 'gtCare': numGtCare,
204 | 'detCare': numDetCare,
205 | 'gtDontCare': gtDontCarePolsNum,
206 | 'detDontCare': detDontCarePolsNum,
207 | 'detMatched': detMatched,
208 | 'evaluationLog': evaluationLog
209 | }
210 |
211 | return perSampleMetrics
212 |
213 | def combine_results(self, results):
214 | numGlobalCareGt = 0
215 | numGlobalCareDet = 0
216 | matchedSum = 0
217 | for result in results:
218 | numGlobalCareGt += result['gtCare']
219 | numGlobalCareDet += result['detCare']
220 | matchedSum += result['detMatched']
221 |
222 | methodRecall = 0 if numGlobalCareGt == 0 else float(
223 | matchedSum) / numGlobalCareGt
224 | methodPrecision = 0 if numGlobalCareDet == 0 else float(
225 | matchedSum) / numGlobalCareDet
226 | methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
227 | methodRecall * methodPrecision / (
228 | methodRecall + methodPrecision)
229 |
230 | methodMetrics = {'precision': methodPrecision,
231 | 'recall': methodRecall, 'hmean': methodHmean}
232 |
233 | return methodMetrics
234 |
235 |
236 | if __name__ == '__main__':
237 | evaluator = DetectionIoUEvaluator()
238 | preds = [[{
239 | 'points': [(0.1, 0.1), (0.5, 0), (0.5, 1), (0, 1)],
240 | 'text': 1234,
241 | 'ignore': False,
242 | }, {
243 | 'points': [(0.5, 0.1), (1, 0), (1, 1), (0.5, 1)],
244 | 'text': 5678,
245 | 'ignore': False,
246 | }]]
247 | gts = [[{
248 | 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
249 | 'text': 123,
250 | 'ignore': False,
251 | }]]
252 | results = []
253 | for gt, pred in zip(gts, preds):
254 | results.append(evaluator.evaluate_image(gt, pred))
255 | metrics = evaluator.combine_results(results)
256 | print(metrics)
257 |
--------------------------------------------------------------------------------
/utils/ocr_metric/icdar2015/quad_metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .detection.iou import DetectionIoUEvaluator
4 |
5 |
6 | class AverageMeter(object):
7 | """Computes and stores the average and current value"""
8 |
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 | return self
24 |
25 |
26 | class QuadMetric():
27 | def __init__(self, is_output_polygon=False):
28 | self.is_output_polygon = is_output_polygon
29 | self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon)
30 |
31 | def measure(self, batch, output, box_thresh=0.6):
32 | '''
33 | batch: (image, polygons, ignore_tags
34 | batch: a dict produced by dataloaders.
35 | image: tensor of shape (N, C, H, W).
36 | polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
37 | ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
38 | shape: the original shape of images.
39 | filename: the original filenames of images.
40 | output: (polygons, ...)
41 | '''
42 | results = []
43 | gt_polyons_batch = batch['text_polys']
44 | ignore_tags_batch = batch['ignore_tags']
45 | pred_polygons_batch = np.array(output[0])
46 | pred_scores_batch = np.array(output[1])
47 | for polygons, pred_polygons, pred_scores, ignore_tags in zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch):
48 | gt = [dict(points=np.int64(polygons[i]), ignore=ignore_tags[i]) for i in range(len(polygons))]
49 | if self.is_output_polygon:
50 | pred = [dict(points=pred_polygons[i]) for i in range(len(pred_polygons))]
51 | else:
52 | pred = []
53 | # print(pred_polygons.shape)
54 | for i in range(pred_polygons.shape[0]):
55 | if pred_scores[i] >= box_thresh:
56 | # print(pred_polygons[i,:,:].tolist())
57 | pred.append(dict(points=pred_polygons[i, :, :].astype(np.int)))
58 | # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])]
59 | results.append(self.evaluator.evaluate_image(gt, pred))
60 | return results
61 |
62 | def validate_measure(self, batch, output, box_thresh=0.6):
63 | return self.measure(batch, output, box_thresh)
64 |
65 | def evaluate_measure(self, batch, output):
66 | return self.measure(batch, output), np.linspace(0, batch['image'].shape[0]).tolist()
67 |
68 | def gather_measure(self, raw_metrics):
69 | raw_metrics = [image_metrics
70 | for batch_metrics in raw_metrics
71 | for image_metrics in batch_metrics]
72 |
73 | result = self.evaluator.combine_results(raw_metrics)
74 |
75 | precision = AverageMeter()
76 | recall = AverageMeter()
77 | fmeasure = AverageMeter()
78 |
79 | precision.update(result['precision'], n=len(raw_metrics))
80 | recall.update(result['recall'], n=len(raw_metrics))
81 | fmeasure_score = 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8)
82 | fmeasure.update(fmeasure_score)
83 |
84 | return {
85 | 'precision': precision,
86 | 'recall': recall,
87 | 'fmeasure': fmeasure
88 | }
89 |
--------------------------------------------------------------------------------
/utils/schedulers.py:
--------------------------------------------------------------------------------
1 | """Popular Learning Rate Schedulers"""
2 | from __future__ import division
3 |
4 | import math
5 | from bisect import bisect_right
6 |
7 | import torch
8 |
9 | __all__ = ['LRScheduler', 'WarmupMultiStepLR', 'WarmupPolyLR']
10 |
11 |
12 | class LRScheduler(object):
13 | r"""Learning Rate Scheduler
14 | Parameters
15 | ----------
16 | mode : str
17 | Modes for learning rate scheduler.
18 | Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'.
19 | base_lr : float
20 | Base learning rate, i.e. the starting learning rate.
21 | target_lr : float
22 | Target learning rate, i.e. the ending learning rate.
23 | With constant mode target_lr is ignored.
24 | niters : int
25 | Number of iterations to be scheduled.
26 | nepochs : int
27 | Number of epochs to be scheduled.
28 | iters_per_epoch : int
29 | Number of iterations in each epoch.
30 | offset : int
31 | Number of iterations before this scheduler.
32 | power : float
33 | Power parameter of poly scheduler.
34 | step_iter : list
35 | A list of iterations to decay the learning rate.
36 | step_epoch : list
37 | A list of epochs to decay the learning rate.
38 | step_factor : float
39 | Learning rate decay factor.
40 | """
41 |
42 | def __init__(self, mode, base_lr=0.01, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0,
43 | offset=0, power=0.9, step_iter=None, step_epoch=None, step_factor=0.1, warmup_epochs=0):
44 | super(LRScheduler, self).__init__()
45 | assert (mode in ['constant', 'step', 'linear', 'poly', 'cosine'])
46 |
47 | if mode == 'step':
48 | assert (step_iter is not None or step_epoch is not None)
49 | self.niters = niters
50 | self.step = step_iter
51 | epoch_iters = nepochs * iters_per_epoch
52 | if epoch_iters > 0:
53 | self.niters = epoch_iters
54 | if step_epoch is not None:
55 | self.step = [s * iters_per_epoch for s in step_epoch]
56 |
57 | self.step_factor = step_factor
58 | self.base_lr = base_lr
59 | self.target_lr = base_lr if mode == 'constant' else target_lr
60 | self.offset = offset
61 | self.power = power
62 | self.warmup_iters = warmup_epochs * iters_per_epoch
63 | self.mode = mode
64 |
65 | def __call__(self, optimizer, num_update):
66 | self.update(num_update)
67 | assert self.learning_rate >= 0
68 | self._adjust_learning_rate(optimizer, self.learning_rate)
69 |
70 | def update(self, num_update):
71 | N = self.niters - 1
72 | T = num_update - self.offset
73 | T = min(max(0, T), N)
74 |
75 | if self.mode == 'constant':
76 | factor = 0
77 | elif self.mode == 'linear':
78 | factor = 1 - T / N
79 | elif self.mode == 'poly':
80 | factor = pow(1 - T / N, self.power)
81 | elif self.mode == 'cosine':
82 | factor = (1 + math.cos(math.pi * T / N)) / 2
83 | elif self.mode == 'step':
84 | if self.step is not None:
85 | count = sum([1 for s in self.step if s <= T])
86 | factor = pow(self.step_factor, count)
87 | else:
88 | factor = 1
89 | else:
90 | raise NotImplementedError
91 |
92 | # warm up lr schedule
93 | if self.warmup_iters > 0 and T < self.warmup_iters:
94 | factor = factor * 1.0 * T / self.warmup_iters
95 |
96 | if self.mode == 'step':
97 | self.learning_rate = self.base_lr * factor
98 | else:
99 | self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor
100 |
101 | def _adjust_learning_rate(self, optimizer, lr):
102 | optimizer.param_groups[0]['lr'] = lr
103 | # enlarge the lr at the head
104 | for i in range(1, len(optimizer.param_groups)):
105 | optimizer.param_groups[i]['lr'] = lr * 10
106 |
107 |
108 | # separating MultiStepLR with WarmupLR
109 | # but the current LRScheduler design doesn't allow it
110 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py
111 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
112 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3,
113 | warmup_iters=500, warmup_method="linear", last_epoch=-1, **kwargs):
114 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
115 | if not list(milestones) == sorted(milestones):
116 | raise ValueError(
117 | "Milestones should be a list of" " increasing integers. Got {}", milestones)
118 | if warmup_method not in ("constant", "linear"):
119 | raise ValueError(
120 | "Only 'constant' or 'linear' warmup_method accepted got {}".format(warmup_method))
121 |
122 | self.milestones = milestones
123 | self.gamma = gamma
124 | self.warmup_factor = warmup_factor
125 | self.warmup_iters = warmup_iters
126 | self.warmup_method = warmup_method
127 |
128 | def get_lr(self):
129 | warmup_factor = 1
130 | if self.last_epoch < self.warmup_iters:
131 | if self.warmup_method == 'constant':
132 | warmup_factor = self.warmup_factor
133 | elif self.warmup_factor == 'linear':
134 | alpha = float(self.last_epoch) / self.warmup_iters
135 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
136 | return [base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
137 | for base_lr in self.base_lrs]
138 |
139 |
140 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
141 | def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3,
142 | warmup_iters=500, warmup_method='linear', last_epoch=-1, **kwargs):
143 | if warmup_method not in ("constant", "linear"):
144 | raise ValueError(
145 | "Only 'constant' or 'linear' warmup_method accepted "
146 | "got {}".format(warmup_method))
147 |
148 | self.target_lr = target_lr
149 | self.max_iters = max_iters
150 | self.power = power
151 | self.warmup_factor = warmup_factor
152 | self.warmup_iters = warmup_iters
153 | self.warmup_method = warmup_method
154 |
155 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch)
156 |
157 | def get_lr(self):
158 | N = self.max_iters - self.warmup_iters
159 | T = self.last_epoch - self.warmup_iters
160 | if self.last_epoch < self.warmup_iters:
161 | if self.warmup_method == 'constant':
162 | warmup_factor = self.warmup_factor
163 | elif self.warmup_method == 'linear':
164 | alpha = float(self.last_epoch) / self.warmup_iters
165 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
166 | else:
167 | raise ValueError("Unknown warmup type.")
168 | return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs]
169 | factor = pow(1 - T / N, self.power)
170 | return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs]
171 |
172 |
173 | if __name__ == '__main__':
174 | import torch
175 | from torchvision.models import resnet18
176 |
177 | max_iter = 600 * 63
178 | model = resnet18()
179 | op = torch.optim.SGD(model.parameters(), 1e-3)
180 | sc = WarmupPolyLR(op, max_iters=max_iter, power=0.9, warmup_iters=3 * 63, warmup_method='constant')
181 | lr = []
182 | for i in range(max_iter):
183 | sc.step()
184 | print(i, sc.last_epoch, sc.get_lr()[0])
185 | lr.append(sc.get_lr()[0])
186 | from matplotlib import pyplot as plt
187 |
188 | plt.plot(list(range(max_iter)), lr)
189 | plt.show()
190 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:59
3 | # @Author : zhoujun
4 | import json
5 | import pathlib
6 | import time
7 | import os
8 | import glob
9 | from natsort import natsorted
10 | import cv2
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 |
14 |
15 | def get_file_list(folder_path: str, p_postfix: list = None, sub_dir: bool = True) -> list:
16 | """
17 | 获取所给文件目录里的指定后缀的文件,读取文件列表目前使用的是 os.walk 和 os.listdir ,这两个目前比 pathlib 快很多
18 | :param filder_path: 文件夹名称
19 | :param p_postfix: 文件后缀,如果为 [.*]将返回全部文件
20 | :param sub_dir: 是否搜索子文件夹
21 | :return: 获取到的指定类型的文件列表
22 | """
23 | assert os.path.exists(folder_path) and os.path.isdir(folder_path)
24 | if p_postfix is None:
25 | p_postfix = ['.jpg']
26 | if isinstance(p_postfix, str):
27 | p_postfix = [p_postfix]
28 | file_list = [x for x in glob.glob(folder_path + '/**/*.*', recursive=True) if
29 | os.path.splitext(x)[-1] in p_postfix or '.*' in p_postfix]
30 | return natsorted(file_list)
31 |
32 |
33 | def setup_logger(log_file_path: str = None):
34 | import logging
35 | logging._warn_preinit_stderr = 0
36 | logger = logging.getLogger('DBNet.pytorch')
37 | formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s')
38 | ch = logging.StreamHandler()
39 | ch.setFormatter(formatter)
40 | logger.addHandler(ch)
41 | if log_file_path is not None:
42 | file_handle = logging.FileHandler(log_file_path)
43 | file_handle.setFormatter(formatter)
44 | logger.addHandler(file_handle)
45 | logger.setLevel(logging.DEBUG)
46 | return logger
47 |
48 |
49 | # --exeTime
50 | def exe_time(func):
51 | def newFunc(*args, **args2):
52 | t0 = time.time()
53 | back = func(*args, **args2)
54 | print("{} cost {:.3f}s".format(func.__name__, time.time() - t0))
55 | return back
56 |
57 | return newFunc
58 |
59 |
60 | def load(file_path: str):
61 | file_path = pathlib.Path(file_path)
62 | func_dict = {'.txt': _load_txt, '.json': _load_json, '.list': _load_txt}
63 | assert file_path.suffix in func_dict
64 | return func_dict[file_path.suffix](file_path)
65 |
66 |
67 | def _load_txt(file_path: str):
68 | with open(file_path, 'r', encoding='utf8') as f:
69 | content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()]
70 | return content
71 |
72 |
73 | def _load_json(file_path: str):
74 | with open(file_path, 'r', encoding='utf8') as f:
75 | content = json.load(f)
76 | return content
77 |
78 |
79 | def save(data, file_path):
80 | file_path = pathlib.Path(file_path)
81 | func_dict = {'.txt': _save_txt, '.json': _save_json}
82 | assert file_path.suffix in func_dict
83 | return func_dict[file_path.suffix](data, file_path)
84 |
85 |
86 | def _save_txt(data, file_path):
87 | """
88 | 将一个list的数组写入txt文件里
89 | :param data:
90 | :param file_path:
91 | :return:
92 | """
93 | if not isinstance(data, list):
94 | data = [data]
95 | with open(file_path, mode='w', encoding='utf8') as f:
96 | f.write('\n'.join(data))
97 |
98 |
99 | def _save_json(data, file_path):
100 | with open(file_path, 'w', encoding='utf-8') as json_file:
101 | json.dump(data, json_file, ensure_ascii=False, indent=4)
102 |
103 |
104 | def show_img(imgs: np.ndarray, title='img'):
105 | color = (len(imgs.shape) == 3 and imgs.shape[-1] == 3)
106 | imgs = np.expand_dims(imgs, axis=0)
107 | for i, img in enumerate(imgs):
108 | plt.figure()
109 | plt.title('{}_{}'.format(title, i))
110 | plt.imshow(img, cmap=None if color else 'gray')
111 | plt.show()
112 |
113 |
114 | def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2):
115 | if isinstance(img_path, str):
116 | img_path = cv2.imread(img_path)
117 | # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
118 | img_path = img_path.copy()
119 | for point in result:
120 | point = point.astype(int)
121 | cv2.polylines(img_path, [point], True, color, thickness)
122 | return img_path
123 |
124 |
125 | def cal_text_score(texts, gt_texts, training_masks, running_metric_text, thred=0.5):
126 | training_masks = training_masks.data.cpu().numpy()
127 | pred_text = texts.data.cpu().numpy() * training_masks
128 | pred_text[pred_text <= thred] = 0
129 | pred_text[pred_text > thred] = 1
130 | pred_text = pred_text.astype(np.int32)
131 | gt_text = gt_texts.data.cpu().numpy() * training_masks
132 | gt_text = gt_text.astype(np.int32)
133 | running_metric_text.update(gt_text, pred_text)
134 | score_text, _ = running_metric_text.get_scores()
135 | return score_text
136 |
137 |
138 | def order_points_clockwise(pts):
139 | rect = np.zeros((4, 2), dtype="float32")
140 | s = pts.sum(axis=1)
141 | rect[0] = pts[np.argmin(s)]
142 | rect[2] = pts[np.argmax(s)]
143 | diff = np.diff(pts, axis=1)
144 | rect[1] = pts[np.argmin(diff)]
145 | rect[3] = pts[np.argmax(diff)]
146 | return rect
147 |
148 |
149 | def order_points_clockwise_list(pts):
150 | pts = pts.tolist()
151 | pts.sort(key=lambda x: (x[1], x[0]))
152 | pts[:2] = sorted(pts[:2], key=lambda x: x[0])
153 | pts[2:] = sorted(pts[2:], key=lambda x: -x[0])
154 | pts = np.array(pts)
155 | return pts
156 |
157 |
158 | def get_datalist(train_data_path):
159 | """
160 | 获取训练和验证的数据list
161 | :param train_data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
162 | :return:
163 | """
164 | train_data = []
165 | for p in train_data_path:
166 | with open(p, 'r', encoding='utf-8') as f:
167 | for line in f.readlines():
168 | line = line.strip('\n').replace('.jpg ', '.jpg\t').split('\t')
169 | if len(line) > 1:
170 | img_path = pathlib.Path(line[0].strip(' '))
171 | label_path = pathlib.Path(line[1].strip(' '))
172 | if img_path.exists() and img_path.stat().st_size > 0 and label_path.exists() and label_path.stat().st_size > 0:
173 | train_data.append((str(img_path), str(label_path)))
174 | return train_data
175 |
176 |
177 | def parse_config(config: dict) -> dict:
178 | import anyconfig
179 | base_file_list = config.pop('base')
180 | base_config = {}
181 | for base_file in base_file_list:
182 | tmp_config = anyconfig.load(open(base_file, 'rb'))
183 | if 'base' in tmp_config:
184 | tmp_config = parse_config(tmp_config)
185 | anyconfig.merge(tmp_config, base_config)
186 | base_config = tmp_config
187 | anyconfig.merge(base_config, config)
188 | return base_config
189 |
190 |
191 | def save_result(result_path, box_list, score_list, is_output_polygon):
192 | if is_output_polygon:
193 | with open(result_path, 'wt') as res:
194 | for i, box in enumerate(box_list):
195 | box = box.reshape(-1).tolist()
196 | result = ",".join([str(int(x)) for x in box])
197 | score = score_list[i]
198 | res.write(result + ',' + str(score) + "\n")
199 | else:
200 | with open(result_path, 'wt') as res:
201 | for i, box in enumerate(box_list):
202 | score = score_list[i]
203 | box = box.reshape(-1).tolist()
204 | result = ",".join([str(int(x)) for x in box])
205 | res.write(result + ',' + str(score) + "\n")
206 |
207 |
208 | def expand_polygon(polygon):
209 | """
210 | 对只有一个字符的框进行扩充
211 | """
212 | (x, y), (w, h), angle = cv2.minAreaRect(np.float32(polygon))
213 | if angle < -45:
214 | w, h = h, w
215 | angle += 90
216 | new_w = w + h
217 | box = ((x, y), (new_w, h), angle)
218 | points = cv2.boxPoints(box)
219 | return order_points_clockwise(points)
220 |
221 |
222 | if __name__ == '__main__':
223 | img = np.zeros((1, 3, 640, 640))
224 | show_img(img[0][0])
225 | plt.show()
226 |
--------------------------------------------------------------------------------