├── .gitattributes ├── .gitignore ├── LICENSE.md ├── README.MD ├── base ├── __init__.py ├── base_dataset.py └── base_trainer.py ├── config ├── SynthText.yaml ├── SynthText_resnet18_FPN_DBhead_polyLR.yaml ├── icdar2015.yaml ├── icdar2015_dcn_resnet18_FPN_DBhead_polyLR.yaml ├── icdar2015_resnet18_FPN_DBhead_polyLR.yaml ├── icdar2015_resnet18_FPN_DBhead_polyLR_finetune.yaml ├── icdar2015_resnet50_FPN_DBhead_polyLR.yaml ├── open_dataset.yaml ├── open_dataset_dcn_resnet50_FPN_DBhead_polyLR.yaml ├── open_dataset_resnest50_FPN_DBhead_polyLR.yaml └── open_dataset_resnet18_FPN_DBhead_polyLR.yaml ├── data_loader ├── __init__.py ├── dataset.py └── modules │ ├── __init__.py │ ├── augment.py │ ├── iaa_augment.py │ ├── make_border_map.py │ ├── make_shrink_map.py │ └── random_crop_data.py ├── datasets ├── test.txt ├── test │ ├── gt │ │ └── README.MD │ └── img │ │ └── README.MD ├── train.txt └── train │ ├── gt │ └── README.MD │ └── img │ └── README.MD ├── environment.yml ├── eval.sh ├── generate_lists.sh ├── imgs └── paper │ └── db.jpg ├── models ├── __init__.py ├── backbone │ ├── MobilenetV3.py │ ├── __init__.py │ ├── resnest │ │ ├── __init__.py │ │ ├── ablation.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ └── splat.py │ ├── resnet.py │ └── shufflenetv2.py ├── basic.py ├── head │ ├── ConvHead.py │ ├── DBHead.py │ └── __init__.py ├── losses │ ├── DB_loss.py │ ├── __init__.py │ └── basic_loss.py ├── model.py └── neck │ ├── FPEM_FFM.py │ ├── FPN.py │ └── __init__.py ├── multi_gpu_train.sh ├── post_processing ├── __init__.py └── seg_detector_representer.py ├── predict.sh ├── requirement.txt ├── singlel_gpu_train.sh ├── test └── README.MD ├── tools ├── __init__.py ├── eval.py ├── predict.py └── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── cal_recall ├── __init__.py ├── rrc_evaluation_funcs.py └── script.py ├── compute_mean_std.py ├── make_trainfile.py ├── metrics.py ├── ocr_metric ├── __init__.py └── icdar2015 │ ├── __init__.py │ ├── detection │ ├── __init__.py │ ├── deteval.py │ ├── icdar2013.py │ ├── iou.py │ └── mtwi2018.py │ └── quad_metric.py ├── schedulers.py └── util.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.html linguist-language=python 2 | *.ipynb linguist-language=python -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pth 3 | *.pyc 4 | *.pyo 5 | *.log 6 | *.tmp 7 | *.pkl 8 | __pycache__/ 9 | .idea/ 10 | output/ 11 | test/*.jpg -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # Real-time Scene Text Detection with Differentiable Binarization 2 | 3 | **note**: some code is inherited from [MhLiao/DB](https://github.com/MhLiao/DB) 4 | 5 | [中文解读](https://zhuanlan.zhihu.com/p/94677957) 6 | 7 | ![network](imgs/paper/db.jpg) 8 | 9 | ## update 10 | 2020-06-07: 添加灰度图训练,训练灰度图时需要在配置里移除`dataset.args.transforms.Normalize` 11 | 12 | ## Install Using Conda 13 | ``` 14 | conda env create -f environment.yml 15 | git clone https://github.com/WenmuZhou/DBNet.pytorch.git 16 | cd DBNet.pytorch/ 17 | ``` 18 | 19 | or 20 | ## Install Manually 21 | ```bash 22 | conda create -n dbnet python=3.6 23 | conda activate dbnet 24 | 25 | conda install ipython pip 26 | 27 | # python dependencies 28 | pip install -r requirement.txt 29 | 30 | # install PyTorch with cuda-10.1 31 | # Note that you can change the cudatoolkit version to the version you want. 32 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 33 | 34 | # clone repo 35 | git clone https://github.com/WenmuZhou/DBNet.pytorch.git 36 | cd DBNet.pytorch/ 37 | 38 | ``` 39 | 40 | ## Requirements 41 | * pytorch 1.4+ 42 | * torchvision 0.5+ 43 | * gcc 4.9+ 44 | 45 | ## Download 46 | 47 | TBD 48 | 49 | ## Data Preparation 50 | 51 | Training data: prepare a text `train.txt` in the following format, use '\t' as a separator 52 | ``` 53 | ./datasets/train/img/001.jpg ./datasets/train/gt/001.txt 54 | ``` 55 | 56 | Validation data: prepare a text `test.txt` in the following format, use '\t' as a separator 57 | ``` 58 | ./datasets/test/img/001.jpg ./datasets/test/gt/001.txt 59 | ``` 60 | - Store images in the `img` folder 61 | - Store groundtruth in the `gt` folder 62 | 63 | The groundtruth can be `.txt` files, with the following format: 64 | ``` 65 | x1, y1, x2, y2, x3, y3, x4, y4, annotation 66 | ``` 67 | 68 | 69 | ## Train 70 | 1. config the `dataset['train']['dataset'['data_path']'`,`dataset['validate']['dataset'['data_path']`in [config/icdar2015_resnet18_fpn_DBhead_polyLR.yaml](cconfig/icdar2015_resnet18_fpn_DBhead_polyLR.yaml) 71 | * . single gpu train 72 | ```bash 73 | bash singlel_gpu_train.sh 74 | ``` 75 | * . Multi-gpu training 76 | ```bash 77 | bash multi_gpu_train.sh 78 | ``` 79 | ## Test 80 | 81 | [eval.py](tools/eval.py) is used to test model on test dataset 82 | 83 | 1. config `model_path` in [eval.sh](eval.sh) 84 | 2. use following script to test 85 | ```bash 86 | bash eval.sh 87 | ``` 88 | 89 | ## Predict 90 | [predict.py](tools/predict.py) Can be used to inference on all images in a folder 91 | 1. config `model_path`,`input_folder`,`output_folder` in [predict.sh](predict.sh) 92 | 2. use following script to predict 93 | ``` 94 | bash predict.sh 95 | ``` 96 | You can change the `model_path` in the `predict.sh` file to your model location. 97 | 98 | tips: if result is not good, you can change `thre` in [predict.sh](predict.sh) 99 | 100 | The project is still under development. 101 | 102 |

Performance

103 | 104 | ### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4) 105 | only train on ICDAR2015 dataset 106 | 107 | | Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS | 108 | |:--------------------------:|:-------:|:--------:|:--------:|:------------:|:---------------:|:-----:| 109 | | SynthText-Defrom-ResNet-18(paper) | 736 |0.007 | 86.8 | 78.4 | 82.3 | 48 | 110 | | ImageNet-resnet18-FPN-DBHead |736 |1e-3| 87.03 | 75.06 | 80.6 | 43 | 111 | | ImageNet-Defrom-Resnet18-FPN-DBHead |736 |1e-3| 88.61 | 73.84 | 80.56 | 36 | 112 | | ImageNet-resnet50-FPN-DBHead |736 |1e-3| 88.06 | 77.14 | 82.24 | 27 | 113 | | ImageNet-resnest50-FPN-DBHead |736 |1e-3| 88.18 | 76.27 | 81.78 | 27 | 114 | 115 | 116 | ### examples 117 | TBD 118 | 119 | 120 | ### todo 121 | - [x] mutil gpu training 122 | 123 | ### reference 124 | 1. https://arxiv.org/pdf/1911.08947.pdf 125 | 2. https://github.com/WenmuZhou/PANet.pytorch 126 | 3. https://github.com/MhLiao/DB 127 | 128 | **If this repository helps you,please star it. Thanks.** 129 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | from .base_dataset import BaseDataSet -------------------------------------------------------------------------------- /base/base_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/4 13:12 3 | # @Author : zhoujun 4 | import copy 5 | from torch.utils.data import Dataset 6 | from data_loader.modules import * 7 | 8 | 9 | class BaseDataSet(Dataset): 10 | 11 | def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None, 12 | target_transform=None): 13 | assert img_mode in ['RGB', 'BRG', 'GRAY'] 14 | self.ignore_tags = ignore_tags 15 | self.data_list = self.load_data(data_path) 16 | item_keys = ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags'] 17 | for item in item_keys: 18 | assert item in self.data_list[0], 'data_list from load_data must contains {}'.format(item_keys) 19 | self.img_mode = img_mode 20 | self.filter_keys = filter_keys 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | self._init_pre_processes(pre_processes) 24 | 25 | def _init_pre_processes(self, pre_processes): 26 | self.aug = [] 27 | if pre_processes is not None: 28 | for aug in pre_processes: 29 | if 'args' not in aug: 30 | args = {} 31 | else: 32 | args = aug['args'] 33 | if isinstance(args, dict): 34 | cls = eval(aug['type'])(**args) 35 | else: 36 | cls = eval(aug['type'])(args) 37 | self.aug.append(cls) 38 | 39 | def load_data(self, data_path: str) -> list: 40 | """ 41 | 把数据加载为一个list: 42 | :params data_path: 存储数据的文件夹或者文件 43 | return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags' 44 | """ 45 | raise NotImplementedError 46 | 47 | def apply_pre_processes(self, data): 48 | for aug in self.aug: 49 | data = aug(data) 50 | return data 51 | 52 | def __getitem__(self, index): 53 | try: 54 | data = copy.deepcopy(self.data_list[index]) 55 | im = cv2.imread(data['img_path'], 1 if self.img_mode != 'GRAY' else 0) 56 | if self.img_mode == 'RGB': 57 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 58 | data['img'] = im 59 | data['shape'] = [im.shape[0], im.shape[1]] 60 | data = self.apply_pre_processes(data) 61 | 62 | if self.transform: 63 | data['img'] = self.transform(data['img']) 64 | data['text_polys'] = data['text_polys'].tolist() 65 | if len(self.filter_keys): 66 | data_dict = {} 67 | for k, v in data.items(): 68 | if k not in self.filter_keys: 69 | data_dict[k] = v 70 | return data_dict 71 | else: 72 | return data 73 | except: 74 | return self.__getitem__(np.random.randint(self.__len__())) 75 | 76 | def __len__(self): 77 | return len(self.data_list) 78 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:50 3 | # @Author : zhoujun 4 | 5 | import os 6 | import pathlib 7 | import shutil 8 | from pprint import pformat 9 | 10 | import anyconfig 11 | import torch 12 | 13 | from utils import setup_logger 14 | 15 | 16 | class BaseTrainer: 17 | def __init__(self, config, model, criterion): 18 | config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent), 19 | config['trainer']['output_dir']) 20 | config['name'] = config['name'] + '_' + model.name 21 | self.save_dir = os.path.join(config['trainer']['output_dir'], config['name']) 22 | self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') 23 | 24 | if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '': 25 | shutil.rmtree(self.save_dir, ignore_errors=True) 26 | if not os.path.exists(self.checkpoint_dir): 27 | os.makedirs(self.checkpoint_dir) 28 | 29 | self.global_step = 0 30 | self.start_epoch = 0 31 | self.config = config 32 | self.model = model 33 | self.criterion = criterion 34 | # logger and tensorboard 35 | self.tensorboard_enable = self.config['trainer']['tensorboard'] 36 | self.epochs = self.config['trainer']['epochs'] 37 | self.log_iter = self.config['trainer']['log_iter'] 38 | if config['local_rank'] == 0: 39 | anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml')) 40 | self.logger = setup_logger(os.path.join(self.save_dir, 'train.log')) 41 | self.logger_info(pformat(self.config)) 42 | 43 | # device 44 | torch.manual_seed(self.config['trainer']['seed']) # 为CPU设置随机种子 45 | if torch.cuda.device_count() > 0 and torch.cuda.is_available(): 46 | self.with_cuda = True 47 | torch.backends.cudnn.benchmark = True 48 | self.device = torch.device("cuda") 49 | torch.cuda.manual_seed(self.config['trainer']['seed']) # 为当前GPU设置随机种子 50 | torch.cuda.manual_seed_all(self.config['trainer']['seed']) # 为所有GPU设置随机种子 51 | else: 52 | self.with_cuda = False 53 | self.device = torch.device("cpu") 54 | self.logger_info('train with device {} and pytorch {}'.format(self.device, torch.__version__)) 55 | # metrics 56 | self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf'),'best_model_epoch':0} 57 | 58 | self.optimizer = self._initialize('optimizer', torch.optim, model.parameters()) 59 | 60 | # resume or finetune 61 | if self.config['trainer']['resume_checkpoint'] != '': 62 | self._load_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True) 63 | elif self.config['trainer']['finetune_checkpoint'] != '': 64 | self._load_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False) 65 | 66 | if self.config['lr_scheduler']['type'] != 'WarmupPolyLR': 67 | self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer) 68 | 69 | self.model.to(self.device) 70 | 71 | if self.tensorboard_enable and config['local_rank'] == 0: 72 | from torch.utils.tensorboard import SummaryWriter 73 | self.writer = SummaryWriter(self.save_dir) 74 | try: 75 | # add graph 76 | in_channels = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1 77 | dummy_input = torch.zeros(1, in_channels, 640, 640).to(self.device) 78 | self.writer.add_graph(self.model, dummy_input) 79 | torch.cuda.empty_cache() 80 | except: 81 | import traceback 82 | self.logger.error(traceback.format_exc()) 83 | self.logger.warn('add graph to tensorboard failed') 84 | # 分布式训练 85 | if torch.cuda.device_count() > 1: 86 | local_rank = config['local_rank'] 87 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, 88 | find_unused_parameters=True) 89 | # make inverse Normalize 90 | self.UN_Normalize = False 91 | for t in self.config['dataset']['train']['dataset']['args']['transforms']: 92 | if t['type'] == 'Normalize': 93 | self.normalize_mean = t['args']['mean'] 94 | self.normalize_std = t['args']['std'] 95 | self.UN_Normalize = True 96 | 97 | def train(self): 98 | """ 99 | Full training logic 100 | """ 101 | for epoch in range(self.start_epoch + 1, self.epochs + 1): 102 | if self.config['distributed']: 103 | self.train_loader.sampler.set_epoch(epoch) 104 | self.epoch_result = self._train_epoch(epoch) 105 | if self.config['lr_scheduler']['type'] != 'WarmupPolyLR': 106 | self.scheduler.step() 107 | self._on_epoch_finish() 108 | if self.config['local_rank'] == 0 and self.tensorboard_enable: 109 | self.writer.close() 110 | self._on_train_finish() 111 | 112 | def _train_epoch(self, epoch): 113 | """ 114 | Training logic for an epoch 115 | 116 | :param epoch: Current epoch number 117 | """ 118 | raise NotImplementedError 119 | 120 | def _eval(self, epoch): 121 | """ 122 | eval logic for an epoch 123 | 124 | :param epoch: Current epoch number 125 | """ 126 | raise NotImplementedError 127 | 128 | def _on_epoch_finish(self): 129 | raise NotImplementedError 130 | 131 | def _on_train_finish(self): 132 | raise NotImplementedError 133 | 134 | def _save_checkpoint(self, epoch, file_name): 135 | """ 136 | Saving checkpoints 137 | 138 | :param epoch: current epoch number 139 | :param log: logging information of the epoch 140 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' 141 | """ 142 | state_dict = self.model.module.state_dict() if self.config['distributed'] else self.model.state_dict() 143 | state = { 144 | 'epoch': epoch, 145 | 'global_step': self.global_step, 146 | 'state_dict': state_dict, 147 | 'optimizer': self.optimizer.state_dict(), 148 | 'scheduler': self.scheduler.state_dict(), 149 | 'config': self.config, 150 | 'metrics': self.metrics 151 | } 152 | filename = os.path.join(self.checkpoint_dir, file_name) 153 | torch.save(state, filename) 154 | 155 | def _load_checkpoint(self, checkpoint_path, resume): 156 | """ 157 | Resume from saved checkpoints 158 | :param checkpoint_path: Checkpoint path to be resumed 159 | """ 160 | self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path)) 161 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) 162 | self.model.load_state_dict(checkpoint['state_dict'], strict=resume) 163 | if resume: 164 | self.global_step = checkpoint['global_step'] 165 | self.start_epoch = checkpoint['epoch'] 166 | self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch 167 | # self.scheduler.load_state_dict(checkpoint['scheduler']) 168 | self.optimizer.load_state_dict(checkpoint['optimizer']) 169 | if 'metrics' in checkpoint: 170 | self.metrics = checkpoint['metrics'] 171 | if self.with_cuda: 172 | for state in self.optimizer.state.values(): 173 | for k, v in state.items(): 174 | if isinstance(v, torch.Tensor): 175 | state[k] = v.to(self.device) 176 | self.logger_info("resume from checkpoint {} (epoch {})".format(checkpoint_path, self.start_epoch)) 177 | else: 178 | self.logger_info("finetune from checkpoint {}".format(checkpoint_path)) 179 | 180 | def _initialize(self, name, module, *args, **kwargs): 181 | module_name = self.config[name]['type'] 182 | module_args = self.config[name]['args'] 183 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 184 | module_args.update(kwargs) 185 | return getattr(module, module_name)(*args, **module_args) 186 | 187 | def inverse_normalize(self, batch_img): 188 | if self.UN_Normalize: 189 | batch_img[:, 0, :, :] = batch_img[:, 0, :, :] * self.normalize_std[0] + self.normalize_mean[0] 190 | batch_img[:, 1, :, :] = batch_img[:, 1, :, :] * self.normalize_std[1] + self.normalize_mean[1] 191 | batch_img[:, 2, :, :] = batch_img[:, 2, :, :] * self.normalize_std[2] + self.normalize_mean[2] 192 | 193 | def logger_info(self, s): 194 | if self.config['local_rank'] == 0: 195 | self.logger.info(s) 196 | -------------------------------------------------------------------------------- /config/SynthText.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | dataset: 3 | train: 4 | dataset: 5 | type: SynthTextDataset # 数据集类型 6 | args: 7 | data_path: ''# SynthTextDataset 根目录 8 | pre_processes: # 数据的预处理过程,包含augment和标签制作 9 | - type: IaaAugment # 使用imgaug进行变换 10 | args: 11 | - {'type':Fliplr, 'args':{'p':0.5}} 12 | - {'type': Affine, 'args':{'rotate':[-10,10]}} 13 | - {'type':Resize,'args':{'size':[0.5,3]}} 14 | - type: EastRandomCropData 15 | args: 16 | size: [640,640] 17 | max_tries: 50 18 | keep_ratio: true 19 | - type: MakeBorderMap 20 | args: 21 | shrink_ratio: 0.4 22 | - type: MakeShrinkMap 23 | args: 24 | shrink_ratio: 0.4 25 | min_text_size: 8 26 | transforms: # 对图片进行的变换方式 27 | - type: ToTensor 28 | args: {} 29 | - type: Normalize 30 | args: 31 | mean: [0.485, 0.456, 0.406] 32 | std: [0.229, 0.224, 0.225] 33 | img_mode: RGB 34 | filter_keys: ['img_path','img_name','text_polys','texts','ignore_tags','shape'] # 返回数据之前,从数据字典里删除的key 35 | ignore_tags: ['*', '###'] 36 | loader: 37 | batch_size: 1 38 | shuffle: true 39 | pin_memory: false 40 | num_workers: 0 41 | collate_fn: '' -------------------------------------------------------------------------------- /config/SynthText_resnet18_FPN_DBhead_polyLR.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | base: ['config/SynthText.yaml'] 3 | arch: 4 | type: Model 5 | backbone: 6 | type: resnet18 7 | pretrained: true 8 | neck: 9 | type: FPN 10 | inner_channels: 256 11 | head: 12 | type: DBHead 13 | out_channels: 2 14 | k: 50 15 | post_processing: 16 | type: SegDetectorRepresenter 17 | args: 18 | thresh: 0.3 19 | box_thresh: 0.7 20 | max_candidates: 1000 21 | unclip_ratio: 1.5 # from paper 22 | metric: 23 | type: QuadMetric 24 | args: 25 | is_output_polygon: false 26 | loss: 27 | type: DBLoss 28 | alpha: 1 29 | beta: 10 30 | ohem_ratio: 3 31 | optimizer: 32 | type: Adam 33 | args: 34 | lr: 0.001 35 | weight_decay: 0 36 | amsgrad: true 37 | lr_scheduler: 38 | type: WarmupPolyLR 39 | args: 40 | warmup_epoch: 3 41 | trainer: 42 | seed: 2 43 | epochs: 1200 44 | log_iter: 10 45 | show_images_iter: 50 46 | resume_checkpoint: '' 47 | finetune_checkpoint: '' 48 | output_dir: output 49 | tensorboard: true 50 | dataset: 51 | train: 52 | dataset: 53 | args: 54 | data_path: ./datasets/SynthText 55 | img_mode: RGB 56 | loader: 57 | batch_size: 2 58 | shuffle: true 59 | pin_memory: true 60 | num_workers: 6 61 | collate_fn: '' -------------------------------------------------------------------------------- /config/icdar2015.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | dataset: 3 | train: 4 | dataset: 5 | type: ICDAR2015Dataset # 数据集类型 6 | args: 7 | data_path: # 一个存放 img_path \t gt_path的文件 8 | - '' 9 | pre_processes: # 数据的预处理过程,包含augment和标签制作 10 | - type: IaaAugment # 使用imgaug进行变换 11 | args: 12 | - {'type':Fliplr, 'args':{'p':0.5}} 13 | - {'type': Affine, 'args':{'rotate':[-10,10]}} 14 | - {'type':Resize,'args':{'size':[0.5,3]}} 15 | - type: EastRandomCropData 16 | args: 17 | size: [640,640] 18 | max_tries: 50 19 | keep_ratio: true 20 | - type: MakeBorderMap 21 | args: 22 | shrink_ratio: 0.4 23 | thresh_min: 0.3 24 | thresh_max: 0.7 25 | - type: MakeShrinkMap 26 | args: 27 | shrink_ratio: 0.4 28 | min_text_size: 8 29 | transforms: # 对图片进行的变换方式 30 | - type: ToTensor 31 | args: {} 32 | - type: Normalize 33 | args: 34 | mean: [0.485, 0.456, 0.406] 35 | std: [0.229, 0.224, 0.225] 36 | img_mode: RGB 37 | filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key 38 | ignore_tags: ['*', '###'] 39 | loader: 40 | batch_size: 1 41 | shuffle: true 42 | pin_memory: false 43 | num_workers: 0 44 | collate_fn: '' 45 | validate: 46 | dataset: 47 | type: ICDAR2015Dataset 48 | args: 49 | data_path: 50 | - '' 51 | pre_processes: 52 | - type: ResizeShortSize 53 | args: 54 | short_size: 736 55 | resize_text_polys: false 56 | transforms: 57 | - type: ToTensor 58 | args: {} 59 | - type: Normalize 60 | args: 61 | mean: [0.485, 0.456, 0.406] 62 | std: [0.229, 0.224, 0.225] 63 | img_mode: RGB 64 | filter_keys: [] 65 | ignore_tags: ['*', '###'] 66 | loader: 67 | batch_size: 1 68 | shuffle: true 69 | pin_memory: false 70 | num_workers: 0 71 | collate_fn: ICDARCollectFN -------------------------------------------------------------------------------- /config/icdar2015_dcn_resnet18_FPN_DBhead_polyLR.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | base: ['config/icdar2015.yaml'] 3 | arch: 4 | type: Model 5 | backbone: 6 | type: deformable_resnet18 7 | pretrained: true 8 | neck: 9 | type: FPN 10 | inner_channels: 256 11 | head: 12 | type: DBHead 13 | out_channels: 2 14 | k: 50 15 | post_processing: 16 | type: SegDetectorRepresenter 17 | args: 18 | thresh: 0.3 19 | box_thresh: 0.7 20 | max_candidates: 1000 21 | unclip_ratio: 1.5 # from paper 22 | metric: 23 | type: QuadMetric 24 | args: 25 | is_output_polygon: false 26 | loss: 27 | type: DBLoss 28 | alpha: 1 29 | beta: 10 30 | ohem_ratio: 3 31 | optimizer: 32 | type: Adam 33 | args: 34 | lr: 0.001 35 | weight_decay: 0 36 | amsgrad: true 37 | lr_scheduler: 38 | type: WarmupPolyLR 39 | args: 40 | warmup_epoch: 3 41 | trainer: 42 | seed: 2 43 | epochs: 1200 44 | log_iter: 10 45 | show_images_iter: 50 46 | resume_checkpoint: '' 47 | finetune_checkpoint: '' 48 | output_dir: output 49 | tensorboard: true 50 | dataset: 51 | train: 52 | dataset: 53 | args: 54 | data_path: 55 | - ./datasets/train.txt 56 | img_mode: RGB 57 | loader: 58 | batch_size: 1 59 | shuffle: true 60 | pin_memory: true 61 | num_workers: 6 62 | collate_fn: '' 63 | validate: 64 | dataset: 65 | args: 66 | data_path: 67 | - ./datasets/test.txt 68 | pre_processes: 69 | - type: ResizeShortSize 70 | args: 71 | short_size: 736 72 | resize_text_polys: false 73 | img_mode: RGB 74 | loader: 75 | batch_size: 1 76 | shuffle: true 77 | pin_memory: false 78 | num_workers: 6 79 | collate_fn: ICDARCollectFN -------------------------------------------------------------------------------- /config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | base: ['config/icdar2015.yaml'] 3 | arch: 4 | type: Model 5 | backbone: 6 | type: resnet18 7 | pretrained: true 8 | neck: 9 | type: FPN 10 | inner_channels: 256 11 | head: 12 | type: DBHead 13 | out_channels: 2 14 | k: 50 15 | post_processing: 16 | type: SegDetectorRepresenter 17 | args: 18 | thresh: 0.3 19 | box_thresh: 0.7 20 | max_candidates: 1000 21 | unclip_ratio: 1.5 # from paper 22 | metric: 23 | type: QuadMetric 24 | args: 25 | is_output_polygon: false 26 | loss: 27 | type: DBLoss 28 | alpha: 1 29 | beta: 10 30 | ohem_ratio: 3 31 | optimizer: 32 | type: Adam 33 | args: 34 | lr: 0.001 35 | weight_decay: 0 36 | amsgrad: true 37 | lr_scheduler: 38 | type: WarmupPolyLR 39 | args: 40 | warmup_epoch: 3 41 | trainer: 42 | seed: 2 43 | epochs: 1200 44 | log_iter: 10 45 | show_images_iter: 50 46 | resume_checkpoint: '' 47 | finetune_checkpoint: '' 48 | output_dir: output 49 | tensorboard: true 50 | dataset: 51 | train: 52 | dataset: 53 | args: 54 | data_path: 55 | - ./datasets/train.txt 56 | img_mode: RGB 57 | loader: 58 | batch_size: 1 59 | shuffle: true 60 | pin_memory: true 61 | num_workers: 6 62 | collate_fn: '' 63 | validate: 64 | dataset: 65 | args: 66 | data_path: 67 | - ./datasets/test.txt 68 | pre_processes: 69 | - type: ResizeShortSize 70 | args: 71 | short_size: 736 72 | resize_text_polys: false 73 | img_mode: RGB 74 | loader: 75 | batch_size: 1 76 | shuffle: true 77 | pin_memory: false 78 | num_workers: 6 79 | collate_fn: ICDARCollectFN 80 | -------------------------------------------------------------------------------- /config/icdar2015_resnet18_FPN_DBhead_polyLR_finetune.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | base: ['config/icdar2015.yaml'] 3 | arch: 4 | type: Model 5 | backbone: 6 | type: resnet18 7 | pretrained: true 8 | neck: 9 | type: FPN 10 | inner_channels: 256 11 | head: 12 | type: DBHead 13 | out_channels: 2 14 | k: 50 15 | post_processing: 16 | type: SegDetectorRepresenter 17 | args: 18 | thresh: 0.3 19 | box_thresh: 0.7 20 | max_candidates: 1000 21 | unclip_ratio: 1.5 # from paper 22 | metric: 23 | type: QuadMetric 24 | args: 25 | is_output_polygon: false 26 | loss: 27 | type: DBLoss 28 | alpha: 1 29 | beta: 10 30 | ohem_ratio: 3 31 | optimizer: 32 | type: Adam 33 | args: 34 | lr: 0.001 35 | weight_decay: 0 36 | amsgrad: true 37 | lr_scheduler: 38 | type: StepLR 39 | args: 40 | step_size: 10 41 | gama: 0.8 42 | trainer: 43 | seed: 2 44 | epochs: 500 45 | log_iter: 10 46 | show_images_iter: 50 47 | resume_checkpoint: '' 48 | finetune_checkpoint: '' 49 | output_dir: output 50 | tensorboard: true 51 | dataset: 52 | train: 53 | dataset: 54 | args: 55 | data_path: 56 | - ./datasets/train.txt 57 | img_mode: RGB 58 | loader: 59 | batch_size: 1 60 | shuffle: true 61 | pin_memory: true 62 | num_workers: 6 63 | collate_fn: '' 64 | validate: 65 | dataset: 66 | args: 67 | data_path: 68 | - ./datasets/test.txt 69 | pre_processes: 70 | - type: ResizeShortSize 71 | args: 72 | short_size: 736 73 | resize_text_polys: false 74 | img_mode: RGB 75 | loader: 76 | batch_size: 1 77 | shuffle: true 78 | pin_memory: false 79 | num_workers: 6 80 | collate_fn: ICDARCollectFN 81 | -------------------------------------------------------------------------------- /config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | base: ['config/icdar2015.yaml'] 3 | arch: 4 | type: Model 5 | backbone: 6 | type: resnet50 7 | pretrained: true 8 | neck: 9 | type: FPN 10 | inner_channels: 256 11 | head: 12 | type: DBHead 13 | out_channels: 2 14 | k: 50 15 | post_processing: 16 | type: SegDetectorRepresenter 17 | args: 18 | thresh: 0.3 19 | box_thresh: 0.7 20 | max_candidates: 1000 21 | unclip_ratio: 1.5 # from paper 22 | metric: 23 | type: QuadMetric 24 | args: 25 | is_output_polygon: false 26 | loss: 27 | type: DBLoss 28 | alpha: 1 29 | beta: 10 30 | ohem_ratio: 3 31 | optimizer: 32 | type: Adam 33 | args: 34 | lr: 0.001 35 | weight_decay: 0 36 | amsgrad: true 37 | lr_scheduler: 38 | type: WarmupPolyLR 39 | args: 40 | warmup_epoch: 3 41 | trainer: 42 | seed: 2 43 | epochs: 1200 44 | log_iter: 10 45 | show_images_iter: 50 46 | resume_checkpoint: '' 47 | finetune_checkpoint: '' 48 | output_dir: output 49 | tensorboard: true 50 | dataset: 51 | train: 52 | dataset: 53 | args: 54 | data_path: 55 | - ./datasets/train.txt 56 | img_mode: RGB 57 | loader: 58 | batch_size: 16 59 | shuffle: true 60 | pin_memory: true 61 | num_workers: 6 62 | collate_fn: '' 63 | validate: 64 | dataset: 65 | args: 66 | data_path: 67 | - ./datasets/test.txt 68 | pre_processes: 69 | - type: ResizeShortSize 70 | args: 71 | short_size: 736 72 | resize_text_polys: false 73 | img_mode: RGB 74 | loader: 75 | batch_size: 1 76 | shuffle: true 77 | pin_memory: false 78 | num_workers: 6 79 | collate_fn: ICDARCollectFN 80 | -------------------------------------------------------------------------------- /config/open_dataset.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | dataset: 3 | train: 4 | dataset: 5 | type: DetDataset # 数据集类型 6 | args: 7 | data_path: # 一个存放 img_path \t gt_path的文件 8 | - '' 9 | pre_processes: # 数据的预处理过程,包含augment和标签制作 10 | - type: IaaAugment # 使用imgaug进行变换 11 | args: 12 | - {'type':Fliplr, 'args':{'p':0.5}} 13 | - {'type': Affine, 'args':{'rotate':[-10,10]}} 14 | - {'type':Resize,'args':{'size':[0.5,3]}} 15 | - type: EastRandomCropData 16 | args: 17 | size: [640,640] 18 | max_tries: 50 19 | keep_ratio: true 20 | - type: MakeBorderMap 21 | args: 22 | shrink_ratio: 0.4 23 | thresh_min: 0.3 24 | thresh_max: 0.7 25 | - type: MakeShrinkMap 26 | args: 27 | shrink_ratio: 0.4 28 | min_text_size: 8 29 | transforms: # 对图片进行的变换方式 30 | - type: ToTensor 31 | args: {} 32 | - type: Normalize 33 | args: 34 | mean: [0.485, 0.456, 0.406] 35 | std: [0.229, 0.224, 0.225] 36 | img_mode: RGB 37 | load_char_annotation: false 38 | expand_one_char: false 39 | filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key 40 | ignore_tags: ['*', '###'] 41 | loader: 42 | batch_size: 1 43 | shuffle: true 44 | pin_memory: false 45 | num_workers: 0 46 | collate_fn: '' 47 | validate: 48 | dataset: 49 | type: DetDataset 50 | args: 51 | data_path: 52 | - '' 53 | pre_processes: 54 | - type: ResizeShortSize 55 | args: 56 | short_size: 736 57 | resize_text_polys: false 58 | transforms: 59 | - type: ToTensor 60 | args: {} 61 | - type: Normalize 62 | args: 63 | mean: [0.485, 0.456, 0.406] 64 | std: [0.229, 0.224, 0.225] 65 | img_mode: RGB 66 | load_char_annotation: false # 是否加载字符级标注 67 | expand_one_char: false # 是否对只有一个字符的框进行宽度扩充,扩充后w = w+h 68 | filter_keys: [] 69 | ignore_tags: ['*', '###'] 70 | loader: 71 | batch_size: 1 72 | shuffle: true 73 | pin_memory: false 74 | num_workers: 0 75 | collate_fn: ICDARCollectFN -------------------------------------------------------------------------------- /config/open_dataset_dcn_resnet50_FPN_DBhead_polyLR.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | base: ['config/open_dataset.yaml'] 3 | arch: 4 | type: Model 5 | backbone: 6 | type: deformable_resnet18 7 | pretrained: true 8 | neck: 9 | type: FPN 10 | inner_channels: 256 11 | head: 12 | type: DBHead 13 | out_channels: 2 14 | k: 50 15 | post_processing: 16 | type: SegDetectorRepresenter 17 | args: 18 | thresh: 0.3 19 | box_thresh: 0.7 20 | max_candidates: 1000 21 | unclip_ratio: 1.5 # from paper 22 | metric: 23 | type: QuadMetric 24 | args: 25 | is_output_polygon: false 26 | loss: 27 | type: DBLoss 28 | alpha: 1 29 | beta: 10 30 | ohem_ratio: 3 31 | optimizer: 32 | type: Adam 33 | args: 34 | lr: 0.001 35 | weight_decay: 0 36 | amsgrad: true 37 | lr_scheduler: 38 | type: WarmupPolyLR 39 | args: 40 | warmup_epoch: 3 41 | trainer: 42 | seed: 2 43 | epochs: 1200 44 | log_iter: 1 45 | show_images_iter: 1 46 | resume_checkpoint: '' 47 | finetune_checkpoint: '' 48 | output_dir: output 49 | tensorboard: true 50 | dataset: 51 | train: 52 | dataset: 53 | args: 54 | data_path: 55 | - ./datasets/train.json 56 | img_mode: RGB 57 | load_char_annotation: false 58 | expand_one_char: false 59 | loader: 60 | batch_size: 2 61 | shuffle: true 62 | pin_memory: true 63 | num_workers: 6 64 | collate_fn: '' 65 | validate: 66 | dataset: 67 | args: 68 | data_path: 69 | - ./datasets/test.json 70 | pre_processes: 71 | - type: ResizeShortSize 72 | args: 73 | short_size: 736 74 | resize_text_polys: false 75 | img_mode: RGB 76 | load_char_annotation: false 77 | expand_one_char: false 78 | loader: 79 | batch_size: 1 80 | shuffle: true 81 | pin_memory: false 82 | num_workers: 6 83 | collate_fn: ICDARCollectFN 84 | -------------------------------------------------------------------------------- /config/open_dataset_resnest50_FPN_DBhead_polyLR.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | base: ['config/open_dataset.yaml'] 3 | arch: 4 | type: Model 5 | backbone: 6 | type: resnest50 7 | pretrained: true 8 | neck: 9 | type: FPN 10 | inner_channels: 256 11 | head: 12 | type: DBHead 13 | out_channels: 2 14 | k: 50 15 | post_processing: 16 | type: SegDetectorRepresenter 17 | args: 18 | thresh: 0.3 19 | box_thresh: 0.7 20 | max_candidates: 1000 21 | unclip_ratio: 1.5 # from paper 22 | metric: 23 | type: QuadMetric 24 | args: 25 | is_output_polygon: false 26 | loss: 27 | type: DBLoss 28 | alpha: 1 29 | beta: 10 30 | ohem_ratio: 3 31 | optimizer: 32 | type: Adam 33 | args: 34 | lr: 0.001 35 | weight_decay: 0 36 | amsgrad: true 37 | lr_scheduler: 38 | type: WarmupPolyLR 39 | args: 40 | warmup_epoch: 3 41 | trainer: 42 | seed: 2 43 | epochs: 1200 44 | log_iter: 1 45 | show_images_iter: 1 46 | resume_checkpoint: '' 47 | finetune_checkpoint: '' 48 | output_dir: output 49 | tensorboard: true 50 | dataset: 51 | train: 52 | dataset: 53 | args: 54 | data_path: 55 | - ./datasets/train.json 56 | img_mode: RGB 57 | load_char_annotation: false 58 | expand_one_char: false 59 | loader: 60 | batch_size: 2 61 | shuffle: true 62 | pin_memory: true 63 | num_workers: 6 64 | collate_fn: '' 65 | validate: 66 | dataset: 67 | args: 68 | data_path: 69 | - ./datasets/test.json 70 | pre_processes: 71 | - type: ResizeShortSize 72 | args: 73 | short_size: 736 74 | resize_text_polys: false 75 | img_mode: RGB 76 | load_char_annotation: false 77 | expand_one_char: false 78 | loader: 79 | batch_size: 1 80 | shuffle: true 81 | pin_memory: false 82 | num_workers: 6 83 | collate_fn: ICDARCollectFN 84 | -------------------------------------------------------------------------------- /config/open_dataset_resnet18_FPN_DBhead_polyLR.yaml: -------------------------------------------------------------------------------- 1 | name: DBNet 2 | base: ['config/open_dataset.yaml'] 3 | arch: 4 | type: Model 5 | backbone: 6 | type: resnet18 7 | pretrained: true 8 | neck: 9 | type: FPN 10 | inner_channels: 256 11 | head: 12 | type: DBHead 13 | out_channels: 2 14 | k: 50 15 | post_processing: 16 | type: SegDetectorRepresenter 17 | args: 18 | thresh: 0.3 19 | box_thresh: 0.7 20 | max_candidates: 1000 21 | unclip_ratio: 1.5 # from paper 22 | metric: 23 | type: QuadMetric 24 | args: 25 | is_output_polygon: false 26 | loss: 27 | type: DBLoss 28 | alpha: 1 29 | beta: 10 30 | ohem_ratio: 3 31 | optimizer: 32 | type: Adam 33 | args: 34 | lr: 0.001 35 | weight_decay: 0 36 | amsgrad: true 37 | lr_scheduler: 38 | type: WarmupPolyLR 39 | args: 40 | warmup_epoch: 3 41 | trainer: 42 | seed: 2 43 | epochs: 1200 44 | log_iter: 1 45 | show_images_iter: 1 46 | resume_checkpoint: '' 47 | finetune_checkpoint: '' 48 | output_dir: output 49 | tensorboard: true 50 | dataset: 51 | train: 52 | dataset: 53 | args: 54 | data_path: 55 | - ./datasets/train.json 56 | transforms: # 对图片进行的变换方式 57 | - type: ToTensor 58 | args: {} 59 | - type: Normalize 60 | args: 61 | mean: [0.485, 0.456, 0.406] 62 | std: [0.229, 0.224, 0.225] 63 | img_mode: RGB 64 | load_char_annotation: false 65 | expand_one_char: false 66 | loader: 67 | batch_size: 2 68 | shuffle: true 69 | pin_memory: true 70 | num_workers: 6 71 | collate_fn: '' 72 | validate: 73 | dataset: 74 | args: 75 | data_path: 76 | - ./datasets/test.json 77 | pre_processes: 78 | - type: ResizeShortSize 79 | args: 80 | short_size: 736 81 | resize_text_polys: false 82 | img_mode: RGB 83 | load_char_annotation: false 84 | expand_one_char: false 85 | loader: 86 | batch_size: 1 87 | shuffle: true 88 | pin_memory: false 89 | num_workers: 6 90 | collate_fn: ICDARCollectFN 91 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:52 3 | # @Author : zhoujun 4 | import copy 5 | 6 | import PIL 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | 12 | 13 | def get_dataset(data_path, module_name, transform, dataset_args): 14 | """ 15 | 获取训练dataset 16 | :param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ 17 | :param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset 18 | :param transform: 该数据集使用的transforms 19 | :param dataset_args: module_name的参数 20 | :return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None 21 | """ 22 | from . import dataset 23 | s_dataset = getattr(dataset, module_name)(transform=transform, data_path=data_path, 24 | **dataset_args) 25 | return s_dataset 26 | 27 | 28 | def get_transforms(transforms_config): 29 | tr_list = [] 30 | for item in transforms_config: 31 | if 'args' not in item: 32 | args = {} 33 | else: 34 | args = item['args'] 35 | cls = getattr(transforms, item['type'])(**args) 36 | tr_list.append(cls) 37 | tr_list = transforms.Compose(tr_list) 38 | return tr_list 39 | 40 | 41 | class ICDARCollectFN: 42 | def __init__(self, *args, **kwargs): 43 | pass 44 | 45 | def __call__(self, batch): 46 | data_dict = {} 47 | to_tensor_keys = [] 48 | for sample in batch: 49 | for k, v in sample.items(): 50 | if k not in data_dict: 51 | data_dict[k] = [] 52 | if isinstance(v, (np.ndarray, torch.Tensor, PIL.Image.Image)): 53 | if k not in to_tensor_keys: 54 | to_tensor_keys.append(k) 55 | data_dict[k].append(v) 56 | for k in to_tensor_keys: 57 | data_dict[k] = torch.stack(data_dict[k], 0) 58 | return data_dict 59 | 60 | 61 | def get_dataloader(module_config, distributed=False): 62 | if module_config is None: 63 | return None 64 | config = copy.deepcopy(module_config) 65 | dataset_args = config['dataset']['args'] 66 | if 'transforms' in dataset_args: 67 | img_transfroms = get_transforms(dataset_args.pop('transforms')) 68 | else: 69 | img_transfroms = None 70 | # 创建数据集 71 | dataset_name = config['dataset']['type'] 72 | data_path = dataset_args.pop('data_path') 73 | if data_path == None: 74 | return None 75 | 76 | data_path = [x for x in data_path if x is not None] 77 | if len(data_path) == 0: 78 | return None 79 | if 'collate_fn' not in config['loader'] or config['loader']['collate_fn'] is None or len(config['loader']['collate_fn']) == 0: 80 | config['loader']['collate_fn'] = None 81 | else: 82 | config['loader']['collate_fn'] = eval(config['loader']['collate_fn'])() 83 | 84 | _dataset = get_dataset(data_path=data_path, module_name=dataset_name, transform=img_transfroms, dataset_args=dataset_args) 85 | sampler = None 86 | if distributed: 87 | from torch.utils.data.distributed import DistributedSampler 88 | # 3)使用DistributedSampler 89 | sampler = DistributedSampler(_dataset) 90 | config['loader']['shuffle'] = False 91 | config['loader']['pin_memory'] = True 92 | loader = DataLoader(dataset=_dataset, sampler=sampler, **config['loader']) 93 | return loader 94 | -------------------------------------------------------------------------------- /data_loader/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:54 3 | # @Author : zhoujun 4 | import pathlib 5 | import os 6 | import cv2 7 | import numpy as np 8 | import scipy.io as sio 9 | from tqdm.auto import tqdm 10 | 11 | from base import BaseDataSet 12 | from utils import order_points_clockwise, get_datalist, load,expand_polygon 13 | 14 | 15 | class ICDAR2015Dataset(BaseDataSet): 16 | def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None, **kwargs): 17 | super().__init__(data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform) 18 | 19 | def load_data(self, data_path: str) -> list: 20 | data_list = get_datalist(data_path) 21 | t_data_list = [] 22 | for img_path, label_path in data_list: 23 | data = self._get_annotation(label_path) 24 | if len(data['text_polys']) > 0: 25 | item = {'img_path': img_path, 'img_name': pathlib.Path(img_path).stem} 26 | item.update(data) 27 | t_data_list.append(item) 28 | else: 29 | print('there is no suit bbox in {}'.format(label_path)) 30 | return t_data_list 31 | 32 | def _get_annotation(self, label_path: str) -> dict: 33 | boxes = [] 34 | texts = [] 35 | ignores = [] 36 | with open(label_path, encoding='utf-8', mode='r') as f: 37 | for line in f.readlines(): 38 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',') 39 | try: 40 | box = order_points_clockwise(np.array(list(map(float, params[:8]))).reshape(-1, 2)) 41 | if cv2.contourArea(box) > 0: 42 | boxes.append(box) 43 | label = params[8] 44 | texts.append(label) 45 | ignores.append(label in self.ignore_tags) 46 | except: 47 | print('load label failed on {}'.format(label_path)) 48 | data = { 49 | 'text_polys': np.array(boxes), 50 | 'texts': texts, 51 | 'ignore_tags': ignores, 52 | } 53 | return data 54 | 55 | 56 | class DetDataset(BaseDataSet): 57 | def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None, **kwargs): 58 | self.load_char_annotation = kwargs['load_char_annotation'] 59 | self.expand_one_char = kwargs['expand_one_char'] 60 | super().__init__(data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform) 61 | 62 | def load_data(self, data_path: str) -> list: 63 | """ 64 | 从json文件中读取出 文本行的坐标和gt,字符的坐标和gt 65 | :param data_path: 66 | :return: 67 | """ 68 | data_list = [] 69 | for path in data_path: 70 | content = load(path) 71 | for gt in tqdm(content['data_list'], desc='read file {}'.format(path)): 72 | img_path = os.path.join(content['data_root'], gt['img_name']) 73 | polygons = [] 74 | texts = [] 75 | illegibility_list = [] 76 | language_list = [] 77 | for annotation in gt['annotations']: 78 | if len(annotation['polygon']) == 0 or len(annotation['text']) == 0: 79 | continue 80 | if len(annotation['text']) > 1 and self.expand_one_char: 81 | annotation['polygon'] = expand_polygon(annotation['polygon']) 82 | polygons.append(annotation['polygon']) 83 | texts.append(annotation['text']) 84 | illegibility_list.append(annotation['illegibility']) 85 | language_list.append(annotation['language']) 86 | if self.load_char_annotation: 87 | for char_annotation in annotation['chars']: 88 | if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0: 89 | continue 90 | polygons.append(char_annotation['polygon']) 91 | texts.append(char_annotation['char']) 92 | illegibility_list.append(char_annotation['illegibility']) 93 | language_list.append(char_annotation['language']) 94 | data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': np.array(polygons), 95 | 'texts': texts, 'ignore_tags': illegibility_list}) 96 | return data_list 97 | 98 | 99 | class SynthTextDataset(BaseDataSet): 100 | def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, transform=None, **kwargs): 101 | self.transform = transform 102 | self.dataRoot = pathlib.Path(data_path) 103 | if not self.dataRoot.exists(): 104 | raise FileNotFoundError('Dataset folder is not exist.') 105 | 106 | self.targetFilePath = self.dataRoot / 'gt.mat' 107 | if not self.targetFilePath.exists(): 108 | raise FileExistsError('Target file is not exist.') 109 | targets = {} 110 | sio.loadmat(self.targetFilePath, targets, squeeze_me=True, struct_as_record=False, 111 | variable_names=['imnames', 'wordBB', 'txt']) 112 | 113 | self.imageNames = targets['imnames'] 114 | self.wordBBoxes = targets['wordBB'] 115 | self.transcripts = targets['txt'] 116 | super().__init__(data_path, img_mode, pre_processes, filter_keys, transform) 117 | 118 | def load_data(self, data_path: str) -> list: 119 | t_data_list = [] 120 | for imageName, wordBBoxes, texts in zip(self.imageNames, self.wordBBoxes, self.transcripts): 121 | item = {} 122 | wordBBoxes = np.expand_dims(wordBBoxes, axis=2) if (wordBBoxes.ndim == 2) else wordBBoxes 123 | _, _, numOfWords = wordBBoxes.shape 124 | text_polys = wordBBoxes.reshape([8, numOfWords], order='F').T # num_words * 8 125 | text_polys = text_polys.reshape(numOfWords, 4, 2) # num_of_words * 4 * 2 126 | transcripts = [word for line in texts for word in line.split()] 127 | if numOfWords != len(transcripts): 128 | continue 129 | item['img_path'] = str(self.dataRoot / imageName) 130 | item['img_name'] = (self.dataRoot / imageName).stem 131 | item['text_polys'] = text_polys 132 | item['texts'] = transcripts 133 | item['ignore_tags'] = [x in self.ignore_tags for x in transcripts] 134 | t_data_list.append(item) 135 | return t_data_list 136 | 137 | 138 | if __name__ == '__main__': 139 | import torch 140 | import anyconfig 141 | from torch.utils.data import DataLoader 142 | from torchvision import transforms 143 | 144 | from utils import parse_config, show_img, plt, draw_bbox 145 | 146 | config = anyconfig.load('config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml') 147 | config = parse_config(config) 148 | dataset_args = config['dataset']['train']['dataset']['args'] 149 | # dataset_args.pop('data_path') 150 | # data_list = [(r'E:/zj/dataset/icdar2015/train/img/img_15.jpg', 'E:/zj/dataset/icdar2015/train/gt/gt_img_15.txt')] 151 | train_data = ICDAR2015Dataset(data_path=dataset_args.pop('data_path'), transform=transforms.ToTensor(), 152 | **dataset_args) 153 | train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=True, num_workers=0) 154 | for i, data in enumerate(tqdm(train_loader)): 155 | # img = data['img'] 156 | # shrink_label = data['shrink_map'] 157 | # threshold_label = data['threshold_map'] 158 | # 159 | # print(threshold_label.shape, threshold_label.shape, img.shape) 160 | # show_img(img[0].numpy().transpose(1, 2, 0), title='img') 161 | # show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label') 162 | # show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label') 163 | # img = draw_bbox(img[0].numpy().transpose(1, 2, 0),np.array(data['text_polys'])) 164 | # show_img(img, title='draw_bbox') 165 | # plt.show() 166 | pass 167 | -------------------------------------------------------------------------------- /data_loader/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/4 10:53 3 | # @Author : zhoujun 4 | from .iaa_augment import IaaAugment 5 | from .augment import * 6 | from .random_crop_data import EastRandomCropData,PSERandomCrop 7 | from .make_border_map import MakeBorderMap 8 | from .make_shrink_map import MakeShrinkMap 9 | -------------------------------------------------------------------------------- /data_loader/modules/augment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:52 3 | # @Author : zhoujun 4 | 5 | import math 6 | import numbers 7 | import random 8 | 9 | import cv2 10 | import numpy as np 11 | from skimage.util import random_noise 12 | 13 | 14 | class RandomNoise: 15 | def __init__(self, random_rate): 16 | self.random_rate = random_rate 17 | 18 | def __call__(self, data: dict): 19 | """ 20 | 对图片加噪声 21 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 22 | :return: 23 | """ 24 | if random.random() > self.random_rate: 25 | return data 26 | data['img'] = (random_noise(data['img'], mode='gaussian', clip=True) * 255).astype(im.dtype) 27 | return data 28 | 29 | 30 | class RandomScale: 31 | def __init__(self, scales, random_rate): 32 | """ 33 | :param scales: 尺度 34 | :param ramdon_rate: 随机系数 35 | :return: 36 | """ 37 | self.random_rate = random_rate 38 | self.scales = scales 39 | 40 | def __call__(self, data: dict) -> dict: 41 | """ 42 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 43 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 44 | :return: 45 | """ 46 | if random.random() > self.random_rate: 47 | return data 48 | im = data['img'] 49 | text_polys = data['text_polys'] 50 | 51 | tmp_text_polys = text_polys.copy() 52 | rd_scale = float(np.random.choice(self.scales)) 53 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) 54 | tmp_text_polys *= rd_scale 55 | 56 | data['img'] = im 57 | data['text_polys'] = tmp_text_polys 58 | return data 59 | 60 | 61 | class RandomRotateImgBox: 62 | def __init__(self, degrees, random_rate, same_size=False): 63 | """ 64 | :param degrees: 角度,可以是一个数值或者list 65 | :param ramdon_rate: 随机系数 66 | :param same_size: 是否保持和原图一样大 67 | :return: 68 | """ 69 | if isinstance(degrees, numbers.Number): 70 | if degrees < 0: 71 | raise ValueError("If degrees is a single number, it must be positive.") 72 | degrees = (-degrees, degrees) 73 | elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray): 74 | if len(degrees) != 2: 75 | raise ValueError("If degrees is a sequence, it must be of len 2.") 76 | degrees = degrees 77 | else: 78 | raise Exception('degrees must in Number or list or tuple or np.ndarray') 79 | self.degrees = degrees 80 | self.same_size = same_size 81 | self.random_rate = random_rate 82 | 83 | def __call__(self, data: dict) -> dict: 84 | """ 85 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 86 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 87 | :return: 88 | """ 89 | if random.random() > self.random_rate: 90 | return data 91 | im = data['img'] 92 | text_polys = data['text_polys'] 93 | 94 | # ---------------------- 旋转图像 ---------------------- 95 | w = im.shape[1] 96 | h = im.shape[0] 97 | angle = np.random.uniform(self.degrees[0], self.degrees[1]) 98 | 99 | if self.same_size: 100 | nw = w 101 | nh = h 102 | else: 103 | # 角度变弧度 104 | rangle = np.deg2rad(angle) 105 | # 计算旋转之后图像的w, h 106 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) 107 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) 108 | # 构造仿射矩阵 109 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1) 110 | # 计算原图中心点到新图中心点的偏移量 111 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) 112 | # 更新仿射矩阵 113 | rot_mat[0, 2] += rot_move[0] 114 | rot_mat[1, 2] += rot_move[1] 115 | # 仿射变换 116 | rot_img = cv2.warpAffine(im, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4) 117 | 118 | # ---------------------- 矫正bbox坐标 ---------------------- 119 | # rot_mat是最终的旋转矩阵 120 | # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下 121 | rot_text_polys = list() 122 | for bbox in text_polys: 123 | point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) 124 | point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) 125 | point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) 126 | point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) 127 | rot_text_polys.append([point1, point2, point3, point4]) 128 | data['img'] = rot_img 129 | data['text_polys'] = np.array(rot_text_polys) 130 | return data 131 | 132 | 133 | class RandomResize: 134 | def __init__(self, size, random_rate, keep_ratio=False): 135 | """ 136 | :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h] 137 | :param ramdon_rate: 随机系数 138 | :param keep_ratio: 是否保持长宽比 139 | :return: 140 | """ 141 | if isinstance(size, numbers.Number): 142 | if size < 0: 143 | raise ValueError("If input_size is a single number, it must be positive.") 144 | size = (size, size) 145 | elif isinstance(size, list) or isinstance(size, tuple) or isinstance(size, np.ndarray): 146 | if len(size) != 2: 147 | raise ValueError("If input_size is a sequence, it must be of len 2.") 148 | size = (size[0], size[1]) 149 | else: 150 | raise Exception('input_size must in Number or list or tuple or np.ndarray') 151 | self.size = size 152 | self.keep_ratio = keep_ratio 153 | self.random_rate = random_rate 154 | 155 | def __call__(self, data: dict) -> dict: 156 | """ 157 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 158 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 159 | :return: 160 | """ 161 | if random.random() > self.random_rate: 162 | return data 163 | im = data['img'] 164 | text_polys = data['text_polys'] 165 | 166 | if self.keep_ratio: 167 | # 将图片短边pad到和长边一样 168 | h, w, c = im.shape 169 | max_h = max(h, self.size[0]) 170 | max_w = max(w, self.size[1]) 171 | im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8) 172 | im_padded[:h, :w] = im.copy() 173 | im = im_padded 174 | text_polys = text_polys.astype(np.float32) 175 | h, w, _ = im.shape 176 | im = cv2.resize(im, self.size) 177 | w_scale = self.size[0] / float(w) 178 | h_scale = self.size[1] / float(h) 179 | text_polys[:, :, 0] *= w_scale 180 | text_polys[:, :, 1] *= h_scale 181 | 182 | data['img'] = im 183 | data['text_polys'] = text_polys 184 | return data 185 | 186 | 187 | def resize_image(img, short_size): 188 | height, width, _ = img.shape 189 | if height < width: 190 | new_height = short_size 191 | new_width = new_height / height * width 192 | else: 193 | new_width = short_size 194 | new_height = new_width / width * height 195 | new_height = int(round(new_height / 32) * 32) 196 | new_width = int(round(new_width / 32) * 32) 197 | resized_img = cv2.resize(img, (new_width, new_height)) 198 | return resized_img, (new_width / width, new_height / height) 199 | 200 | 201 | class ResizeShortSize: 202 | def __init__(self, short_size, resize_text_polys=True): 203 | """ 204 | :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h] 205 | :return: 206 | """ 207 | self.short_size = short_size 208 | self.resize_text_polys = resize_text_polys 209 | 210 | def __call__(self, data: dict) -> dict: 211 | """ 212 | 对图片和文本框进行缩放 213 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 214 | :return: 215 | """ 216 | im = data['img'] 217 | text_polys = data['text_polys'] 218 | 219 | h, w, _ = im.shape 220 | short_edge = min(h, w) 221 | if short_edge < self.short_size: 222 | # 保证短边 >= short_size 223 | scale = self.short_size / short_edge 224 | im = cv2.resize(im, dsize=None, fx=scale, fy=scale) 225 | scale = (scale, scale) 226 | # im, scale = resize_image(im, self.short_size) 227 | if self.resize_text_polys: 228 | # text_polys *= scale 229 | text_polys[:, 0] *= scale[0] 230 | text_polys[:, 1] *= scale[1] 231 | 232 | data['img'] = im 233 | data['text_polys'] = text_polys 234 | return data 235 | 236 | 237 | class HorizontalFlip: 238 | def __init__(self, random_rate): 239 | """ 240 | 241 | :param random_rate: 随机系数 242 | """ 243 | self.random_rate = random_rate 244 | 245 | def __call__(self, data: dict) -> dict: 246 | """ 247 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 248 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 249 | :return: 250 | """ 251 | if random.random() > self.random_rate: 252 | return data 253 | im = data['img'] 254 | text_polys = data['text_polys'] 255 | 256 | flip_text_polys = text_polys.copy() 257 | flip_im = cv2.flip(im, 1) 258 | h, w, _ = flip_im.shape 259 | flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0] 260 | 261 | data['img'] = flip_im 262 | data['text_polys'] = flip_text_polys 263 | return data 264 | 265 | 266 | class VerticallFlip: 267 | def __init__(self, random_rate): 268 | """ 269 | 270 | :param random_rate: 随机系数 271 | """ 272 | self.random_rate = random_rate 273 | 274 | def __call__(self, data: dict) -> dict: 275 | """ 276 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 277 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 278 | :return: 279 | """ 280 | if random.random() > self.random_rate: 281 | return data 282 | im = data['img'] 283 | text_polys = data['text_polys'] 284 | 285 | flip_text_polys = text_polys.copy() 286 | flip_im = cv2.flip(im, 0) 287 | h, w, _ = flip_im.shape 288 | flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1] 289 | data['img'] = flip_im 290 | data['text_polys'] = flip_text_polys 291 | return data 292 | -------------------------------------------------------------------------------- /data_loader/modules/iaa_augment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/4 18:06 3 | # @Author : zhoujun 4 | import numpy as np 5 | import imgaug 6 | import imgaug.augmenters as iaa 7 | 8 | class AugmenterBuilder(object): 9 | def __init__(self): 10 | pass 11 | 12 | def build(self, args, root=True): 13 | if args is None or len(args) == 0: 14 | return None 15 | elif isinstance(args, list): 16 | if root: 17 | sequence = [self.build(value, root=False) for value in args] 18 | return iaa.Sequential(sequence) 19 | else: 20 | return getattr(iaa, args[0])(*[self.to_tuple_if_list(a) for a in args[1:]]) 21 | elif isinstance(args, dict): 22 | cls = getattr(iaa, args['type']) 23 | return cls(**{k: self.to_tuple_if_list(v) for k, v in args['args'].items()}) 24 | else: 25 | raise RuntimeError('unknown augmenter arg: ' + str(args)) 26 | 27 | def to_tuple_if_list(self, obj): 28 | if isinstance(obj, list): 29 | return tuple(obj) 30 | return obj 31 | 32 | 33 | class IaaAugment(): 34 | def __init__(self, augmenter_args): 35 | self.augmenter_args = augmenter_args 36 | self.augmenter = AugmenterBuilder().build(self.augmenter_args) 37 | 38 | def __call__(self, data): 39 | image = data['img'] 40 | shape = image.shape 41 | 42 | if self.augmenter: 43 | aug = self.augmenter.to_deterministic() 44 | data['img'] = aug.augment_image(image) 45 | data = self.may_augment_annotation(aug, data, shape) 46 | return data 47 | 48 | def may_augment_annotation(self, aug, data, shape): 49 | if aug is None: 50 | return data 51 | 52 | line_polys = [] 53 | for poly in data['text_polys']: 54 | new_poly = self.may_augment_poly(aug, shape, poly) 55 | line_polys.append(new_poly) 56 | data['text_polys'] = np.array(line_polys) 57 | return data 58 | 59 | def may_augment_poly(self, aug, img_shape, poly): 60 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] 61 | keypoints = aug.augment_keypoints( 62 | [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints 63 | poly = [(p.x, p.y) for p in keypoints] 64 | return poly 65 | -------------------------------------------------------------------------------- /data_loader/modules/make_border_map.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | np.seterr(divide='ignore',invalid='ignore') 4 | import pyclipper 5 | from shapely.geometry import Polygon 6 | 7 | 8 | class MakeBorderMap(): 9 | def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7): 10 | self.shrink_ratio = shrink_ratio 11 | self.thresh_min = thresh_min 12 | self.thresh_max = thresh_max 13 | 14 | def __call__(self, data: dict) -> dict: 15 | """ 16 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 17 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 18 | :return: 19 | """ 20 | im = data['img'] 21 | text_polys = data['text_polys'] 22 | ignore_tags = data['ignore_tags'] 23 | 24 | canvas = np.zeros(im.shape[:2], dtype=np.float32) 25 | mask = np.zeros(im.shape[:2], dtype=np.float32) 26 | 27 | for i in range(len(text_polys)): 28 | if ignore_tags[i]: 29 | continue 30 | self.draw_border_map(text_polys[i], canvas, mask=mask) 31 | canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min 32 | 33 | data['threshold_map'] = canvas 34 | data['threshold_mask'] = mask 35 | return data 36 | 37 | def draw_border_map(self, polygon, canvas, mask): 38 | polygon = np.array(polygon) 39 | assert polygon.ndim == 2 40 | assert polygon.shape[1] == 2 41 | 42 | polygon_shape = Polygon(polygon) 43 | if polygon_shape.area <= 0: 44 | return 45 | distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 46 | subject = [tuple(l) for l in polygon] 47 | padding = pyclipper.PyclipperOffset() 48 | padding.AddPath(subject, pyclipper.JT_ROUND, 49 | pyclipper.ET_CLOSEDPOLYGON) 50 | 51 | padded_polygon = np.array(padding.Execute(distance)[0]) 52 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) 53 | 54 | xmin = padded_polygon[:, 0].min() 55 | xmax = padded_polygon[:, 0].max() 56 | ymin = padded_polygon[:, 1].min() 57 | ymax = padded_polygon[:, 1].max() 58 | width = xmax - xmin + 1 59 | height = ymax - ymin + 1 60 | 61 | polygon[:, 0] = polygon[:, 0] - xmin 62 | polygon[:, 1] = polygon[:, 1] - ymin 63 | 64 | xs = np.broadcast_to( 65 | np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)) 66 | ys = np.broadcast_to( 67 | np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)) 68 | 69 | distance_map = np.zeros( 70 | (polygon.shape[0], height, width), dtype=np.float32) 71 | for i in range(polygon.shape[0]): 72 | j = (i + 1) % polygon.shape[0] 73 | absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) 74 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1) 75 | distance_map = distance_map.min(axis=0) 76 | 77 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) 78 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) 79 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) 80 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) 81 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( 82 | 1 - distance_map[ 83 | ymin_valid - ymin:ymax_valid - ymax + height, 84 | xmin_valid - xmin:xmax_valid - xmax + width], 85 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) 86 | 87 | def distance(self, xs, ys, point_1, point_2): 88 | ''' 89 | compute the distance from point to a line 90 | ys: coordinates in the first axis 91 | xs: coordinates in the second axis 92 | point_1, point_2: (x, y), the end of the line 93 | ''' 94 | height, width = xs.shape[:2] 95 | square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1]) 96 | square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1]) 97 | square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1]) 98 | 99 | cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2)) 100 | square_sin = 1 - np.square(cosin) 101 | square_sin = np.nan_to_num(square_sin) 102 | 103 | result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance) 104 | result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0] 105 | # self.extend_line(point_1, point_2, result) 106 | return result 107 | 108 | def extend_line(self, point_1, point_2, result): 109 | ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))), 110 | int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio)))) 111 | cv2.line(result, tuple(ex_point_1), tuple(point_1), 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 112 | ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))), 113 | int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio)))) 114 | cv2.line(result, tuple(ex_point_2), tuple(point_2), 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 115 | return ex_point_1, ex_point_2 116 | -------------------------------------------------------------------------------- /data_loader/modules/make_shrink_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def shrink_polygon_py(polygon, shrink_ratio): 6 | """ 7 | 对框进行缩放,返回去的比例为1/shrink_ratio 即可 8 | """ 9 | cx = polygon[:, 0].mean() 10 | cy = polygon[:, 1].mean() 11 | polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio 12 | polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio 13 | return polygon 14 | 15 | 16 | def shrink_polygon_pyclipper(polygon, shrink_ratio): 17 | from shapely.geometry import Polygon 18 | import pyclipper 19 | polygon_shape = Polygon(polygon) 20 | distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length 21 | subject = [tuple(l) for l in polygon] 22 | padding = pyclipper.PyclipperOffset() 23 | padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 24 | shrinked = padding.Execute(-distance) 25 | if shrinked == []: 26 | shrinked = np.array(shrinked) 27 | else: 28 | shrinked = np.array(shrinked[0]).reshape(-1, 2) 29 | return shrinked 30 | 31 | 32 | class MakeShrinkMap(): 33 | r''' 34 | Making binary mask from detection data with ICDAR format. 35 | Typically following the process of class `MakeICDARData`. 36 | ''' 37 | 38 | def __init__(self, min_text_size=8, shrink_ratio=0.4, shrink_type='pyclipper'): 39 | shrink_func_dict = {'py': shrink_polygon_py, 'pyclipper': shrink_polygon_pyclipper} 40 | self.shrink_func = shrink_func_dict[shrink_type] 41 | self.min_text_size = min_text_size 42 | self.shrink_ratio = shrink_ratio 43 | 44 | def __call__(self, data: dict) -> dict: 45 | """ 46 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 47 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 48 | :return: 49 | """ 50 | image = data['img'] 51 | text_polys = data['text_polys'] 52 | ignore_tags = data['ignore_tags'] 53 | 54 | h, w = image.shape[:2] 55 | text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w) 56 | gt = np.zeros((h, w), dtype=np.float32) 57 | mask = np.ones((h, w), dtype=np.float32) 58 | for i in range(len(text_polys)): 59 | polygon = text_polys[i] 60 | height = max(polygon[:, 1]) - min(polygon[:, 1]) 61 | width = max(polygon[:, 0]) - min(polygon[:, 0]) 62 | if ignore_tags[i] or min(height, width) < self.min_text_size: 63 | cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0) 64 | ignore_tags[i] = True 65 | else: 66 | shrinked = self.shrink_func(polygon, self.shrink_ratio) 67 | if shrinked.size == 0: 68 | cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0) 69 | ignore_tags[i] = True 70 | continue 71 | cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1) 72 | 73 | data['shrink_map'] = gt 74 | data['shrink_mask'] = mask 75 | return data 76 | 77 | def validate_polygons(self, polygons, ignore_tags, h, w): 78 | ''' 79 | polygons (numpy.array, required): of shape (num_instances, num_points, 2) 80 | ''' 81 | if len(polygons) == 0: 82 | return polygons, ignore_tags 83 | assert len(polygons) == len(ignore_tags) 84 | for polygon in polygons: 85 | polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) 86 | polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) 87 | 88 | for i in range(len(polygons)): 89 | area = self.polygon_area(polygons[i]) 90 | if abs(area) < 1: 91 | ignore_tags[i] = True 92 | if area > 0: 93 | polygons[i] = polygons[i][::-1, :] 94 | return polygons, ignore_tags 95 | 96 | def polygon_area(self, polygon): 97 | return cv2.contourArea(polygon) 98 | # edge = 0 99 | # for i in range(polygon.shape[0]): 100 | # next_index = (i + 1) % polygon.shape[0] 101 | # edge += (polygon[next_index, 0] - polygon[i, 0]) * (polygon[next_index, 1] - polygon[i, 1]) 102 | # 103 | # return edge / 2. 104 | 105 | 106 | if __name__ == '__main__': 107 | from shapely.geometry import Polygon 108 | import pyclipper 109 | 110 | polygon = np.array([[0, 0], [100, 10], [100, 100], [10, 90]]) 111 | a = shrink_polygon_py(polygon, 0.4) 112 | print(a) 113 | print(shrink_polygon_py(a, 1 / 0.4)) 114 | b = shrink_polygon_pyclipper(polygon, 0.4) 115 | print(b) 116 | poly = Polygon(b) 117 | distance = poly.area * 1.5 / poly.length 118 | offset = pyclipper.PyclipperOffset() 119 | offset.AddPath(b, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 120 | expanded = np.array(offset.Execute(distance)) 121 | bounding_box = cv2.minAreaRect(expanded) 122 | points = cv2.boxPoints(bounding_box) 123 | print(points) 124 | -------------------------------------------------------------------------------- /data_loader/modules/random_crop_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | # random crop algorithm similar to https://github.com/argman/EAST 8 | class EastRandomCropData(): 9 | def __init__(self, size=(640, 640), max_tries=50, min_crop_side_ratio=0.1, require_original_image=False, keep_ratio=True): 10 | self.size = size 11 | self.max_tries = max_tries 12 | self.min_crop_side_ratio = min_crop_side_ratio 13 | self.require_original_image = require_original_image 14 | self.keep_ratio = keep_ratio 15 | 16 | def __call__(self, data: dict) -> dict: 17 | """ 18 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 19 | :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} 20 | :return: 21 | """ 22 | im = data['img'] 23 | text_polys = data['text_polys'] 24 | ignore_tags = data['ignore_tags'] 25 | texts = data['texts'] 26 | all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag] 27 | # 计算crop区域 28 | crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys) 29 | # crop 图片 保持比例填充 30 | scale_w = self.size[0] / crop_w 31 | scale_h = self.size[1] / crop_h 32 | scale = min(scale_w, scale_h) 33 | h = int(crop_h * scale) 34 | w = int(crop_w * scale) 35 | if self.keep_ratio: 36 | if len(im.shape) == 3: 37 | padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype) 38 | else: 39 | padimg = np.zeros((self.size[1], self.size[0]), im.dtype) 40 | padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) 41 | img = padimg 42 | else: 43 | img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], tuple(self.size)) 44 | # crop 文本框 45 | text_polys_crop = [] 46 | ignore_tags_crop = [] 47 | texts_crop = [] 48 | for poly, text, tag in zip(text_polys, texts, ignore_tags): 49 | poly = ((poly - (crop_x, crop_y)) * scale).tolist() 50 | if not self.is_poly_outside_rect(poly, 0, 0, w, h): 51 | text_polys_crop.append(poly) 52 | ignore_tags_crop.append(tag) 53 | texts_crop.append(text) 54 | data['img'] = img 55 | data['text_polys'] = np.float32(text_polys_crop) 56 | data['ignore_tags'] = ignore_tags_crop 57 | data['texts'] = texts_crop 58 | return data 59 | 60 | def is_poly_in_rect(self, poly, x, y, w, h): 61 | poly = np.array(poly) 62 | if poly[:, 0].min() < x or poly[:, 0].max() > x + w: 63 | return False 64 | if poly[:, 1].min() < y or poly[:, 1].max() > y + h: 65 | return False 66 | return True 67 | 68 | def is_poly_outside_rect(self, poly, x, y, w, h): 69 | poly = np.array(poly) 70 | if poly[:, 0].max() < x or poly[:, 0].min() > x + w: 71 | return True 72 | if poly[:, 1].max() < y or poly[:, 1].min() > y + h: 73 | return True 74 | return False 75 | 76 | def split_regions(self, axis): 77 | regions = [] 78 | min_axis = 0 79 | for i in range(1, axis.shape[0]): 80 | if axis[i] != axis[i - 1] + 1: 81 | region = axis[min_axis:i] 82 | min_axis = i 83 | regions.append(region) 84 | return regions 85 | 86 | def random_select(self, axis, max_size): 87 | xx = np.random.choice(axis, size=2) 88 | xmin = np.min(xx) 89 | xmax = np.max(xx) 90 | xmin = np.clip(xmin, 0, max_size - 1) 91 | xmax = np.clip(xmax, 0, max_size - 1) 92 | return xmin, xmax 93 | 94 | def region_wise_random_select(self, regions, max_size): 95 | selected_index = list(np.random.choice(len(regions), 2)) 96 | selected_values = [] 97 | for index in selected_index: 98 | axis = regions[index] 99 | xx = int(np.random.choice(axis, size=1)) 100 | selected_values.append(xx) 101 | xmin = min(selected_values) 102 | xmax = max(selected_values) 103 | return xmin, xmax 104 | 105 | def crop_area(self, im, text_polys): 106 | h, w = im.shape[:2] 107 | h_array = np.zeros(h, dtype=np.int32) 108 | w_array = np.zeros(w, dtype=np.int32) 109 | for points in text_polys: 110 | points = np.round(points, decimals=0).astype(np.int32) 111 | minx = np.min(points[:, 0]) 112 | maxx = np.max(points[:, 0]) 113 | w_array[minx:maxx] = 1 114 | miny = np.min(points[:, 1]) 115 | maxy = np.max(points[:, 1]) 116 | h_array[miny:maxy] = 1 117 | # ensure the cropped area not across a text 118 | h_axis = np.where(h_array == 0)[0] 119 | w_axis = np.where(w_array == 0)[0] 120 | 121 | if len(h_axis) == 0 or len(w_axis) == 0: 122 | return 0, 0, w, h 123 | 124 | h_regions = self.split_regions(h_axis) 125 | w_regions = self.split_regions(w_axis) 126 | 127 | for i in range(self.max_tries): 128 | if len(w_regions) > 1: 129 | xmin, xmax = self.region_wise_random_select(w_regions, w) 130 | else: 131 | xmin, xmax = self.random_select(w_axis, w) 132 | if len(h_regions) > 1: 133 | ymin, ymax = self.region_wise_random_select(h_regions, h) 134 | else: 135 | ymin, ymax = self.random_select(h_axis, h) 136 | 137 | if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h: 138 | # area too small 139 | continue 140 | num_poly_in_rect = 0 141 | for poly in text_polys: 142 | if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin): 143 | num_poly_in_rect += 1 144 | break 145 | 146 | if num_poly_in_rect > 0: 147 | return xmin, ymin, xmax - xmin, ymax - ymin 148 | 149 | return 0, 0, w, h 150 | 151 | 152 | class PSERandomCrop(): 153 | def __init__(self, size): 154 | self.size = size 155 | 156 | def __call__(self, data): 157 | imgs = data['imgs'] 158 | 159 | h, w = imgs[0].shape[0:2] 160 | th, tw = self.size 161 | if w == tw and h == th: 162 | return imgs 163 | 164 | # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制 165 | if np.max(imgs[2]) > 0 and random.random() > 3 / 8: 166 | # 文本实例的左上角点 167 | tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size 168 | tl[tl < 0] = 0 169 | # 文本实例的右下角点 170 | br = np.max(np.where(imgs[2] > 0), axis=1) - self.size 171 | br[br < 0] = 0 172 | # 保证选到右下角点时,有足够的距离进行crop 173 | br[0] = min(br[0], h - th) 174 | br[1] = min(br[1], w - tw) 175 | 176 | for _ in range(50000): 177 | i = random.randint(tl[0], br[0]) 178 | j = random.randint(tl[1], br[1]) 179 | # 保证shrink_label_map有文本 180 | if imgs[1][i:i + th, j:j + tw].sum() <= 0: 181 | continue 182 | else: 183 | break 184 | else: 185 | i = random.randint(0, h - th) 186 | j = random.randint(0, w - tw) 187 | 188 | # return i, j, th, tw 189 | for idx in range(len(imgs)): 190 | if len(imgs[idx].shape) == 3: 191 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] 192 | else: 193 | imgs[idx] = imgs[idx][i:i + th, j:j + tw] 194 | data['imgs'] = imgs 195 | return data 196 | -------------------------------------------------------------------------------- /datasets/test.txt: -------------------------------------------------------------------------------- 1 | ./datasets/test/img/001.jpg ./datasets/test/gt/001.txt 2 | ./datasets/test/img/002.jpg ./datasets/test/gt/002.txt 3 | -------------------------------------------------------------------------------- /datasets/test/gt/README.MD: -------------------------------------------------------------------------------- 1 | Place the `.txt` ground truth here, with format of `x1, y1, x2, y2, x3, y3, x4, y4, annotation` 2 | -------------------------------------------------------------------------------- /datasets/test/img/README.MD: -------------------------------------------------------------------------------- 1 | Place the `.jpg` files here. 2 | -------------------------------------------------------------------------------- /datasets/train.txt: -------------------------------------------------------------------------------- 1 | ./datasets/train/img/001.jpg ./datasets/train/gt/001.txt 2 | ./datasets/train/img/002.jpg ./datasets/train/gt/002.txt 3 | -------------------------------------------------------------------------------- /datasets/train/gt/README.MD: -------------------------------------------------------------------------------- 1 | Place the `.txt` ground truth here, with format of `x1, y1, x2, y2, x3, y3, x4, y4, annotation` 2 | -------------------------------------------------------------------------------- /datasets/train/img/README.MD: -------------------------------------------------------------------------------- 1 | Place the `.jpg` files here. 2 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dbnet 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - pytorch==1.4 8 | - torchvision==0.5 9 | - anyconfig==0.9.10 10 | - future==0.18.2 11 | - imgaug==0.4.0 12 | - matplotlib==3.1.2 13 | - numpy==1.17.4 14 | - opencv 15 | - pyclipper 16 | - PyYAML==5.2 17 | - scikit-image==0.16.2 18 | - Shapely==1.6.4 19 | - tensorboard=2 20 | - tqdm==4.40.1 21 | - ipython 22 | - pip 23 | - pytorch 24 | - torchvision 25 | - pip: 26 | - polygon3 27 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py --model_path '' -------------------------------------------------------------------------------- /generate_lists.sh: -------------------------------------------------------------------------------- 1 | #Only use if your file names of the images and txts are identical 2 | rm ./datasets/train_img.txt 3 | rm ./datasets/train_gt.txt 4 | rm ./datasets/test_img.txt 5 | rm ./datasets/test_gt.txt 6 | rm ./datasets/train.txt 7 | rm ./datasets/test.txt 8 | ls ./datasets/train/img/*.jpg > ./datasets/train_img.txt 9 | ls ./datasets/train/gt/*.txt > ./datasets/train_gt.txt 10 | ls ./datasets/test/img/*.jpg > ./datasets/test_img.txt 11 | ls ./datasets/test/gt/*.txt > ./datasets/test_gt.txt 12 | paste ./datasets/train_img.txt ./datasets/train_gt.txt > ./datasets/train.txt 13 | paste ./datasets/test_img.txt ./datasets/test_gt.txt > ./datasets/test.txt 14 | rm ./datasets/train_img.txt 15 | rm ./datasets/train_gt.txt 16 | rm ./datasets/test_img.txt 17 | rm ./datasets/test_gt.txt 18 | -------------------------------------------------------------------------------- /imgs/paper/db.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/DBNet.pytorch/e03acf0e6b3b62f7d1dc7e10a6d2587456ac9ea1/imgs/paper/db.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:55 3 | # @Author : zhoujun 4 | import copy 5 | from .model import Model 6 | from .losses import build_loss 7 | 8 | __all__ = ['build_loss', 'build_model'] 9 | support_model = ['Model'] 10 | 11 | 12 | def build_model(config): 13 | """ 14 | get architecture model class 15 | """ 16 | copy_config = copy.deepcopy(config) 17 | arch_type = copy_config.pop('type') 18 | assert arch_type in support_model, f'{arch_type} is not developed yet!, only {support_model} are support now' 19 | arch_model = eval(arch_type)(copy_config) 20 | return arch_model 21 | -------------------------------------------------------------------------------- /models/backbone/MobilenetV3.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class HSwish(nn.Module): 9 | def forward(self, x): 10 | out = x * F.relu6(x + 3, inplace=True) / 6 11 | return out 12 | 13 | 14 | class HardSigmoid(nn.Module): 15 | def __init__(self, slope=.2, offset=.5): 16 | super().__init__() 17 | self.slope = slope 18 | self.offset = offset 19 | 20 | def forward(self, x): 21 | x = (self.slope * x) + self.offset 22 | x = F.threshold(-x, -1, -1) 23 | x = F.threshold(-x, 0, 0) 24 | return x 25 | 26 | 27 | class ConvBNACT(nn.Module): 28 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None): 29 | super().__init__() 30 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 31 | stride=stride, padding=padding, groups=groups, 32 | bias=False) 33 | self.bn = nn.BatchNorm2d(out_channels) 34 | if act == 'relu': 35 | self.act = nn.ReLU() 36 | elif act == 'hard_swish': 37 | self.act = HSwish() 38 | elif act is None: 39 | self.act = None 40 | 41 | def forward(self, x): 42 | x = self.conv(x) 43 | x = self.bn(x) 44 | if self.act is not None: 45 | x = self.act(x) 46 | return x 47 | 48 | 49 | class SEBlock(nn.Module): 50 | def __init__(self, in_channels, out_channels, ratio=4): 51 | super().__init__() 52 | num_mid_filter = out_channels // ratio 53 | self.pool = nn.AdaptiveAvgPool2d(1) 54 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_mid_filter, kernel_size=1, bias=True) 55 | self.relu1 = nn.ReLU() 56 | self.conv2 = nn.Conv2d(in_channels=num_mid_filter, kernel_size=1, out_channels=out_channels, bias=True) 57 | self.relu2 = HardSigmoid() 58 | 59 | def forward(self, x): 60 | attn = self.pool(x) 61 | attn = self.conv1(attn) 62 | attn = self.relu1(attn) 63 | attn = self.conv2(attn) 64 | attn = self.relu2(attn) 65 | return x * attn 66 | 67 | 68 | class ResidualUnit(nn.Module): 69 | def __init__(self, num_in_filter, num_mid_filter, num_out_filter, stride, kernel_size, act=None, use_se=False): 70 | super().__init__() 71 | self.conv0 = ConvBNACT(in_channels=num_in_filter, out_channels=num_mid_filter, kernel_size=1, stride=1, 72 | padding=0, act=act) 73 | 74 | self.conv1 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_mid_filter, kernel_size=kernel_size, 75 | stride=stride, 76 | padding=int((kernel_size - 1) // 2), act=act, groups=num_mid_filter) 77 | if use_se: 78 | self.se = SEBlock(in_channels=num_mid_filter, out_channels=num_mid_filter) 79 | else: 80 | self.se = None 81 | 82 | self.conv2 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_out_filter, kernel_size=1, stride=1, 83 | padding=0) 84 | self.not_add = num_in_filter != num_out_filter or stride != 1 85 | 86 | def forward(self, x): 87 | y = self.conv0(x) 88 | y = self.conv1(y) 89 | if self.se is not None: 90 | y = self.se(y) 91 | y = self.conv2(y) 92 | if not self.not_add: 93 | y = x + y 94 | return y 95 | 96 | 97 | class MobileNetV3(nn.Module): 98 | def __init__(self, in_channels=3, **kwargs): 99 | """ 100 | the MobilenetV3 backbone network for detection module. 101 | Args: 102 | params(dict): the super parameters for build network 103 | """ 104 | super().__init__() 105 | self.scale = kwargs.get('scale', 0.5) 106 | model_name = kwargs.get('model_name', 'large') 107 | self.inplanes = 16 108 | if model_name == "large": 109 | self.cfg = [ 110 | # k, exp, c, se, nl, s, 111 | [3, 16, 16, False, 'relu', 1], 112 | [3, 64, 24, False, 'relu', 2], 113 | [3, 72, 24, False, 'relu', 1], 114 | [5, 72, 40, True, 'relu', 2], 115 | [5, 120, 40, True, 'relu', 1], 116 | [5, 120, 40, True, 'relu', 1], 117 | [3, 240, 80, False, 'hard_swish', 2], 118 | [3, 200, 80, False, 'hard_swish', 1], 119 | [3, 184, 80, False, 'hard_swish', 1], 120 | [3, 184, 80, False, 'hard_swish', 1], 121 | [3, 480, 112, True, 'hard_swish', 1], 122 | [3, 672, 112, True, 'hard_swish', 1], 123 | [5, 672, 160, True, 'hard_swish', 2], 124 | [5, 960, 160, True, 'hard_swish', 1], 125 | [5, 960, 160, True, 'hard_swish', 1], 126 | ] 127 | self.cls_ch_squeeze = 960 128 | self.cls_ch_expand = 1280 129 | elif model_name == "small": 130 | self.cfg = [ 131 | # k, exp, c, se, nl, s, 132 | [3, 16, 16, True, 'relu', 2], 133 | [3, 72, 24, False, 'relu', 2], 134 | [3, 88, 24, False, 'relu', 1], 135 | [5, 96, 40, True, 'hard_swish', 2], 136 | [5, 240, 40, True, 'hard_swish', 1], 137 | [5, 240, 40, True, 'hard_swish', 1], 138 | [5, 120, 48, True, 'hard_swish', 1], 139 | [5, 144, 48, True, 'hard_swish', 1], 140 | [5, 288, 96, True, 'hard_swish', 2], 141 | [5, 576, 96, True, 'hard_swish', 1], 142 | [5, 576, 96, True, 'hard_swish', 1], 143 | ] 144 | self.cls_ch_squeeze = 576 145 | self.cls_ch_expand = 1280 146 | else: 147 | raise NotImplementedError("mode[" + model_name + 148 | "_model] is not implemented!") 149 | 150 | supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25] 151 | assert self.scale in supported_scale, \ 152 | "supported scale are {} but input scale is {}".format(supported_scale, self.scale) 153 | 154 | scale = self.scale 155 | inplanes = self.inplanes 156 | cfg = self.cfg 157 | cls_ch_squeeze = self.cls_ch_squeeze 158 | # conv1 159 | self.conv1 = ConvBNACT(in_channels=in_channels, 160 | out_channels=self.make_divisible(inplanes * scale), 161 | kernel_size=3, 162 | stride=2, 163 | padding=1, 164 | groups=1, 165 | act='hard_swish') 166 | i = 0 167 | inplanes = self.make_divisible(inplanes * scale) 168 | self.stages = nn.ModuleList() 169 | block_list = [] 170 | self.out_channels = [] 171 | for layer_cfg in cfg: 172 | if layer_cfg[5] == 2 and i > 2: 173 | self.out_channels.append(inplanes) 174 | self.stages.append(nn.Sequential(*block_list)) 175 | block_list = [] 176 | block = ResidualUnit(num_in_filter=inplanes, 177 | num_mid_filter=self.make_divisible(scale * layer_cfg[1]), 178 | num_out_filter=self.make_divisible(scale * layer_cfg[2]), 179 | act=layer_cfg[4], 180 | stride=layer_cfg[5], 181 | kernel_size=layer_cfg[0], 182 | use_se=layer_cfg[3]) 183 | block_list.append(block) 184 | inplanes = self.make_divisible(scale * layer_cfg[2]) 185 | i += 1 186 | self.stages.append(nn.Sequential(*block_list)) 187 | self.conv2 = ConvBNACT( 188 | in_channels=inplanes, 189 | out_channels=self.make_divisible(scale * cls_ch_squeeze), 190 | kernel_size=1, 191 | stride=1, 192 | padding=0, 193 | groups=1, 194 | act='hard_swish') 195 | self.out_channels.append(self.make_divisible(scale * cls_ch_squeeze)) 196 | 197 | def make_divisible(self, v, divisor=8, min_value=None): 198 | if min_value is None: 199 | min_value = divisor 200 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 201 | if new_v < 0.9 * v: 202 | new_v += divisor 203 | return new_v 204 | 205 | def forward(self, x): 206 | x = self.conv1(x) 207 | out = [] 208 | for stage in self.stages: 209 | x = stage(x) 210 | out.append(x) 211 | out[-1] = self.conv2(out[-1]) 212 | return out 213 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:54 3 | # @Author : zhoujun 4 | 5 | from .resnet import * 6 | from .resnest import * 7 | from .shufflenetv2 import * 8 | from .MobilenetV3 import MobileNetV3 9 | 10 | __all__ = ['build_backbone'] 11 | 12 | support_backbone = ['resnet18', 'deformable_resnet18', 'deformable_resnet50', 13 | 'resnet50', 'resnet34', 'resnet101', 'resnet152', 14 | 'resnest50', 'resnest101', 'resnest200', 'resnest269', 15 | 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 16 | 'MobileNetV3'] 17 | 18 | 19 | def build_backbone(backbone_name, **kwargs): 20 | assert backbone_name in support_backbone, f'all support backbone is {support_backbone}' 21 | backbone = eval(backbone_name)(**kwargs) 22 | return backbone 23 | -------------------------------------------------------------------------------- /models/backbone/resnest/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnest import * -------------------------------------------------------------------------------- /models/backbone/resnest/ablation.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNeSt ablation study models""" 9 | 10 | import torch 11 | from .resnet import ResNet, Bottleneck 12 | 13 | __all__ = ['resnest50_fast_1s1x64d', 'resnest50_fast_2s1x64d', 'resnest50_fast_4s1x64d', 14 | 'resnest50_fast_1s2x40d', 'resnest50_fast_2s2x40d', 'resnest50_fast_4s2x40d', 15 | 'resnest50_fast_1s4x24d'] 16 | 17 | _url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth' 18 | 19 | _model_sha256 = {name: checksum for checksum, name in [ 20 | ('d8fbf808', 'resnest50_fast_1s1x64d'), 21 | ('44938639', 'resnest50_fast_2s1x64d'), 22 | ('f74f3fc3', 'resnest50_fast_4s1x64d'), 23 | ('32830b84', 'resnest50_fast_1s2x40d'), 24 | ('9d126481', 'resnest50_fast_2s2x40d'), 25 | ('41d14ed0', 'resnest50_fast_4s2x40d'), 26 | ('d4a4f76f', 'resnest50_fast_1s4x24d'), 27 | ]} 28 | 29 | def short_hash(name): 30 | if name not in _model_sha256: 31 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 32 | return _model_sha256[name][:8] 33 | 34 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 35 | name in _model_sha256.keys() 36 | } 37 | 38 | def resnest50_fast_1s1x64d(pretrained=False, root='~/.encoding/models', **kwargs): 39 | model = ResNet(Bottleneck, [3, 4, 6, 3], 40 | radix=1, groups=1, bottleneck_width=64, 41 | deep_stem=True, stem_width=32, avg_down=True, 42 | avd=True, avd_first=True, **kwargs) 43 | if pretrained: 44 | model.load_state_dict(torch.hub.load_state_dict_from_url( 45 | resnest_model_urls['resnest50_fast_1s1x64d'], progress=True, check_hash=True)) 46 | return model 47 | 48 | def resnest50_fast_2s1x64d(pretrained=False, root='~/.encoding/models', **kwargs): 49 | model = ResNet(Bottleneck, [3, 4, 6, 3], 50 | radix=2, groups=1, bottleneck_width=64, 51 | deep_stem=True, stem_width=32, avg_down=True, 52 | avd=True, avd_first=True, **kwargs) 53 | if pretrained: 54 | model.load_state_dict(torch.hub.load_state_dict_from_url( 55 | resnest_model_urls['resnest50_fast_2s1x64d'], progress=True, check_hash=True)) 56 | return model 57 | 58 | def resnest50_fast_4s1x64d(pretrained=False, root='~/.encoding/models', **kwargs): 59 | model = ResNet(Bottleneck, [3, 4, 6, 3], 60 | radix=4, groups=1, bottleneck_width=64, 61 | deep_stem=True, stem_width=32, avg_down=True, 62 | avd=True, avd_first=True, **kwargs) 63 | if pretrained: 64 | model.load_state_dict(torch.hub.load_state_dict_from_url( 65 | resnest_model_urls['resnest50_fast_4s1x64d'], progress=True, check_hash=True)) 66 | return model 67 | 68 | def resnest50_fast_1s2x40d(pretrained=False, root='~/.encoding/models', **kwargs): 69 | model = ResNet(Bottleneck, [3, 4, 6, 3], 70 | radix=1, groups=2, bottleneck_width=40, 71 | deep_stem=True, stem_width=32, avg_down=True, 72 | avd=True, avd_first=True, **kwargs) 73 | if pretrained: 74 | model.load_state_dict(torch.hub.load_state_dict_from_url( 75 | resnest_model_urls['resnest50_fast_1s2x40d'], progress=True, check_hash=True)) 76 | return model 77 | 78 | def resnest50_fast_2s2x40d(pretrained=False, root='~/.encoding/models', **kwargs): 79 | model = ResNet(Bottleneck, [3, 4, 6, 3], 80 | radix=2, groups=2, bottleneck_width=40, 81 | deep_stem=True, stem_width=32, avg_down=True, 82 | avd=True, avd_first=True, **kwargs) 83 | if pretrained: 84 | model.load_state_dict(torch.hub.load_state_dict_from_url( 85 | resnest_model_urls['resnest50_fast_2s2x40d'], progress=True, check_hash=True)) 86 | return model 87 | 88 | def resnest50_fast_4s2x40d(pretrained=False, root='~/.encoding/models', **kwargs): 89 | model = ResNet(Bottleneck, [3, 4, 6, 3], 90 | radix=4, groups=2, bottleneck_width=40, 91 | deep_stem=True, stem_width=32, avg_down=True, 92 | avd=True, avd_first=True, **kwargs) 93 | if pretrained: 94 | model.load_state_dict(torch.hub.load_state_dict_from_url( 95 | resnest_model_urls['resnest50_fast_4s2x40d'], progress=True, check_hash=True)) 96 | return model 97 | 98 | def resnest50_fast_1s4x24d(pretrained=False, root='~/.encoding/models', **kwargs): 99 | model = ResNet(Bottleneck, [3, 4, 6, 3], 100 | radix=1, groups=4, bottleneck_width=24, 101 | deep_stem=True, stem_width=32, avg_down=True, 102 | avd=True, avd_first=True, **kwargs) 103 | if pretrained: 104 | model.load_state_dict(torch.hub.load_state_dict_from_url( 105 | resnest_model_urls['resnest50_fast_1s4x24d'], progress=True, check_hash=True)) 106 | return model 107 | -------------------------------------------------------------------------------- /models/backbone/resnest/resnest.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNeSt models""" 9 | 10 | import torch 11 | from models.backbone.resnest.resnet import ResNet, Bottleneck 12 | 13 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 14 | 15 | _url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth' 16 | 17 | _model_sha256 = {name: checksum for checksum, name in [ 18 | ('528c19ca', 'resnest50'), 19 | ('22405ba7', 'resnest101'), 20 | ('75117900', 'resnest200'), 21 | ('0cc87c48', 'resnest269'), 22 | ]} 23 | 24 | def short_hash(name): 25 | if name not in _model_sha256: 26 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 27 | return _model_sha256[name][:8] 28 | 29 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 30 | name in _model_sha256.keys() 31 | } 32 | 33 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 34 | model = ResNet(Bottleneck, [3, 4, 6, 3], 35 | radix=2, groups=1, bottleneck_width=64, 36 | deep_stem=True, stem_width=32, avg_down=True, 37 | avd=True, avd_first=False, **kwargs) 38 | if pretrained: 39 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 40 | model.load_state_dict(torch.hub.load_state_dict_from_url( 41 | resnest_model_urls['resnest50'], progress=True, check_hash=True)) 42 | return model 43 | 44 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 45 | model = ResNet(Bottleneck, [3, 4, 23, 3], 46 | radix=2, groups=1, bottleneck_width=64, 47 | deep_stem=True, stem_width=64, avg_down=True, 48 | avd=True, avd_first=False, **kwargs) 49 | if pretrained: 50 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 51 | model.load_state_dict(torch.hub.load_state_dict_from_url( 52 | resnest_model_urls['resnest101'], progress=True, check_hash=True)) 53 | return model 54 | 55 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 56 | model = ResNet(Bottleneck, [3, 24, 36, 3], 57 | radix=2, groups=1, bottleneck_width=64, 58 | deep_stem=True, stem_width=64, avg_down=True, 59 | avd=True, avd_first=False, **kwargs) 60 | if pretrained: 61 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 62 | model.load_state_dict(torch.hub.load_state_dict_from_url( 63 | resnest_model_urls['resnest200'], progress=True, check_hash=True)) 64 | return model 65 | 66 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 67 | model = ResNet(Bottleneck, [3, 30, 48, 8], 68 | radix=2, groups=1, bottleneck_width=64, 69 | deep_stem=True, stem_width=64, avg_down=True, 70 | avd=True, avd_first=False, **kwargs) 71 | if pretrained: 72 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 73 | model.load_state_dict(torch.hub.load_state_dict_from_url( 74 | resnest_model_urls['resnest269'], progress=True, check_hash=True)) 75 | return model 76 | 77 | if __name__ == '__main__': 78 | x = torch.zeros(2,3,640,640) 79 | net = resnest269(pretrained=False) 80 | y = net(x) 81 | for u in y: 82 | print(u.shape) 83 | print(net.out_channels) -------------------------------------------------------------------------------- /models/backbone/resnest/splat.py: -------------------------------------------------------------------------------- 1 | """Split-Attention""" 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU 7 | from torch.nn.modules.utils import _pair 8 | 9 | __all__ = ['SplAtConv2d'] 10 | 11 | class SplAtConv2d(Module): 12 | """Split-Attention Conv2d 13 | """ 14 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 15 | dilation=(1, 1), groups=1, bias=True, 16 | radix=2, reduction_factor=4, 17 | rectify=False, rectify_avg=False, norm_layer=None, 18 | dropblock_prob=0.0, **kwargs): 19 | super(SplAtConv2d, self).__init__() 20 | padding = _pair(padding) 21 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 22 | self.rectify_avg = rectify_avg 23 | inter_channels = max(in_channels*radix//reduction_factor, 32) 24 | self.radix = radix 25 | self.cardinality = groups 26 | self.channels = channels 27 | self.dropblock_prob = dropblock_prob 28 | if self.rectify: 29 | from rfconv import RFConv2d 30 | self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 31 | groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) 32 | else: 33 | self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 34 | groups=groups*radix, bias=bias, **kwargs) 35 | self.use_bn = norm_layer is not None 36 | if self.use_bn: 37 | self.bn0 = norm_layer(channels*radix) 38 | self.relu = ReLU(inplace=True) 39 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 40 | if self.use_bn: 41 | self.bn1 = norm_layer(inter_channels) 42 | self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) 43 | if dropblock_prob > 0.0: 44 | self.dropblock = DropBlock2D(dropblock_prob, 3) 45 | self.rsoftmax = rSoftMax(radix, groups) 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | if self.use_bn: 50 | x = self.bn0(x) 51 | if self.dropblock_prob > 0.0: 52 | x = self.dropblock(x) 53 | x = self.relu(x) 54 | 55 | batch, rchannel = x.shape[:2] 56 | if self.radix > 1: 57 | if torch.__version__ < '1.5': 58 | splited = torch.split(x, int(rchannel//self.radix), dim=1) 59 | else: 60 | splited = torch.split(x, rchannel//self.radix, dim=1) 61 | gap = sum(splited) 62 | else: 63 | gap = x 64 | gap = F.adaptive_avg_pool2d(gap, 1) 65 | gap = self.fc1(gap) 66 | 67 | if self.use_bn: 68 | gap = self.bn1(gap) 69 | gap = self.relu(gap) 70 | 71 | atten = self.fc2(gap) 72 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 73 | 74 | if self.radix > 1: 75 | if torch.__version__ < '1.5': 76 | attens = torch.split(atten, int(rchannel//self.radix), dim=1) 77 | else: 78 | attens = torch.split(atten, rchannel//self.radix, dim=1) 79 | out = sum([att*split for (att, split) in zip(attens, splited)]) 80 | else: 81 | out = atten * x 82 | return out.contiguous() 83 | 84 | class rSoftMax(nn.Module): 85 | def __init__(self, radix, cardinality): 86 | super().__init__() 87 | self.radix = radix 88 | self.cardinality = cardinality 89 | 90 | def forward(self, x): 91 | batch = x.size(0) 92 | if self.radix > 1: 93 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 94 | x = F.softmax(x, dim=1) 95 | x = x.reshape(batch, -1) 96 | else: 97 | x = torch.sigmoid(x) 98 | return x 99 | 100 | -------------------------------------------------------------------------------- /models/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | BatchNorm2d = nn.BatchNorm2d 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'deformable_resnet18', 'deformable_resnet50', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def constant_init(module, constant, bias=0): 20 | nn.init.constant_(module.weight, constant) 21 | if hasattr(module, 'bias'): 22 | nn.init.constant_(module.bias, bias) 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=1, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): 35 | super(BasicBlock, self).__init__() 36 | self.with_dcn = dcn is not None 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = BatchNorm2d(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.with_modulated_dcn = False 41 | if not self.with_dcn: 42 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 43 | else: 44 | from torchvision.ops import DeformConv2d 45 | deformable_groups = dcn.get('deformable_groups', 1) 46 | offset_channels = 18 47 | self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1) 48 | self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, bias=False) 49 | self.bn2 = BatchNorm2d(planes) 50 | self.downsample = downsample 51 | self.stride = stride 52 | 53 | def forward(self, x): 54 | residual = x 55 | 56 | out = self.conv1(x) 57 | out = self.bn1(out) 58 | out = self.relu(out) 59 | 60 | # out = self.conv2(out) 61 | if not self.with_dcn: 62 | out = self.conv2(out) 63 | else: 64 | offset = self.conv2_offset(out) 65 | out = self.conv2(out, offset) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | expansion = 4 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): 81 | super(Bottleneck, self).__init__() 82 | self.with_dcn = dcn is not None 83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 84 | self.bn1 = BatchNorm2d(planes) 85 | self.with_modulated_dcn = False 86 | if not self.with_dcn: 87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 88 | else: 89 | deformable_groups = dcn.get('deformable_groups', 1) 90 | from torchvision.ops import DeformConv2d 91 | offset_channels = 18 92 | self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, stride=stride, kernel_size=3, padding=1) 93 | self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) 94 | self.bn2 = BatchNorm2d(planes) 95 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 96 | self.bn3 = BatchNorm2d(planes * 4) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | self.dcn = dcn 101 | self.with_dcn = dcn is not None 102 | 103 | def forward(self, x): 104 | residual = x 105 | 106 | out = self.conv1(x) 107 | out = self.bn1(out) 108 | out = self.relu(out) 109 | 110 | # out = self.conv2(out) 111 | if not self.with_dcn: 112 | out = self.conv2(out) 113 | else: 114 | offset = self.conv2_offset(out) 115 | out = self.conv2(out, offset) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | residual = self.downsample(x) 124 | 125 | out += residual 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | 131 | class ResNet(nn.Module): 132 | def __init__(self, block, layers, in_channels=3, dcn=None): 133 | self.dcn = dcn 134 | self.inplanes = 64 135 | super(ResNet, self).__init__() 136 | self.out_channels = [] 137 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, 138 | bias=False) 139 | self.bn1 = BatchNorm2d(64) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dcn=dcn) 144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dcn=dcn) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dcn=dcn) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 150 | m.weight.data.normal_(0, math.sqrt(2. / n)) 151 | elif isinstance(m, BatchNorm2d): 152 | m.weight.data.fill_(1) 153 | m.bias.data.zero_() 154 | if self.dcn is not None: 155 | for m in self.modules(): 156 | if isinstance(m, Bottleneck) or isinstance(m, BasicBlock): 157 | if hasattr(m, 'conv2_offset'): 158 | constant_init(m.conv2_offset, 0) 159 | 160 | def _make_layer(self, block, planes, blocks, stride=1, dcn=None): 161 | downsample = None 162 | if stride != 1 or self.inplanes != planes * block.expansion: 163 | downsample = nn.Sequential( 164 | nn.Conv2d(self.inplanes, planes * block.expansion, 165 | kernel_size=1, stride=stride, bias=False), 166 | BatchNorm2d(planes * block.expansion), 167 | ) 168 | 169 | layers = [] 170 | layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn)) 171 | self.inplanes = planes * block.expansion 172 | for i in range(1, blocks): 173 | layers.append(block(self.inplanes, planes, dcn=dcn)) 174 | self.out_channels.append(planes * block.expansion) 175 | return nn.Sequential(*layers) 176 | 177 | def forward(self, x): 178 | x = self.conv1(x) 179 | x = self.bn1(x) 180 | x = self.relu(x) 181 | x = self.maxpool(x) 182 | 183 | x2 = self.layer1(x) 184 | x3 = self.layer2(x2) 185 | x4 = self.layer3(x3) 186 | x5 = self.layer4(x4) 187 | 188 | return x2, x3, x4, x5 189 | 190 | 191 | def resnet18(pretrained=True, **kwargs): 192 | """Constructs a ResNet-18 model. 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 197 | if pretrained: 198 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 199 | print('load from imagenet') 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 201 | return model 202 | 203 | 204 | def deformable_resnet18(pretrained=True, **kwargs): 205 | """Constructs a ResNet-18 model. 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(BasicBlock, [2, 2, 2, 2], dcn=dict(deformable_groups=1), **kwargs) 210 | if pretrained: 211 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 212 | print('load from imagenet') 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 214 | return model 215 | 216 | 217 | def resnet34(pretrained=True, **kwargs): 218 | """Constructs a ResNet-34 model. 219 | Args: 220 | pretrained (bool): If True, returns a model pre-trained on ImageNet 221 | """ 222 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 223 | if pretrained: 224 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 225 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False) 226 | return model 227 | 228 | 229 | def resnet50(pretrained=True, **kwargs): 230 | """Constructs a ResNet-50 model. 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | """ 234 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 235 | if pretrained: 236 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) 238 | return model 239 | 240 | 241 | def deformable_resnet50(pretrained=True, **kwargs): 242 | """Constructs a ResNet-50 model with deformable conv. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = ResNet(Bottleneck, [3, 4, 6, 3], dcn=dict(deformable_groups=1), **kwargs) 247 | if pretrained: 248 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 249 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) 250 | return model 251 | 252 | 253 | def resnet101(pretrained=True, **kwargs): 254 | """Constructs a ResNet-101 model. 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | """ 258 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 259 | if pretrained: 260 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 261 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False) 262 | return model 263 | 264 | 265 | def resnet152(pretrained=True, **kwargs): 266 | """Constructs a ResNet-152 model. 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | """ 270 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 271 | if pretrained: 272 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 273 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False) 274 | return model 275 | 276 | 277 | if __name__ == '__main__': 278 | import torch 279 | 280 | x = torch.zeros(2, 3, 640, 640) 281 | net = deformable_resnet50(pretrained=False) 282 | y = net(x) 283 | for u in y: 284 | print(u.shape) 285 | 286 | print(net.out_channels) 287 | -------------------------------------------------------------------------------- /models/backbone/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/11/1 15:31 3 | # @Author : zhoujun 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.models.utils import load_state_dict_from_url 8 | 9 | __all__ = [ 10 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 11 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 12 | ] 13 | 14 | model_urls = { 15 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 16 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 17 | 'shufflenetv2_x1.5': None, 18 | 'shufflenetv2_x2.0': None, 19 | } 20 | 21 | 22 | def channel_shuffle(x, groups): 23 | batchsize, num_channels, height, width = x.data.size() 24 | channels_per_group = num_channels // groups 25 | 26 | # reshape 27 | x = x.view(batchsize, groups, 28 | channels_per_group, height, width) 29 | 30 | x = torch.transpose(x, 1, 2).contiguous() 31 | 32 | # flatten 33 | x = x.view(batchsize, -1, height, width) 34 | 35 | return x 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride): 40 | super(InvertedResidual, self).__init__() 41 | 42 | if not (1 <= stride <= 3): 43 | raise ValueError('illegal stride value') 44 | self.stride = stride 45 | 46 | branch_features = oup // 2 47 | assert (self.stride != 1) or (inp == branch_features << 1) 48 | 49 | if self.stride > 1: 50 | self.branch1 = nn.Sequential( 51 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 52 | nn.BatchNorm2d(inp), 53 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 54 | nn.BatchNorm2d(branch_features), 55 | nn.ReLU(inplace=True), 56 | ) 57 | 58 | self.branch2 = nn.Sequential( 59 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 60 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 61 | nn.BatchNorm2d(branch_features), 62 | nn.ReLU(inplace=True), 63 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 64 | nn.BatchNorm2d(branch_features), 65 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 66 | nn.BatchNorm2d(branch_features), 67 | nn.ReLU(inplace=True), 68 | ) 69 | 70 | @staticmethod 71 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 72 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 73 | 74 | def forward(self, x): 75 | if self.stride == 1: 76 | x1, x2 = x.chunk(2, dim=1) 77 | out = torch.cat((x1, self.branch2(x2)), dim=1) 78 | else: 79 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 80 | 81 | out = channel_shuffle(out, 2) 82 | 83 | return out 84 | 85 | 86 | class ShuffleNetV2(nn.Module): 87 | def __init__(self, stages_repeats, stages_out_channels, in_channels=3, **kwargs): 88 | super(ShuffleNetV2, self).__init__() 89 | self.out_channels = [] 90 | if len(stages_repeats) != 3: 91 | raise ValueError('expected stages_repeats as list of 3 positive ints') 92 | if len(stages_out_channels) != 5: 93 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 94 | self._stage_out_channels = stages_out_channels 95 | 96 | output_channels = self._stage_out_channels[0] 97 | self.conv1 = nn.Sequential( 98 | nn.Conv2d(in_channels, output_channels, 3, 2, 1, bias=False), 99 | nn.BatchNorm2d(output_channels), 100 | nn.ReLU(inplace=True), 101 | ) 102 | input_channels = output_channels 103 | 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.out_channels.append(input_channels) 106 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 107 | for name, repeats, output_channels in zip( 108 | stage_names, stages_repeats, self._stage_out_channels[1:]): 109 | seq = [InvertedResidual(input_channels, output_channels, 2)] 110 | for i in range(repeats - 1): 111 | seq.append(InvertedResidual(output_channels, output_channels, 1)) 112 | setattr(self, name, nn.Sequential(*seq)) 113 | input_channels = output_channels 114 | self.out_channels.append(input_channels) 115 | output_channels = self._stage_out_channels[-1] 116 | self.conv5 = nn.Sequential( 117 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 118 | nn.BatchNorm2d(output_channels), 119 | nn.ReLU(inplace=True), 120 | ) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | c2 = self.maxpool(x) 125 | c3 = self.stage2(c2) 126 | c4 = self.stage3(c3) 127 | c5 = self.stage4(c4) 128 | # c5 = self.conv5(c5) 129 | return c2, c3, c4, c5 130 | 131 | 132 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 133 | model = ShuffleNetV2(*args, **kwargs) 134 | 135 | if pretrained: 136 | model_url = model_urls[arch] 137 | if model_url is None: 138 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 139 | else: 140 | assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' 141 | state_dict = load_state_dict_from_url(model_url, progress=progress) 142 | model.load_state_dict(state_dict, strict=False) 143 | 144 | return model 145 | 146 | 147 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 148 | """ 149 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 150 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 151 | `_. 152 | 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | progress (bool): If True, displays a progress bar of the download to stderr 156 | """ 157 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 158 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 159 | 160 | 161 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 162 | """ 163 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 164 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 165 | `_. 166 | 167 | Args: 168 | pretrained (bool): If True, returns a model pre-trained on ImageNet 169 | progress (bool): If True, displays a progress bar of the download to stderr 170 | """ 171 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 172 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 173 | 174 | 175 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): 176 | """ 177 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 178 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 179 | `_. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | progress (bool): If True, displays a progress bar of the download to stderr 184 | """ 185 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 186 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 187 | 188 | 189 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): 190 | """ 191 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 192 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 193 | `_. 194 | 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | progress (bool): If True, displays a progress bar of the download to stderr 198 | """ 199 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 200 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 201 | -------------------------------------------------------------------------------- /models/basic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/6 11:19 3 | # @Author : zhoujun 4 | from torch import nn 5 | 6 | 7 | class ConvBnRelu(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', inplace=True): 9 | super().__init__() 10 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 11 | groups=groups, bias=bias, padding_mode=padding_mode) 12 | self.bn = nn.BatchNorm2d(out_channels) 13 | self.relu = nn.ReLU(inplace=inplace) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | x = self.relu(x) 19 | return x 20 | -------------------------------------------------------------------------------- /models/head/ConvHead.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/4 14:54 3 | # @Author : zhoujun 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class ConvHead(nn.Module): 9 | def __init__(self, in_channels, out_channels,**kwargs): 10 | super().__init__() 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | return self.conv(x) -------------------------------------------------------------------------------- /models/head/DBHead.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/4 14:54 3 | # @Author : zhoujun 4 | import torch 5 | from torch import nn 6 | 7 | class DBHead(nn.Module): 8 | def __init__(self, in_channels, out_channels, k = 50): 9 | super().__init__() 10 | self.k = k 11 | self.binarize = nn.Sequential( 12 | nn.Conv2d(in_channels, in_channels // 4, 3, padding=1), 13 | nn.BatchNorm2d(in_channels // 4), 14 | nn.ReLU(inplace=True), 15 | nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), 16 | nn.BatchNorm2d(in_channels // 4), 17 | nn.ReLU(inplace=True), 18 | nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), 19 | nn.Sigmoid()) 20 | self.binarize.apply(self.weights_init) 21 | 22 | self.thresh = self._init_thresh(in_channels) 23 | self.thresh.apply(self.weights_init) 24 | 25 | def forward(self, x): 26 | shrink_maps = self.binarize(x) 27 | threshold_maps = self.thresh(x) 28 | if self.training: 29 | binary_maps = self.step_function(shrink_maps, threshold_maps) 30 | y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1) 31 | else: 32 | y = torch.cat((shrink_maps, threshold_maps), dim=1) 33 | return y 34 | 35 | def weights_init(self, m): 36 | classname = m.__class__.__name__ 37 | if classname.find('Conv') != -1: 38 | nn.init.kaiming_normal_(m.weight.data) 39 | elif classname.find('BatchNorm') != -1: 40 | m.weight.data.fill_(1.) 41 | m.bias.data.fill_(1e-4) 42 | 43 | def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False): 44 | in_channels = inner_channels 45 | if serial: 46 | in_channels += 1 47 | self.thresh = nn.Sequential( 48 | nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias), 49 | nn.BatchNorm2d(inner_channels // 4), 50 | nn.ReLU(inplace=True), 51 | self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias), 52 | nn.BatchNorm2d(inner_channels // 4), 53 | nn.ReLU(inplace=True), 54 | self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias), 55 | nn.Sigmoid()) 56 | return self.thresh 57 | 58 | def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False): 59 | if smooth: 60 | inter_out_channels = out_channels 61 | if out_channels == 1: 62 | inter_out_channels = in_channels 63 | module_list = [ 64 | nn.Upsample(scale_factor=2, mode='nearest'), 65 | nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)] 66 | if out_channels == 1: 67 | module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True)) 68 | return nn.Sequential(module_list) 69 | else: 70 | return nn.ConvTranspose2d(in_channels, out_channels, 2, 2) 71 | 72 | def step_function(self, x, y): 73 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) 74 | -------------------------------------------------------------------------------- /models/head/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/5 11:35 3 | # @Author : zhoujun 4 | from .DBHead import DBHead 5 | from .ConvHead import ConvHead 6 | 7 | __all__ = ['build_head'] 8 | support_head = ['ConvHead', 'DBHead'] 9 | 10 | 11 | def build_head(head_name, **kwargs): 12 | assert head_name in support_head, f'all support head is {support_head}' 13 | head = eval(head_name)(**kwargs) 14 | return head -------------------------------------------------------------------------------- /models/losses/DB_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:56 3 | # @Author : zhoujun 4 | from torch import nn 5 | 6 | from models.losses.basic_loss import BalanceCrossEntropyLoss, MaskL1Loss, DiceLoss 7 | 8 | 9 | class DBLoss(nn.Module): 10 | def __init__(self, alpha=1.0, beta=10, ohem_ratio=3, reduction='mean', eps=1e-6): 11 | """ 12 | Implement PSE Loss. 13 | :param alpha: binary_map loss 前面的系数 14 | :param beta: threshold_map loss 前面的系数 15 | :param ohem_ratio: OHEM的比例 16 | :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和 17 | """ 18 | super().__init__() 19 | assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']" 20 | self.alpha = alpha 21 | self.beta = beta 22 | self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio) 23 | self.dice_loss = DiceLoss(eps=eps) 24 | self.l1_loss = MaskL1Loss(eps=eps) 25 | self.ohem_ratio = ohem_ratio 26 | self.reduction = reduction 27 | 28 | def forward(self, pred, batch): 29 | shrink_maps = pred[:, 0, :, :] 30 | threshold_maps = pred[:, 1, :, :] 31 | binary_maps = pred[:, 2, :, :] 32 | 33 | loss_shrink_maps = self.bce_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask']) 34 | loss_threshold_maps = self.l1_loss(threshold_maps, batch['threshold_map'], batch['threshold_mask']) 35 | metrics = dict(loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps) 36 | if pred.size()[1] > 2: 37 | loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'], batch['shrink_mask']) 38 | metrics['loss_binary_maps'] = loss_binary_maps 39 | loss_all = self.alpha * loss_shrink_maps + self.beta * loss_threshold_maps + loss_binary_maps 40 | metrics['loss'] = loss_all 41 | else: 42 | metrics['loss'] = loss_shrink_maps 43 | return metrics 44 | -------------------------------------------------------------------------------- /models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/5 11:36 3 | # @Author : zhoujun 4 | import copy 5 | from .DB_loss import DBLoss 6 | 7 | __all__ = ['build_loss'] 8 | support_loss = ['DBLoss'] 9 | 10 | def build_loss(config): 11 | copy_config = copy.deepcopy(config) 12 | loss_type = copy_config.pop('type') 13 | assert loss_type in support_loss, f'all support loss is {support_loss}' 14 | criterion = eval(loss_type)(**copy_config) 15 | return criterion 16 | -------------------------------------------------------------------------------- /models/losses/basic_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/4 14:39 3 | # @Author : zhoujun 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class BalanceCrossEntropyLoss(nn.Module): 9 | ''' 10 | Balanced cross entropy loss. 11 | Shape: 12 | - Input: :math:`(N, 1, H, W)` 13 | - GT: :math:`(N, 1, H, W)`, same shape as the input 14 | - Mask: :math:`(N, H, W)`, same spatial shape as the input 15 | - Output: scalar. 16 | 17 | Examples:: 18 | 19 | >>> m = nn.Sigmoid() 20 | >>> loss = nn.BCELoss() 21 | >>> input = torch.randn(3, requires_grad=True) 22 | >>> target = torch.empty(3).random_(2) 23 | >>> output = loss(m(input), target) 24 | >>> output.backward() 25 | ''' 26 | 27 | def __init__(self, negative_ratio=3.0, eps=1e-6): 28 | super(BalanceCrossEntropyLoss, self).__init__() 29 | self.negative_ratio = negative_ratio 30 | self.eps = eps 31 | 32 | def forward(self, 33 | pred: torch.Tensor, 34 | gt: torch.Tensor, 35 | mask: torch.Tensor, 36 | return_origin=False): 37 | ''' 38 | Args: 39 | pred: shape :math:`(N, 1, H, W)`, the prediction of network 40 | gt: shape :math:`(N, 1, H, W)`, the target 41 | mask: shape :math:`(N, H, W)`, the mask indicates positive regions 42 | ''' 43 | positive = (gt * mask).byte() 44 | negative = ((1 - gt) * mask).byte() 45 | positive_count = int(positive.float().sum()) 46 | negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio)) 47 | loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none') 48 | positive_loss = loss * positive.float() 49 | negative_loss = loss * negative.float() 50 | # negative_loss, _ = torch.topk(negative_loss.view(-1).contiguous(), negative_count) 51 | negative_loss, _ = negative_loss.view(-1).topk(negative_count) 52 | 53 | balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps) 54 | 55 | if return_origin: 56 | return balance_loss, loss 57 | return balance_loss 58 | 59 | 60 | class DiceLoss(nn.Module): 61 | ''' 62 | Loss function from https://arxiv.org/abs/1707.03237, 63 | where iou computation is introduced heatmap manner to measure the 64 | diversity bwtween tow heatmaps. 65 | ''' 66 | 67 | def __init__(self, eps=1e-6): 68 | super(DiceLoss, self).__init__() 69 | self.eps = eps 70 | 71 | def forward(self, pred: torch.Tensor, gt, mask, weights=None): 72 | ''' 73 | pred: one or two heatmaps of shape (N, 1, H, W), 74 | the losses of tow heatmaps are added together. 75 | gt: (N, 1, H, W) 76 | mask: (N, H, W) 77 | ''' 78 | return self._compute(pred, gt, mask, weights) 79 | 80 | def _compute(self, pred, gt, mask, weights): 81 | if pred.dim() == 4: 82 | pred = pred[:, 0, :, :] 83 | gt = gt[:, 0, :, :] 84 | assert pred.shape == gt.shape 85 | assert pred.shape == mask.shape 86 | if weights is not None: 87 | assert weights.shape == mask.shape 88 | mask = weights * mask 89 | intersection = (pred * gt * mask).sum() 90 | 91 | union = (pred * mask).sum() + (gt * mask).sum() + self.eps 92 | loss = 1 - 2.0 * intersection / union 93 | assert loss <= 1 94 | return loss 95 | 96 | 97 | class MaskL1Loss(nn.Module): 98 | def __init__(self, eps=1e-6): 99 | super(MaskL1Loss, self).__init__() 100 | self.eps = eps 101 | 102 | def forward(self, pred: torch.Tensor, gt, mask): 103 | loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps) 104 | return loss 105 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:57 3 | # @Author : zhoujun 4 | from addict import Dict 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.backbone import build_backbone 9 | from models.neck import build_neck 10 | from models.head import build_head 11 | 12 | 13 | class Model(nn.Module): 14 | def __init__(self, model_config: dict): 15 | """ 16 | PANnet 17 | :param model_config: 模型配置 18 | """ 19 | super().__init__() 20 | model_config = Dict(model_config) 21 | backbone_type = model_config.backbone.pop('type') 22 | neck_type = model_config.neck.pop('type') 23 | head_type = model_config.head.pop('type') 24 | self.backbone = build_backbone(backbone_type, **model_config.backbone) 25 | self.neck = build_neck(neck_type, in_channels=self.backbone.out_channels, **model_config.neck) 26 | self.head = build_head(head_type, in_channels=self.neck.out_channels, **model_config.head) 27 | self.name = f'{backbone_type}_{neck_type}_{head_type}' 28 | 29 | def forward(self, x): 30 | _, _, H, W = x.size() 31 | backbone_out = self.backbone(x) 32 | neck_out = self.neck(backbone_out) 33 | y = self.head(neck_out) 34 | y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True) 35 | return y 36 | 37 | 38 | if __name__ == '__main__': 39 | import torch 40 | 41 | device = torch.device('cpu') 42 | x = torch.zeros(2, 3, 640, 640).to(device) 43 | 44 | model_config = { 45 | 'backbone': {'type': 'resnest50', 'pretrained': True, "in_channels": 3}, 46 | 'neck': {'type': 'FPN', 'inner_channels': 256}, # 分割头,FPN or FPEM_FFM 47 | 'head': {'type': 'DBHead', 'out_channels': 2, 'k': 50}, 48 | } 49 | model = Model(model_config=model_config).to(device) 50 | import time 51 | 52 | tic = time.time() 53 | y = model(x) 54 | print(time.time() - tic) 55 | print(y.shape) 56 | print(model.name) 57 | print(model) 58 | # torch.save(model.state_dict(), 'PAN.pth') 59 | -------------------------------------------------------------------------------- /models/neck/FPEM_FFM.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/9/13 10:29 3 | # @Author : zhoujun 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from models.basic import ConvBnRelu 9 | 10 | 11 | class FPEM_FFM(nn.Module): 12 | def __init__(self, in_channels, inner_channels=128, fpem_repeat=2, **kwargs): 13 | """ 14 | PANnet 15 | :param in_channels: 基础网络输出的维度 16 | """ 17 | super().__init__() 18 | self.conv_out = inner_channels 19 | inplace = True 20 | # reduce layers 21 | self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace) 22 | self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace) 23 | self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace) 24 | self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace) 25 | self.fpems = nn.ModuleList() 26 | for i in range(fpem_repeat): 27 | self.fpems.append(FPEM(self.conv_out)) 28 | self.out_channels = self.conv_out * 4 29 | 30 | def forward(self, x): 31 | c2, c3, c4, c5 = x 32 | # reduce channel 33 | c2 = self.reduce_conv_c2(c2) 34 | c3 = self.reduce_conv_c3(c3) 35 | c4 = self.reduce_conv_c4(c4) 36 | c5 = self.reduce_conv_c5(c5) 37 | 38 | # FPEM 39 | for i, fpem in enumerate(self.fpems): 40 | c2, c3, c4, c5 = fpem(c2, c3, c4, c5) 41 | if i == 0: 42 | c2_ffm = c2 43 | c3_ffm = c3 44 | c4_ffm = c4 45 | c5_ffm = c5 46 | else: 47 | c2_ffm += c2 48 | c3_ffm += c3 49 | c4_ffm += c4 50 | c5_ffm += c5 51 | 52 | # FFM 53 | c5 = F.interpolate(c5_ffm, c2_ffm.size()[-2:]) 54 | c4 = F.interpolate(c4_ffm, c2_ffm.size()[-2:]) 55 | c3 = F.interpolate(c3_ffm, c2_ffm.size()[-2:]) 56 | Fy = torch.cat([c2_ffm, c3, c4, c5], dim=1) 57 | return Fy 58 | 59 | 60 | class FPEM(nn.Module): 61 | def __init__(self, in_channels=128): 62 | super().__init__() 63 | self.up_add1 = SeparableConv2d(in_channels, in_channels, 1) 64 | self.up_add2 = SeparableConv2d(in_channels, in_channels, 1) 65 | self.up_add3 = SeparableConv2d(in_channels, in_channels, 1) 66 | self.down_add1 = SeparableConv2d(in_channels, in_channels, 2) 67 | self.down_add2 = SeparableConv2d(in_channels, in_channels, 2) 68 | self.down_add3 = SeparableConv2d(in_channels, in_channels, 2) 69 | 70 | def forward(self, c2, c3, c4, c5): 71 | # up阶段 72 | c4 = self.up_add1(self._upsample_add(c5, c4)) 73 | c3 = self.up_add2(self._upsample_add(c4, c3)) 74 | c2 = self.up_add3(self._upsample_add(c3, c2)) 75 | 76 | # down 阶段 77 | c3 = self.down_add1(self._upsample_add(c3, c2)) 78 | c4 = self.down_add2(self._upsample_add(c4, c3)) 79 | c5 = self.down_add3(self._upsample_add(c5, c4)) 80 | return c2, c3, c4, c5 81 | 82 | def _upsample_add(self, x, y): 83 | return F.interpolate(x, size=y.size()[2:]) + y 84 | 85 | 86 | class SeparableConv2d(nn.Module): 87 | def __init__(self, in_channels, out_channels, stride=1): 88 | super(SeparableConv2d, self).__init__() 89 | 90 | self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1, 91 | stride=stride, groups=in_channels) 92 | self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) 93 | self.bn = nn.BatchNorm2d(out_channels) 94 | self.relu = nn.ReLU() 95 | 96 | def forward(self, x): 97 | x = self.depthwise_conv(x) 98 | x = self.pointwise_conv(x) 99 | x = self.bn(x) 100 | x = self.relu(x) 101 | return x 102 | -------------------------------------------------------------------------------- /models/neck/FPN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/9/13 10:29 3 | # @Author : zhoujun 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from models.basic import ConvBnRelu 9 | 10 | 11 | class FPN(nn.Module): 12 | def __init__(self, in_channels, inner_channels=256, **kwargs): 13 | """ 14 | :param in_channels: 基础网络输出的维度 15 | :param kwargs: 16 | """ 17 | super().__init__() 18 | inplace = True 19 | self.conv_out = inner_channels 20 | inner_channels = inner_channels // 4 21 | # reduce layers 22 | self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace) 23 | self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace) 24 | self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace) 25 | self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace) 26 | # Smooth layers 27 | self.smooth_p4 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace) 28 | self.smooth_p3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace) 29 | self.smooth_p2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace) 30 | 31 | self.conv = nn.Sequential( 32 | nn.Conv2d(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1), 33 | nn.BatchNorm2d(self.conv_out), 34 | nn.ReLU(inplace=inplace) 35 | ) 36 | self.out_channels = self.conv_out 37 | 38 | def forward(self, x): 39 | c2, c3, c4, c5 = x 40 | # Top-down 41 | p5 = self.reduce_conv_c5(c5) 42 | p4 = self._upsample_add(p5, self.reduce_conv_c4(c4)) 43 | p4 = self.smooth_p4(p4) 44 | p3 = self._upsample_add(p4, self.reduce_conv_c3(c3)) 45 | p3 = self.smooth_p3(p3) 46 | p2 = self._upsample_add(p3, self.reduce_conv_c2(c2)) 47 | p2 = self.smooth_p2(p2) 48 | 49 | x = self._upsample_cat(p2, p3, p4, p5) 50 | x = self.conv(x) 51 | return x 52 | 53 | def _upsample_add(self, x, y): 54 | return F.interpolate(x, size=y.size()[2:]) + y 55 | 56 | def _upsample_cat(self, p2, p3, p4, p5): 57 | h, w = p2.size()[2:] 58 | p3 = F.interpolate(p3, size=(h, w)) 59 | p4 = F.interpolate(p4, size=(h, w)) 60 | p5 = F.interpolate(p5, size=(h, w)) 61 | return torch.cat([p2, p3, p4, p5], dim=1) 62 | -------------------------------------------------------------------------------- /models/neck/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/5 11:34 3 | # @Author : zhoujun 4 | from .FPN import FPN 5 | from .FPEM_FFM import FPEM_FFM 6 | 7 | __all__ = ['build_neck'] 8 | support_neck = ['FPN', 'FPEM_FFM'] 9 | 10 | 11 | def build_neck(neck_name, **kwargs): 12 | assert neck_name in support_neck, f'all support neck is {support_neck}' 13 | neck = eval(neck_name)(**kwargs) 14 | return neck 15 | -------------------------------------------------------------------------------- /multi_gpu_train.sh: -------------------------------------------------------------------------------- 1 | # export NCCL_P2P_DISABLE=1 2 | export NGPUS=4 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train.py --config_file "config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml" -------------------------------------------------------------------------------- /post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/5 15:17 3 | # @Author : zhoujun 4 | 5 | from .seg_detector_representer import SegDetectorRepresenter 6 | 7 | 8 | def get_post_processing(config): 9 | try: 10 | cls = eval(config['type'])(**config['args']) 11 | return cls 12 | except: 13 | return None -------------------------------------------------------------------------------- /post_processing/seg_detector_representer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import pyclipper 4 | from shapely.geometry import Polygon 5 | 6 | 7 | class SegDetectorRepresenter(): 8 | def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5): 9 | self.min_size = 3 10 | self.thresh = thresh 11 | self.box_thresh = box_thresh 12 | self.max_candidates = max_candidates 13 | self.unclip_ratio = unclip_ratio 14 | 15 | def __call__(self, batch, pred, is_output_polygon=False): 16 | ''' 17 | batch: (image, polygons, ignore_tags 18 | batch: a dict produced by dataloaders. 19 | image: tensor of shape (N, C, H, W). 20 | polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. 21 | ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. 22 | shape: the original shape of images. 23 | filename: the original filenames of images. 24 | pred: 25 | binary: text region segmentation map, with shape (N, H, W) 26 | thresh: [if exists] thresh hold prediction with shape (N, H, W) 27 | thresh_binary: [if exists] binarized with threshhold, (N, H, W) 28 | ''' 29 | pred = pred[:, 0, :, :] 30 | segmentation = self.binarize(pred) 31 | boxes_batch = [] 32 | scores_batch = [] 33 | for batch_index in range(pred.size(0)): 34 | height, width = batch['shape'][batch_index] 35 | if is_output_polygon: 36 | boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height) 37 | else: 38 | boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height) 39 | boxes_batch.append(boxes) 40 | scores_batch.append(scores) 41 | return boxes_batch, scores_batch 42 | 43 | def binarize(self, pred): 44 | return pred > self.thresh 45 | 46 | def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): 47 | ''' 48 | _bitmap: single map with shape (H, W), 49 | whose values are binarized as {0, 1} 50 | ''' 51 | 52 | assert len(_bitmap.shape) == 2 53 | bitmap = _bitmap.cpu().numpy() # The first channel 54 | pred = pred.cpu().detach().numpy() 55 | height, width = bitmap.shape 56 | boxes = [] 57 | scores = [] 58 | 59 | contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 60 | 61 | for contour in contours[:self.max_candidates]: 62 | epsilon = 0.005 * cv2.arcLength(contour, True) 63 | approx = cv2.approxPolyDP(contour, epsilon, True) 64 | points = approx.reshape((-1, 2)) 65 | if points.shape[0] < 4: 66 | continue 67 | # _, sside = self.get_mini_boxes(contour) 68 | # if sside < self.min_size: 69 | # continue 70 | score = self.box_score_fast(pred, contour.squeeze(1)) 71 | if self.box_thresh > score: 72 | continue 73 | 74 | if points.shape[0] > 2: 75 | box = self.unclip(points, unclip_ratio=self.unclip_ratio) 76 | if len(box) > 1: 77 | continue 78 | else: 79 | continue 80 | box = box.reshape(-1, 2) 81 | _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) 82 | if sside < self.min_size + 2: 83 | continue 84 | 85 | if not isinstance(dest_width, int): 86 | dest_width = dest_width.item() 87 | dest_height = dest_height.item() 88 | 89 | box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) 90 | box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) 91 | boxes.append(box) 92 | scores.append(score) 93 | return boxes, scores 94 | 95 | def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): 96 | ''' 97 | _bitmap: single map with shape (H, W), 98 | whose values are binarized as {0, 1} 99 | ''' 100 | 101 | assert len(_bitmap.shape) == 2 102 | bitmap = _bitmap.cpu().numpy() # The first channel 103 | pred = pred.cpu().detach().numpy() 104 | height, width = bitmap.shape 105 | contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 106 | num_contours = min(len(contours), self.max_candidates) 107 | boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) 108 | scores = np.zeros((num_contours,), dtype=np.float32) 109 | 110 | for index in range(num_contours): 111 | contour = contours[index].squeeze(1) 112 | points, sside = self.get_mini_boxes(contour) 113 | if sside < self.min_size: 114 | continue 115 | points = np.array(points) 116 | score = self.box_score_fast(pred, contour) 117 | if self.box_thresh > score: 118 | continue 119 | 120 | box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2) 121 | box, sside = self.get_mini_boxes(box) 122 | if sside < self.min_size + 2: 123 | continue 124 | box = np.array(box) 125 | if not isinstance(dest_width, int): 126 | dest_width = dest_width.item() 127 | dest_height = dest_height.item() 128 | 129 | box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) 130 | box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) 131 | boxes[index, :, :] = box.astype(np.int16) 132 | scores[index] = score 133 | return boxes, scores 134 | 135 | def unclip(self, box, unclip_ratio=1.5): 136 | poly = Polygon(box) 137 | distance = poly.area * unclip_ratio / poly.length 138 | offset = pyclipper.PyclipperOffset() 139 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 140 | expanded = np.array(offset.Execute(distance)) 141 | return expanded 142 | 143 | def get_mini_boxes(self, contour): 144 | bounding_box = cv2.minAreaRect(contour) 145 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 146 | 147 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 148 | if points[1][1] > points[0][1]: 149 | index_1 = 0 150 | index_4 = 1 151 | else: 152 | index_1 = 1 153 | index_4 = 0 154 | if points[3][1] > points[2][1]: 155 | index_2 = 2 156 | index_3 = 3 157 | else: 158 | index_2 = 3 159 | index_3 = 2 160 | 161 | box = [points[index_1], points[index_2], points[index_3], points[index_4]] 162 | return box, min(bounding_box[1]) 163 | 164 | def box_score_fast(self, bitmap, _box): 165 | h, w = bitmap.shape[:2] 166 | box = _box.copy() 167 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) 168 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) 169 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) 170 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) 171 | 172 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 173 | box[:, 0] = box[:, 0] - xmin 174 | box[:, 1] = box[:, 1] - ymin 175 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) 176 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] 177 | -------------------------------------------------------------------------------- /predict.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python tools/predict.py --model_path model_best.pth --input_folder ./input --output_folder ./output --thre 0.7 --polygon --show --save_result -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | anyconfig==0.9.10 2 | future==0.18.2 3 | imgaug==0.4.0 4 | matplotlib==3.1.2 5 | numpy==1.17.4 6 | opencv-python==4.1.2.30 7 | Polygon3==3.0.8 8 | pyclipper==1.1.0.post3 9 | PyYAML==5.2 10 | scikit-image==0.16.2 11 | Shapely==1.6.4.post2 12 | tensorboard==2.1.0 13 | tqdm==4.40.1 14 | torch==1.4 15 | torchvision==0.5 16 | -------------------------------------------------------------------------------- /singlel_gpu_train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 tools/train.py --config_file "config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml" -------------------------------------------------------------------------------- /test/README.MD: -------------------------------------------------------------------------------- 1 | Place the images that you want to detect here. You better named them as such: 2 | img_10.jpg 3 | img_11.jpg 4 | img_{img_id}.jpg 5 | 6 | For predicting single images, you can change the `img_path` in the `/tools/predict.py` to your image number. 7 | 8 | The result will be saved in the output_folder(default is test/output) you give in predict.sh -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/8 13:14 3 | # @Author : zhoujun -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/6/11 15:54 3 | # @Author : zhoujun 4 | import os 5 | import sys 6 | import pathlib 7 | __dir__ = pathlib.Path(os.path.abspath(__file__)) 8 | sys.path.append(str(__dir__)) 9 | sys.path.append(str(__dir__.parent.parent)) 10 | # project = 'DBNet.pytorch' # 工作项目根目录 11 | # sys.path.append(os.getcwd().split(project)[0] + project) 12 | 13 | import argparse 14 | import time 15 | import torch 16 | from tqdm.auto import tqdm 17 | 18 | 19 | class EVAL(): 20 | def __init__(self, model_path, gpu_id=0): 21 | from models import build_model 22 | from data_loader import get_dataloader 23 | from post_processing import get_post_processing 24 | from utils import get_metric 25 | self.gpu_id = gpu_id 26 | if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available(): 27 | self.device = torch.device("cuda:%s" % self.gpu_id) 28 | torch.backends.cudnn.benchmark = True 29 | else: 30 | self.device = torch.device("cpu") 31 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 32 | config = checkpoint['config'] 33 | config['arch']['backbone']['pretrained'] = False 34 | 35 | self.validate_loader = get_dataloader(config['dataset']['validate'], config['distributed']) 36 | 37 | self.model = build_model(config['arch']) 38 | self.model.load_state_dict(checkpoint['state_dict']) 39 | self.model.to(self.device) 40 | 41 | self.post_process = get_post_processing(config['post_processing']) 42 | self.metric_cls = get_metric(config['metric']) 43 | 44 | def eval(self): 45 | self.model.eval() 46 | # torch.cuda.empty_cache() # speed up evaluating after training finished 47 | raw_metrics = [] 48 | total_frame = 0.0 49 | total_time = 0.0 50 | for i, batch in tqdm(enumerate(self.validate_loader), total=len(self.validate_loader), desc='test model'): 51 | with torch.no_grad(): 52 | # 数据进行转换和丢到gpu 53 | for key, value in batch.items(): 54 | if value is not None: 55 | if isinstance(value, torch.Tensor): 56 | batch[key] = value.to(self.device) 57 | start = time.time() 58 | preds = self.model(batch['img']) 59 | boxes, scores = self.post_process(batch, preds,is_output_polygon=self.metric_cls.is_output_polygon) 60 | total_frame += batch['img'].size()[0] 61 | total_time += time.time() - start 62 | raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores)) 63 | raw_metrics.append(raw_metric) 64 | metrics = self.metric_cls.gather_measure(raw_metrics) 65 | print('FPS:{}'.format(total_frame / total_time)) 66 | return metrics['recall'].avg, metrics['precision'].avg, metrics['fmeasure'].avg 67 | 68 | 69 | def init_args(): 70 | parser = argparse.ArgumentParser(description='DBNet.pytorch') 71 | parser.add_argument('--model_path', required=False,default='output/DBNet_resnet18_FPN_DBHead/checkpoint/1.pth', type=str) 72 | args = parser.parse_args() 73 | return args 74 | 75 | 76 | if __name__ == '__main__': 77 | args = init_args() 78 | eval = EVAL(args.model_path) 79 | result = eval.eval() 80 | print(result) 81 | -------------------------------------------------------------------------------- /tools/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/24 12:06 3 | # @Author : zhoujun 4 | 5 | import os 6 | import sys 7 | import pathlib 8 | __dir__ = pathlib.Path(os.path.abspath(__file__)) 9 | sys.path.append(str(__dir__)) 10 | sys.path.append(str(__dir__.parent.parent)) 11 | 12 | # project = 'DBNet.pytorch' # 工作项目根目录 13 | # sys.path.append(os.getcwd().split(project)[0] + project) 14 | import time 15 | import cv2 16 | import torch 17 | 18 | from data_loader import get_transforms 19 | from models import build_model 20 | from post_processing import get_post_processing 21 | 22 | 23 | def resize_image(img, short_size): 24 | height, width, _ = img.shape 25 | if height < width: 26 | new_height = short_size 27 | new_width = new_height / height * width 28 | else: 29 | new_width = short_size 30 | new_height = new_width / width * height 31 | new_height = int(round(new_height / 32) * 32) 32 | new_width = int(round(new_width / 32) * 32) 33 | resized_img = cv2.resize(img, (new_width, new_height)) 34 | return resized_img 35 | 36 | 37 | class Pytorch_model: 38 | def __init__(self, model_path, post_p_thre=0.7, gpu_id=None): 39 | ''' 40 | 初始化pytorch模型 41 | :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) 42 | :param gpu_id: 在哪一块gpu上运行 43 | ''' 44 | self.gpu_id = gpu_id 45 | 46 | if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available(): 47 | self.device = torch.device("cuda:%s" % self.gpu_id) 48 | else: 49 | self.device = torch.device("cpu") 50 | print('device:', self.device) 51 | checkpoint = torch.load(model_path, map_location=self.device) 52 | 53 | config = checkpoint['config'] 54 | config['arch']['backbone']['pretrained'] = False 55 | self.model = build_model(config['arch']) 56 | self.post_process = get_post_processing(config['post_processing']) 57 | self.post_process.box_thresh = post_p_thre 58 | self.img_mode = config['dataset']['train']['dataset']['args']['img_mode'] 59 | self.model.load_state_dict(checkpoint['state_dict']) 60 | self.model.to(self.device) 61 | self.model.eval() 62 | 63 | self.transform = [] 64 | for t in config['dataset']['train']['dataset']['args']['transforms']: 65 | if t['type'] in ['ToTensor', 'Normalize']: 66 | self.transform.append(t) 67 | self.transform = get_transforms(self.transform) 68 | 69 | def predict(self, img_path: str, is_output_polygon=False, short_size: int = 1024): 70 | ''' 71 | 对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢 72 | :param img_path: 图像地址 73 | :param is_numpy: 74 | :return: 75 | ''' 76 | assert os.path.exists(img_path), 'file is not exists' 77 | img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0) 78 | if self.img_mode == 'RGB': 79 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 80 | h, w = img.shape[:2] 81 | img = resize_image(img, short_size) 82 | # 将图片由(w,h)变为(1,img_channel,h,w) 83 | tensor = self.transform(img) 84 | tensor = tensor.unsqueeze_(0) 85 | 86 | tensor = tensor.to(self.device) 87 | batch = {'shape': [(h, w)]} 88 | with torch.no_grad(): 89 | if str(self.device).__contains__('cuda'): 90 | torch.cuda.synchronize(self.device) 91 | start = time.time() 92 | preds = self.model(tensor) 93 | if str(self.device).__contains__('cuda'): 94 | torch.cuda.synchronize(self.device) 95 | box_list, score_list = self.post_process(batch, preds, is_output_polygon=is_output_polygon) 96 | box_list, score_list = box_list[0], score_list[0] 97 | if len(box_list) > 0: 98 | if is_output_polygon: 99 | idx = [x.sum() > 0 for x in box_list] 100 | box_list = [box_list[i] for i, v in enumerate(idx) if v] 101 | score_list = [score_list[i] for i, v in enumerate(idx) if v] 102 | else: 103 | idx = box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0 # 去掉全为0的框 104 | box_list, score_list = box_list[idx], score_list[idx] 105 | else: 106 | box_list, score_list = [], [] 107 | t = time.time() - start 108 | return preds[0, 0, :, :].detach().cpu().numpy(), box_list, score_list, t 109 | 110 | 111 | def save_depoly(model, input, save_path): 112 | traced_script_model = torch.jit.trace(model, input) 113 | traced_script_model.save(save_path) 114 | 115 | 116 | def init_args(): 117 | import argparse 118 | parser = argparse.ArgumentParser(description='DBNet.pytorch') 119 | parser.add_argument('--model_path', default=r'model_best.pth', type=str) 120 | parser.add_argument('--input_folder', default='./test/input', type=str, help='img path for predict') 121 | parser.add_argument('--output_folder', default='./test/output', type=str, help='img path for output') 122 | parser.add_argument('--thre', default=0.3,type=float, help='the thresh of post_processing') 123 | parser.add_argument('--polygon', action='store_true', help='output polygon or box') 124 | parser.add_argument('--show', action='store_true', help='show result') 125 | parser.add_argument('--save_resut', action='store_true', help='save box and score to txt file') 126 | args = parser.parse_args() 127 | return args 128 | 129 | 130 | if __name__ == '__main__': 131 | import pathlib 132 | from tqdm import tqdm 133 | import matplotlib.pyplot as plt 134 | from utils.util import show_img, draw_bbox, save_result, get_file_list 135 | 136 | args = init_args() 137 | print(args) 138 | os.environ['CUDA_VISIBLE_DEVICES'] = str('0') 139 | # 初始化网络 140 | model = Pytorch_model(args.model_path, post_p_thre=args.thre, gpu_id=0) 141 | img_folder = pathlib.Path(args.input_folder) 142 | for img_path in tqdm(get_file_list(args.input_folder, p_postfix=['.jpg'])): 143 | preds, boxes_list, score_list, t = model.predict(img_path, is_output_polygon=args.polygon) 144 | img = draw_bbox(cv2.imread(img_path)[:, :, ::-1], boxes_list) 145 | if args.show: 146 | show_img(preds) 147 | show_img(img, title=os.path.basename(img_path)) 148 | plt.show() 149 | # 保存结果到路径 150 | os.makedirs(args.output_folder, exist_ok=True) 151 | img_path = pathlib.Path(img_path) 152 | output_path = os.path.join(args.output_folder, img_path.stem + '_result.jpg') 153 | pred_path = os.path.join(args.output_folder, img_path.stem + '_pred.jpg') 154 | cv2.imwrite(output_path, img[:, :, ::-1]) 155 | cv2.imwrite(pred_path, preds * 255) 156 | save_result(output_path.replace('_result.jpg', '.txt'), boxes_list, score_list, args.polygon) 157 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 22:00 3 | # @Author : zhoujun 4 | 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import os 9 | 10 | import anyconfig 11 | 12 | 13 | def init_args(): 14 | parser = argparse.ArgumentParser(description='DBNet.pytorch') 15 | parser.add_argument('--config_file', default='config/open_dataset_resnet18_FPN_DBhead_polyLR.yaml', type=str) 16 | parser.add_argument('--local_rank', dest='local_rank', default=0, type=int, help='Use distributed training') 17 | 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def main(config): 23 | import torch 24 | from models import build_model, build_loss 25 | from data_loader import get_dataloader 26 | from trainer import Trainer 27 | from post_processing import get_post_processing 28 | from utils import get_metric 29 | if torch.cuda.device_count() > 1: 30 | torch.cuda.set_device(args.local_rank) 31 | torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=torch.cuda.device_count(), rank=args.local_rank) 32 | config['distributed'] = True 33 | else: 34 | config['distributed'] = False 35 | config['local_rank'] = args.local_rank 36 | 37 | train_loader = get_dataloader(config['dataset']['train'], config['distributed']) 38 | assert train_loader is not None 39 | if 'validate' in config['dataset']: 40 | validate_loader = get_dataloader(config['dataset']['validate'], False) 41 | else: 42 | validate_loader = None 43 | 44 | criterion = build_loss(config['loss']).cuda() 45 | 46 | config['arch']['backbone']['in_channels'] = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1 47 | model = build_model(config['arch']) 48 | 49 | post_p = get_post_processing(config['post_processing']) 50 | metric = get_metric(config['metric']) 51 | 52 | trainer = Trainer(config=config, 53 | model=model, 54 | criterion=criterion, 55 | train_loader=train_loader, 56 | post_process=post_p, 57 | metric_cls=metric, 58 | validate_loader=validate_loader) 59 | trainer.train() 60 | 61 | 62 | if __name__ == '__main__': 63 | import sys 64 | import pathlib 65 | __dir__ = pathlib.Path(os.path.abspath(__file__)) 66 | sys.path.append(str(__dir__)) 67 | sys.path.append(str(__dir__.parent.parent)) 68 | # project = 'DBNet.pytorch' # 工作项目根目录 69 | # sys.path.append(os.getcwd().split(project)[0] + project) 70 | 71 | from utils import parse_config 72 | 73 | args = init_args() 74 | assert os.path.exists(args.config_file) 75 | config = anyconfig.load(open(args.config_file, 'rb')) 76 | if 'base' in config: 77 | config = parse_config(config) 78 | main(config) 79 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:58 3 | # @Author : zhoujun 4 | from .trainer import Trainer -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:58 3 | # @Author : zhoujun 4 | import time 5 | 6 | import torch 7 | import torchvision.utils as vutils 8 | from tqdm import tqdm 9 | 10 | from base import BaseTrainer 11 | from utils import WarmupPolyLR, runningScore, cal_text_score 12 | 13 | 14 | class Trainer(BaseTrainer): 15 | def __init__(self, config, model, criterion, train_loader, validate_loader, metric_cls, post_process=None): 16 | super(Trainer, self).__init__(config, model, criterion) 17 | self.show_images_iter = self.config['trainer']['show_images_iter'] 18 | self.train_loader = train_loader 19 | if validate_loader is not None: 20 | assert post_process is not None and metric_cls is not None 21 | self.validate_loader = validate_loader 22 | self.post_process = post_process 23 | self.metric_cls = metric_cls 24 | self.train_loader_len = len(train_loader) 25 | if self.config['lr_scheduler']['type'] == 'WarmupPolyLR': 26 | warmup_iters = config['lr_scheduler']['args']['warmup_epoch'] * self.train_loader_len 27 | if self.start_epoch > 1: 28 | self.config['lr_scheduler']['args']['last_epoch'] = (self.start_epoch - 1) * self.train_loader_len 29 | self.scheduler = WarmupPolyLR(self.optimizer, max_iters=self.epochs * self.train_loader_len, 30 | warmup_iters=warmup_iters, **config['lr_scheduler']['args']) 31 | if self.validate_loader is not None: 32 | self.logger_info( 33 | 'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'.format( 34 | len(self.train_loader.dataset), self.train_loader_len, len(self.validate_loader.dataset), len(self.validate_loader))) 35 | else: 36 | self.logger_info('train dataset has {} samples,{} in dataloader'.format(len(self.train_loader.dataset), self.train_loader_len)) 37 | 38 | def _train_epoch(self, epoch): 39 | self.model.train() 40 | epoch_start = time.time() 41 | batch_start = time.time() 42 | train_loss = 0. 43 | running_metric_text = runningScore(2) 44 | lr = self.optimizer.param_groups[0]['lr'] 45 | 46 | for i, batch in enumerate(self.train_loader): 47 | if i >= self.train_loader_len: 48 | break 49 | self.global_step += 1 50 | lr = self.optimizer.param_groups[0]['lr'] 51 | 52 | # 数据进行转换和丢到gpu 53 | for key, value in batch.items(): 54 | if value is not None: 55 | if isinstance(value, torch.Tensor): 56 | batch[key] = value.to(self.device) 57 | cur_batch_size = batch['img'].size()[0] 58 | 59 | preds = self.model(batch['img']) 60 | loss_dict = self.criterion(preds, batch) 61 | # backward 62 | self.optimizer.zero_grad() 63 | loss_dict['loss'].backward() 64 | self.optimizer.step() 65 | if self.config['lr_scheduler']['type'] == 'WarmupPolyLR': 66 | self.scheduler.step() 67 | # acc iou 68 | score_shrink_map = cal_text_score(preds[:, 0, :, :], batch['shrink_map'], batch['shrink_mask'], running_metric_text, 69 | thred=self.config['post_processing']['args']['thresh']) 70 | 71 | # loss 和 acc 记录到日志 72 | loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item()) 73 | for idx, (key, value) in enumerate(loss_dict.items()): 74 | loss_dict[key] = value.item() 75 | if key == 'loss': 76 | continue 77 | loss_str += '{}: {:.4f}'.format(key, loss_dict[key]) 78 | if idx < len(loss_dict) - 1: 79 | loss_str += ', ' 80 | 81 | train_loss += loss_dict['loss'] 82 | acc = score_shrink_map['Mean Acc'] 83 | iou_shrink_map = score_shrink_map['Mean IoU'] 84 | 85 | if self.global_step % self.log_iter == 0: 86 | batch_time = time.time() - batch_start 87 | self.logger_info( 88 | '[{}/{}], [{}/{}], global_step: {}, speed: {:.1f} samples/sec, acc: {:.4f}, iou_shrink_map: {:.4f}, {}, lr:{:.6}, time:{:.2f}'.format( 89 | epoch, self.epochs, i + 1, self.train_loader_len, self.global_step, self.log_iter * cur_batch_size / batch_time, acc, 90 | iou_shrink_map, loss_str, lr, batch_time)) 91 | batch_start = time.time() 92 | 93 | if self.tensorboard_enable and self.config['local_rank'] == 0: 94 | # write tensorboard 95 | for key, value in loss_dict.items(): 96 | self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value, self.global_step) 97 | self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc, self.global_step) 98 | self.writer.add_scalar('TRAIN/ACC_IOU/iou_shrink_map', iou_shrink_map, self.global_step) 99 | self.writer.add_scalar('TRAIN/lr', lr, self.global_step) 100 | if self.global_step % self.show_images_iter == 0: 101 | # show images on tensorboard 102 | self.inverse_normalize(batch['img']) 103 | self.writer.add_images('TRAIN/imgs', batch['img'], self.global_step) 104 | # shrink_labels and threshold_labels 105 | shrink_labels = batch['shrink_map'] 106 | threshold_labels = batch['threshold_map'] 107 | shrink_labels[shrink_labels <= 0.5] = 0 108 | shrink_labels[shrink_labels > 0.5] = 1 109 | show_label = torch.cat([shrink_labels, threshold_labels]) 110 | show_label = vutils.make_grid(show_label.unsqueeze(1), nrow=cur_batch_size, normalize=False, padding=20, pad_value=1) 111 | self.writer.add_image('TRAIN/gt', show_label, self.global_step) 112 | # model output 113 | show_pred = [] 114 | for kk in range(preds.shape[1]): 115 | show_pred.append(preds[:, kk, :, :]) 116 | show_pred = torch.cat(show_pred) 117 | show_pred = vutils.make_grid(show_pred.unsqueeze(1), nrow=cur_batch_size, normalize=False, padding=20, pad_value=1) 118 | self.writer.add_image('TRAIN/preds', show_pred, self.global_step) 119 | return {'train_loss': train_loss / self.train_loader_len, 'lr': lr, 'time': time.time() - epoch_start, 120 | 'epoch': epoch} 121 | 122 | def _eval(self, epoch): 123 | self.model.eval() 124 | # torch.cuda.empty_cache() # speed up evaluating after training finished 125 | raw_metrics = [] 126 | total_frame = 0.0 127 | total_time = 0.0 128 | for i, batch in tqdm(enumerate(self.validate_loader), total=len(self.validate_loader), desc='test model'): 129 | with torch.no_grad(): 130 | # 数据进行转换和丢到gpu 131 | for key, value in batch.items(): 132 | if value is not None: 133 | if isinstance(value, torch.Tensor): 134 | batch[key] = value.to(self.device) 135 | start = time.time() 136 | preds = self.model(batch['img']) 137 | boxes, scores = self.post_process(batch, preds,is_output_polygon=self.metric_cls.is_output_polygon) 138 | total_frame += batch['img'].size()[0] 139 | total_time += time.time() - start 140 | raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores)) 141 | raw_metrics.append(raw_metric) 142 | metrics = self.metric_cls.gather_measure(raw_metrics) 143 | self.logger_info('FPS:{}'.format(total_frame / total_time)) 144 | return metrics['recall'].avg, metrics['precision'].avg, metrics['fmeasure'].avg 145 | 146 | def _on_epoch_finish(self): 147 | self.logger_info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format( 148 | self.epoch_result['epoch'], self.epochs, self.epoch_result['train_loss'], self.epoch_result['time'], 149 | self.epoch_result['lr'])) 150 | net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir) 151 | net_save_path_best = '{}/model_best.pth'.format(self.checkpoint_dir) 152 | 153 | if self.config['local_rank'] == 0: 154 | self._save_checkpoint(self.epoch_result['epoch'], net_save_path) 155 | save_best = False 156 | if self.validate_loader is not None and self.metric_cls is not None: # 使用f1作为最优模型指标 157 | recall, precision, hmean = self._eval(self.epoch_result['epoch']) 158 | 159 | if self.tensorboard_enable: 160 | self.writer.add_scalar('EVAL/recall', recall, self.global_step) 161 | self.writer.add_scalar('EVAL/precision', precision, self.global_step) 162 | self.writer.add_scalar('EVAL/hmean', hmean, self.global_step) 163 | self.logger_info('test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.format(recall, precision, hmean)) 164 | 165 | if hmean >= self.metrics['hmean']: 166 | save_best = True 167 | self.metrics['train_loss'] = self.epoch_result['train_loss'] 168 | self.metrics['hmean'] = hmean 169 | self.metrics['precision'] = precision 170 | self.metrics['recall'] = recall 171 | self.metrics['best_model_epoch'] = self.epoch_result['epoch'] 172 | else: 173 | if self.epoch_result['train_loss'] <= self.metrics['train_loss']: 174 | save_best = True 175 | self.metrics['train_loss'] = self.epoch_result['train_loss'] 176 | self.metrics['best_model_epoch'] = self.epoch_result['epoch'] 177 | best_str = 'current best, ' 178 | for k, v in self.metrics.items(): 179 | best_str += '{}: {:.6f}, '.format(k, v) 180 | self.logger_info(best_str) 181 | if save_best: 182 | import shutil 183 | shutil.copy(net_save_path, net_save_path_best) 184 | self.logger_info("Saving current best: {}".format(net_save_path_best)) 185 | else: 186 | self.logger_info("Saving checkpoint: {}".format(net_save_path)) 187 | 188 | 189 | def _on_train_finish(self): 190 | for k, v in self.metrics.items(): 191 | self.logger_info('{}:{}'.format(k, v)) 192 | self.logger_info('finish train') 193 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:58 3 | # @Author : zhoujun 4 | from .util import * 5 | from .metrics import * 6 | from .schedulers import * 7 | from .cal_recall.script import cal_recall_precison_f1 8 | from .ocr_metric import get_metric -------------------------------------------------------------------------------- /utils/cal_recall/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/16/19 6:40 AM 3 | # @Author : zhoujun 4 | from .script import cal_recall_precison_f1 5 | __all__ = ['cal_recall_precison_f1'] -------------------------------------------------------------------------------- /utils/compute_mean_std.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/7 14:46 3 | # @Author : zhoujun 4 | 5 | import numpy as np 6 | import cv2 7 | import os 8 | import random 9 | from tqdm import tqdm 10 | # calculate means and std 11 | train_txt_path = './train_val_list.txt' 12 | 13 | CNum = 10000 # 挑选多少图片进行计算 14 | 15 | img_h, img_w = 640, 640 16 | imgs = np.zeros([img_w, img_h, 3, 1]) 17 | means, stdevs = [], [] 18 | 19 | with open(train_txt_path, 'r') as f: 20 | lines = f.readlines() 21 | random.shuffle(lines) # shuffle , 随机挑选图片 22 | 23 | for i in tqdm(range(CNum)): 24 | img_path = lines[i].split('\t')[0] 25 | 26 | img = cv2.imread(img_path) 27 | img = cv2.resize(img, (img_h, img_w)) 28 | img = img[:, :, :, np.newaxis] 29 | 30 | imgs = np.concatenate((imgs, img), axis=3) 31 | # print(i) 32 | 33 | imgs = imgs.astype(np.float32) / 255. 34 | 35 | for i in tqdm(range(3)): 36 | pixels = imgs[:, :, i, :].ravel() # 拉成一行 37 | means.append(np.mean(pixels)) 38 | stdevs.append(np.std(pixels)) 39 | 40 | # cv2 读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转 41 | means.reverse() # BGR --> RGB 42 | stdevs.reverse() 43 | 44 | print("normMean = {}".format(means)) 45 | print("normStd = {}".format(stdevs)) 46 | print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs)) -------------------------------------------------------------------------------- /utils/make_trainfile.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/24 12:06 3 | # @Author : zhoujun 4 | import os 5 | import glob 6 | import pathlib 7 | 8 | data_path = r'test' 9 | # data_path/img 存放图片 10 | # data_path/gt 存放标签文件 11 | 12 | f_w = open(os.path.join(data_path, 'test.txt'), 'w', encoding='utf8') 13 | for img_path in glob.glob(data_path + '/img/*.jpg', recursive=True): 14 | d = pathlib.Path(img_path) 15 | label_path = os.path.join(data_path, 'gt', ('gt_' + str(d.stem) + '.txt')) 16 | if os.path.exists(img_path) and os.path.exists(label_path): 17 | print(img_path, label_path) 18 | else: 19 | print('不存在', img_path, label_path) 20 | f_w.write('{}\t{}\n'.format(img_path, label_path)) 21 | f_w.close() -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | 7 | class runningScore(object): 8 | 9 | def __init__(self, n_classes): 10 | self.n_classes = n_classes 11 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 12 | 13 | def _fast_hist(self, label_true, label_pred, n_class): 14 | mask = (label_true >= 0) & (label_true < n_class) 15 | 16 | if np.sum((label_pred[mask] < 0)) > 0: 17 | print(label_pred[label_pred < 0]) 18 | hist = np.bincount(n_class * label_true[mask].astype(int) + 19 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 20 | return hist 21 | 22 | def update(self, label_trues, label_preds): 23 | # print label_trues.dtype, label_preds.dtype 24 | for lt, lp in zip(label_trues, label_preds): 25 | try: 26 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 27 | except: 28 | pass 29 | 30 | def get_scores(self): 31 | """Returns accuracy score evaluation result. 32 | - overall accuracy 33 | - mean accuracy 34 | - mean IU 35 | - fwavacc 36 | """ 37 | hist = self.confusion_matrix 38 | acc = np.diag(hist).sum() / (hist.sum() + 0.0001) 39 | acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001) 40 | acc_cls = np.nanmean(acc_cls) 41 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001) 42 | mean_iu = np.nanmean(iu) 43 | freq = hist.sum(axis=1) / (hist.sum() + 0.0001) 44 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 45 | cls_iu = dict(zip(range(self.n_classes), iu)) 46 | 47 | return {'Overall Acc': acc, 48 | 'Mean Acc': acc_cls, 49 | 'FreqW Acc': fwavacc, 50 | 'Mean IoU': mean_iu, }, cls_iu 51 | 52 | def reset(self): 53 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 54 | -------------------------------------------------------------------------------- /utils/ocr_metric/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/5 15:36 3 | # @Author : zhoujun 4 | from .icdar2015 import QuadMetric 5 | 6 | 7 | def get_metric(config): 8 | try: 9 | if 'args' not in config: 10 | args = {} 11 | else: 12 | args = config['args'] 13 | if isinstance(args, dict): 14 | cls = eval(config['type'])(**args) 15 | else: 16 | cls = eval(config['type'])(args) 17 | return cls 18 | except: 19 | return None -------------------------------------------------------------------------------- /utils/ocr_metric/icdar2015/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/5 15:36 3 | # @Author : zhoujun 4 | 5 | from .quad_metric import QuadMetric -------------------------------------------------------------------------------- /utils/ocr_metric/icdar2015/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/DBNet.pytorch/e03acf0e6b3b62f7d1dc7e10a6d2587456ac9ea1/utils/ocr_metric/icdar2015/detection/__init__.py -------------------------------------------------------------------------------- /utils/ocr_metric/icdar2015/detection/iou.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from collections import namedtuple 4 | import numpy as np 5 | from shapely.geometry import Polygon 6 | import cv2 7 | 8 | 9 | def iou_rotate(box_a, box_b, method='union'): 10 | rect_a = cv2.minAreaRect(box_a) 11 | rect_b = cv2.minAreaRect(box_b) 12 | r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b) 13 | if r1[0] == 0: 14 | return 0 15 | else: 16 | inter_area = cv2.contourArea(r1[1]) 17 | area_a = cv2.contourArea(box_a) 18 | area_b = cv2.contourArea(box_b) 19 | union_area = area_a + area_b - inter_area 20 | if union_area == 0 or inter_area == 0: 21 | return 0 22 | if method == 'union': 23 | iou = inter_area / union_area 24 | elif method == 'intersection': 25 | iou = inter_area / min(area_a, area_b) 26 | else: 27 | raise NotImplementedError 28 | return iou 29 | 30 | 31 | class DetectionIoUEvaluator(object): 32 | def __init__(self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5): 33 | self.is_output_polygon = is_output_polygon 34 | self.iou_constraint = iou_constraint 35 | self.area_precision_constraint = area_precision_constraint 36 | 37 | def evaluate_image(self, gt, pred): 38 | 39 | def get_union(pD, pG): 40 | return Polygon(pD).union(Polygon(pG)).area 41 | 42 | def get_intersection_over_union(pD, pG): 43 | return get_intersection(pD, pG) / get_union(pD, pG) 44 | 45 | def get_intersection(pD, pG): 46 | return Polygon(pD).intersection(Polygon(pG)).area 47 | 48 | def compute_ap(confList, matchList, numGtCare): 49 | correct = 0 50 | AP = 0 51 | if len(confList) > 0: 52 | confList = np.array(confList) 53 | matchList = np.array(matchList) 54 | sorted_ind = np.argsort(-confList) 55 | confList = confList[sorted_ind] 56 | matchList = matchList[sorted_ind] 57 | for n in range(len(confList)): 58 | match = matchList[n] 59 | if match: 60 | correct += 1 61 | AP += float(correct) / (n + 1) 62 | 63 | if numGtCare > 0: 64 | AP /= numGtCare 65 | 66 | return AP 67 | 68 | perSampleMetrics = {} 69 | 70 | matchedSum = 0 71 | 72 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') 73 | 74 | numGlobalCareGt = 0 75 | numGlobalCareDet = 0 76 | 77 | arrGlobalConfidences = [] 78 | arrGlobalMatches = [] 79 | 80 | recall = 0 81 | precision = 0 82 | hmean = 0 83 | 84 | detMatched = 0 85 | 86 | iouMat = np.empty([1, 1]) 87 | 88 | gtPols = [] 89 | detPols = [] 90 | 91 | gtPolPoints = [] 92 | detPolPoints = [] 93 | 94 | # Array of Ground Truth Polygons' keys marked as don't Care 95 | gtDontCarePolsNum = [] 96 | # Array of Detected Polygons' matched with a don't Care GT 97 | detDontCarePolsNum = [] 98 | 99 | pairs = [] 100 | detMatchedNums = [] 101 | 102 | arrSampleConfidences = [] 103 | arrSampleMatch = [] 104 | 105 | evaluationLog = "" 106 | 107 | for n in range(len(gt)): 108 | points = gt[n]['points'] 109 | # transcription = gt[n]['text'] 110 | dontCare = gt[n]['ignore'] 111 | 112 | if not Polygon(points).is_valid or not Polygon(points).is_simple: 113 | continue 114 | 115 | gtPol = points 116 | gtPols.append(gtPol) 117 | gtPolPoints.append(points) 118 | if dontCare: 119 | gtDontCarePolsNum.append(len(gtPols) - 1) 120 | 121 | evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len( 122 | gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n") 123 | 124 | for n in range(len(pred)): 125 | points = pred[n]['points'] 126 | if not Polygon(points).is_valid or not Polygon(points).is_simple: 127 | continue 128 | 129 | detPol = points 130 | detPols.append(detPol) 131 | detPolPoints.append(points) 132 | if len(gtDontCarePolsNum) > 0: 133 | for dontCarePol in gtDontCarePolsNum: 134 | dontCarePol = gtPols[dontCarePol] 135 | intersected_area = get_intersection(dontCarePol, detPol) 136 | pdDimensions = Polygon(detPol).area 137 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions 138 | if (precision > self.area_precision_constraint): 139 | detDontCarePolsNum.append(len(detPols) - 1) 140 | break 141 | 142 | evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len( 143 | detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n") 144 | 145 | if len(gtPols) > 0 and len(detPols) > 0: 146 | # Calculate IoU and precision matrixs 147 | outputShape = [len(gtPols), len(detPols)] 148 | iouMat = np.empty(outputShape) 149 | gtRectMat = np.zeros(len(gtPols), np.int8) 150 | detRectMat = np.zeros(len(detPols), np.int8) 151 | if self.is_output_polygon: 152 | for gtNum in range(len(gtPols)): 153 | for detNum in range(len(detPols)): 154 | pG = gtPols[gtNum] 155 | pD = detPols[detNum] 156 | iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) 157 | else: 158 | # gtPols = np.float32(gtPols) 159 | # detPols = np.float32(detPols) 160 | for gtNum in range(len(gtPols)): 161 | for detNum in range(len(detPols)): 162 | pG = np.float32(gtPols[gtNum]) 163 | pD = np.float32(detPols[detNum]) 164 | iouMat[gtNum, detNum] = iou_rotate(pD, pG) 165 | for gtNum in range(len(gtPols)): 166 | for detNum in range(len(detPols)): 167 | if gtRectMat[gtNum] == 0 and detRectMat[ 168 | detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: 169 | if iouMat[gtNum, detNum] > self.iou_constraint: 170 | gtRectMat[gtNum] = 1 171 | detRectMat[detNum] = 1 172 | detMatched += 1 173 | pairs.append({'gt': gtNum, 'det': detNum}) 174 | detMatchedNums.append(detNum) 175 | evaluationLog += "Match GT #" + \ 176 | str(gtNum) + " with Det #" + str(detNum) + "\n" 177 | 178 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) 179 | numDetCare = (len(detPols) - len(detDontCarePolsNum)) 180 | if numGtCare == 0: 181 | recall = float(1) 182 | precision = float(0) if numDetCare > 0 else float(1) 183 | else: 184 | recall = float(detMatched) / numGtCare 185 | precision = 0 if numDetCare == 0 else float( 186 | detMatched) / numDetCare 187 | 188 | hmean = 0 if (precision + recall) == 0 else 2.0 * \ 189 | precision * recall / (precision + recall) 190 | 191 | matchedSum += detMatched 192 | numGlobalCareGt += numGtCare 193 | numGlobalCareDet += numDetCare 194 | 195 | perSampleMetrics = { 196 | 'precision': precision, 197 | 'recall': recall, 198 | 'hmean': hmean, 199 | 'pairs': pairs, 200 | 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), 201 | 'gtPolPoints': gtPolPoints, 202 | 'detPolPoints': detPolPoints, 203 | 'gtCare': numGtCare, 204 | 'detCare': numDetCare, 205 | 'gtDontCare': gtDontCarePolsNum, 206 | 'detDontCare': detDontCarePolsNum, 207 | 'detMatched': detMatched, 208 | 'evaluationLog': evaluationLog 209 | } 210 | 211 | return perSampleMetrics 212 | 213 | def combine_results(self, results): 214 | numGlobalCareGt = 0 215 | numGlobalCareDet = 0 216 | matchedSum = 0 217 | for result in results: 218 | numGlobalCareGt += result['gtCare'] 219 | numGlobalCareDet += result['detCare'] 220 | matchedSum += result['detMatched'] 221 | 222 | methodRecall = 0 if numGlobalCareGt == 0 else float( 223 | matchedSum) / numGlobalCareGt 224 | methodPrecision = 0 if numGlobalCareDet == 0 else float( 225 | matchedSum) / numGlobalCareDet 226 | methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ 227 | methodRecall * methodPrecision / ( 228 | methodRecall + methodPrecision) 229 | 230 | methodMetrics = {'precision': methodPrecision, 231 | 'recall': methodRecall, 'hmean': methodHmean} 232 | 233 | return methodMetrics 234 | 235 | 236 | if __name__ == '__main__': 237 | evaluator = DetectionIoUEvaluator() 238 | preds = [[{ 239 | 'points': [(0.1, 0.1), (0.5, 0), (0.5, 1), (0, 1)], 240 | 'text': 1234, 241 | 'ignore': False, 242 | }, { 243 | 'points': [(0.5, 0.1), (1, 0), (1, 1), (0.5, 1)], 244 | 'text': 5678, 245 | 'ignore': False, 246 | }]] 247 | gts = [[{ 248 | 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], 249 | 'text': 123, 250 | 'ignore': False, 251 | }]] 252 | results = [] 253 | for gt, pred in zip(gts, preds): 254 | results.append(evaluator.evaluate_image(gt, pred)) 255 | metrics = evaluator.combine_results(results) 256 | print(metrics) 257 | -------------------------------------------------------------------------------- /utils/ocr_metric/icdar2015/quad_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .detection.iou import DetectionIoUEvaluator 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | return self 24 | 25 | 26 | class QuadMetric(): 27 | def __init__(self, is_output_polygon=False): 28 | self.is_output_polygon = is_output_polygon 29 | self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon) 30 | 31 | def measure(self, batch, output, box_thresh=0.6): 32 | ''' 33 | batch: (image, polygons, ignore_tags 34 | batch: a dict produced by dataloaders. 35 | image: tensor of shape (N, C, H, W). 36 | polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. 37 | ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. 38 | shape: the original shape of images. 39 | filename: the original filenames of images. 40 | output: (polygons, ...) 41 | ''' 42 | results = [] 43 | gt_polyons_batch = batch['text_polys'] 44 | ignore_tags_batch = batch['ignore_tags'] 45 | pred_polygons_batch = np.array(output[0]) 46 | pred_scores_batch = np.array(output[1]) 47 | for polygons, pred_polygons, pred_scores, ignore_tags in zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch): 48 | gt = [dict(points=np.int64(polygons[i]), ignore=ignore_tags[i]) for i in range(len(polygons))] 49 | if self.is_output_polygon: 50 | pred = [dict(points=pred_polygons[i]) for i in range(len(pred_polygons))] 51 | else: 52 | pred = [] 53 | # print(pred_polygons.shape) 54 | for i in range(pred_polygons.shape[0]): 55 | if pred_scores[i] >= box_thresh: 56 | # print(pred_polygons[i,:,:].tolist()) 57 | pred.append(dict(points=pred_polygons[i, :, :].astype(np.int))) 58 | # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])] 59 | results.append(self.evaluator.evaluate_image(gt, pred)) 60 | return results 61 | 62 | def validate_measure(self, batch, output, box_thresh=0.6): 63 | return self.measure(batch, output, box_thresh) 64 | 65 | def evaluate_measure(self, batch, output): 66 | return self.measure(batch, output), np.linspace(0, batch['image'].shape[0]).tolist() 67 | 68 | def gather_measure(self, raw_metrics): 69 | raw_metrics = [image_metrics 70 | for batch_metrics in raw_metrics 71 | for image_metrics in batch_metrics] 72 | 73 | result = self.evaluator.combine_results(raw_metrics) 74 | 75 | precision = AverageMeter() 76 | recall = AverageMeter() 77 | fmeasure = AverageMeter() 78 | 79 | precision.update(result['precision'], n=len(raw_metrics)) 80 | recall.update(result['recall'], n=len(raw_metrics)) 81 | fmeasure_score = 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8) 82 | fmeasure.update(fmeasure_score) 83 | 84 | return { 85 | 'precision': precision, 86 | 'recall': recall, 87 | 'fmeasure': fmeasure 88 | } 89 | -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | """Popular Learning Rate Schedulers""" 2 | from __future__ import division 3 | 4 | import math 5 | from bisect import bisect_right 6 | 7 | import torch 8 | 9 | __all__ = ['LRScheduler', 'WarmupMultiStepLR', 'WarmupPolyLR'] 10 | 11 | 12 | class LRScheduler(object): 13 | r"""Learning Rate Scheduler 14 | Parameters 15 | ---------- 16 | mode : str 17 | Modes for learning rate scheduler. 18 | Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'. 19 | base_lr : float 20 | Base learning rate, i.e. the starting learning rate. 21 | target_lr : float 22 | Target learning rate, i.e. the ending learning rate. 23 | With constant mode target_lr is ignored. 24 | niters : int 25 | Number of iterations to be scheduled. 26 | nepochs : int 27 | Number of epochs to be scheduled. 28 | iters_per_epoch : int 29 | Number of iterations in each epoch. 30 | offset : int 31 | Number of iterations before this scheduler. 32 | power : float 33 | Power parameter of poly scheduler. 34 | step_iter : list 35 | A list of iterations to decay the learning rate. 36 | step_epoch : list 37 | A list of epochs to decay the learning rate. 38 | step_factor : float 39 | Learning rate decay factor. 40 | """ 41 | 42 | def __init__(self, mode, base_lr=0.01, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0, 43 | offset=0, power=0.9, step_iter=None, step_epoch=None, step_factor=0.1, warmup_epochs=0): 44 | super(LRScheduler, self).__init__() 45 | assert (mode in ['constant', 'step', 'linear', 'poly', 'cosine']) 46 | 47 | if mode == 'step': 48 | assert (step_iter is not None or step_epoch is not None) 49 | self.niters = niters 50 | self.step = step_iter 51 | epoch_iters = nepochs * iters_per_epoch 52 | if epoch_iters > 0: 53 | self.niters = epoch_iters 54 | if step_epoch is not None: 55 | self.step = [s * iters_per_epoch for s in step_epoch] 56 | 57 | self.step_factor = step_factor 58 | self.base_lr = base_lr 59 | self.target_lr = base_lr if mode == 'constant' else target_lr 60 | self.offset = offset 61 | self.power = power 62 | self.warmup_iters = warmup_epochs * iters_per_epoch 63 | self.mode = mode 64 | 65 | def __call__(self, optimizer, num_update): 66 | self.update(num_update) 67 | assert self.learning_rate >= 0 68 | self._adjust_learning_rate(optimizer, self.learning_rate) 69 | 70 | def update(self, num_update): 71 | N = self.niters - 1 72 | T = num_update - self.offset 73 | T = min(max(0, T), N) 74 | 75 | if self.mode == 'constant': 76 | factor = 0 77 | elif self.mode == 'linear': 78 | factor = 1 - T / N 79 | elif self.mode == 'poly': 80 | factor = pow(1 - T / N, self.power) 81 | elif self.mode == 'cosine': 82 | factor = (1 + math.cos(math.pi * T / N)) / 2 83 | elif self.mode == 'step': 84 | if self.step is not None: 85 | count = sum([1 for s in self.step if s <= T]) 86 | factor = pow(self.step_factor, count) 87 | else: 88 | factor = 1 89 | else: 90 | raise NotImplementedError 91 | 92 | # warm up lr schedule 93 | if self.warmup_iters > 0 and T < self.warmup_iters: 94 | factor = factor * 1.0 * T / self.warmup_iters 95 | 96 | if self.mode == 'step': 97 | self.learning_rate = self.base_lr * factor 98 | else: 99 | self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor 100 | 101 | def _adjust_learning_rate(self, optimizer, lr): 102 | optimizer.param_groups[0]['lr'] = lr 103 | # enlarge the lr at the head 104 | for i in range(1, len(optimizer.param_groups)): 105 | optimizer.param_groups[i]['lr'] = lr * 10 106 | 107 | 108 | # separating MultiStepLR with WarmupLR 109 | # but the current LRScheduler design doesn't allow it 110 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py 111 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 112 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, 113 | warmup_iters=500, warmup_method="linear", last_epoch=-1, **kwargs): 114 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 115 | if not list(milestones) == sorted(milestones): 116 | raise ValueError( 117 | "Milestones should be a list of" " increasing integers. Got {}", milestones) 118 | if warmup_method not in ("constant", "linear"): 119 | raise ValueError( 120 | "Only 'constant' or 'linear' warmup_method accepted got {}".format(warmup_method)) 121 | 122 | self.milestones = milestones 123 | self.gamma = gamma 124 | self.warmup_factor = warmup_factor 125 | self.warmup_iters = warmup_iters 126 | self.warmup_method = warmup_method 127 | 128 | def get_lr(self): 129 | warmup_factor = 1 130 | if self.last_epoch < self.warmup_iters: 131 | if self.warmup_method == 'constant': 132 | warmup_factor = self.warmup_factor 133 | elif self.warmup_factor == 'linear': 134 | alpha = float(self.last_epoch) / self.warmup_iters 135 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 136 | return [base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 137 | for base_lr in self.base_lrs] 138 | 139 | 140 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler): 141 | def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3, 142 | warmup_iters=500, warmup_method='linear', last_epoch=-1, **kwargs): 143 | if warmup_method not in ("constant", "linear"): 144 | raise ValueError( 145 | "Only 'constant' or 'linear' warmup_method accepted " 146 | "got {}".format(warmup_method)) 147 | 148 | self.target_lr = target_lr 149 | self.max_iters = max_iters 150 | self.power = power 151 | self.warmup_factor = warmup_factor 152 | self.warmup_iters = warmup_iters 153 | self.warmup_method = warmup_method 154 | 155 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch) 156 | 157 | def get_lr(self): 158 | N = self.max_iters - self.warmup_iters 159 | T = self.last_epoch - self.warmup_iters 160 | if self.last_epoch < self.warmup_iters: 161 | if self.warmup_method == 'constant': 162 | warmup_factor = self.warmup_factor 163 | elif self.warmup_method == 'linear': 164 | alpha = float(self.last_epoch) / self.warmup_iters 165 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 166 | else: 167 | raise ValueError("Unknown warmup type.") 168 | return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs] 169 | factor = pow(1 - T / N, self.power) 170 | return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs] 171 | 172 | 173 | if __name__ == '__main__': 174 | import torch 175 | from torchvision.models import resnet18 176 | 177 | max_iter = 600 * 63 178 | model = resnet18() 179 | op = torch.optim.SGD(model.parameters(), 1e-3) 180 | sc = WarmupPolyLR(op, max_iters=max_iter, power=0.9, warmup_iters=3 * 63, warmup_method='constant') 181 | lr = [] 182 | for i in range(max_iter): 183 | sc.step() 184 | print(i, sc.last_epoch, sc.get_lr()[0]) 185 | lr.append(sc.get_lr()[0]) 186 | from matplotlib import pyplot as plt 187 | 188 | plt.plot(list(range(max_iter)), lr) 189 | plt.show() 190 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:59 3 | # @Author : zhoujun 4 | import json 5 | import pathlib 6 | import time 7 | import os 8 | import glob 9 | from natsort import natsorted 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | 14 | 15 | def get_file_list(folder_path: str, p_postfix: list = None, sub_dir: bool = True) -> list: 16 | """ 17 | 获取所给文件目录里的指定后缀的文件,读取文件列表目前使用的是 os.walk 和 os.listdir ,这两个目前比 pathlib 快很多 18 | :param filder_path: 文件夹名称 19 | :param p_postfix: 文件后缀,如果为 [.*]将返回全部文件 20 | :param sub_dir: 是否搜索子文件夹 21 | :return: 获取到的指定类型的文件列表 22 | """ 23 | assert os.path.exists(folder_path) and os.path.isdir(folder_path) 24 | if p_postfix is None: 25 | p_postfix = ['.jpg'] 26 | if isinstance(p_postfix, str): 27 | p_postfix = [p_postfix] 28 | file_list = [x for x in glob.glob(folder_path + '/**/*.*', recursive=True) if 29 | os.path.splitext(x)[-1] in p_postfix or '.*' in p_postfix] 30 | return natsorted(file_list) 31 | 32 | 33 | def setup_logger(log_file_path: str = None): 34 | import logging 35 | logging._warn_preinit_stderr = 0 36 | logger = logging.getLogger('DBNet.pytorch') 37 | formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s') 38 | ch = logging.StreamHandler() 39 | ch.setFormatter(formatter) 40 | logger.addHandler(ch) 41 | if log_file_path is not None: 42 | file_handle = logging.FileHandler(log_file_path) 43 | file_handle.setFormatter(formatter) 44 | logger.addHandler(file_handle) 45 | logger.setLevel(logging.DEBUG) 46 | return logger 47 | 48 | 49 | # --exeTime 50 | def exe_time(func): 51 | def newFunc(*args, **args2): 52 | t0 = time.time() 53 | back = func(*args, **args2) 54 | print("{} cost {:.3f}s".format(func.__name__, time.time() - t0)) 55 | return back 56 | 57 | return newFunc 58 | 59 | 60 | def load(file_path: str): 61 | file_path = pathlib.Path(file_path) 62 | func_dict = {'.txt': _load_txt, '.json': _load_json, '.list': _load_txt} 63 | assert file_path.suffix in func_dict 64 | return func_dict[file_path.suffix](file_path) 65 | 66 | 67 | def _load_txt(file_path: str): 68 | with open(file_path, 'r', encoding='utf8') as f: 69 | content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()] 70 | return content 71 | 72 | 73 | def _load_json(file_path: str): 74 | with open(file_path, 'r', encoding='utf8') as f: 75 | content = json.load(f) 76 | return content 77 | 78 | 79 | def save(data, file_path): 80 | file_path = pathlib.Path(file_path) 81 | func_dict = {'.txt': _save_txt, '.json': _save_json} 82 | assert file_path.suffix in func_dict 83 | return func_dict[file_path.suffix](data, file_path) 84 | 85 | 86 | def _save_txt(data, file_path): 87 | """ 88 | 将一个list的数组写入txt文件里 89 | :param data: 90 | :param file_path: 91 | :return: 92 | """ 93 | if not isinstance(data, list): 94 | data = [data] 95 | with open(file_path, mode='w', encoding='utf8') as f: 96 | f.write('\n'.join(data)) 97 | 98 | 99 | def _save_json(data, file_path): 100 | with open(file_path, 'w', encoding='utf-8') as json_file: 101 | json.dump(data, json_file, ensure_ascii=False, indent=4) 102 | 103 | 104 | def show_img(imgs: np.ndarray, title='img'): 105 | color = (len(imgs.shape) == 3 and imgs.shape[-1] == 3) 106 | imgs = np.expand_dims(imgs, axis=0) 107 | for i, img in enumerate(imgs): 108 | plt.figure() 109 | plt.title('{}_{}'.format(title, i)) 110 | plt.imshow(img, cmap=None if color else 'gray') 111 | plt.show() 112 | 113 | 114 | def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2): 115 | if isinstance(img_path, str): 116 | img_path = cv2.imread(img_path) 117 | # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB) 118 | img_path = img_path.copy() 119 | for point in result: 120 | point = point.astype(int) 121 | cv2.polylines(img_path, [point], True, color, thickness) 122 | return img_path 123 | 124 | 125 | def cal_text_score(texts, gt_texts, training_masks, running_metric_text, thred=0.5): 126 | training_masks = training_masks.data.cpu().numpy() 127 | pred_text = texts.data.cpu().numpy() * training_masks 128 | pred_text[pred_text <= thred] = 0 129 | pred_text[pred_text > thred] = 1 130 | pred_text = pred_text.astype(np.int32) 131 | gt_text = gt_texts.data.cpu().numpy() * training_masks 132 | gt_text = gt_text.astype(np.int32) 133 | running_metric_text.update(gt_text, pred_text) 134 | score_text, _ = running_metric_text.get_scores() 135 | return score_text 136 | 137 | 138 | def order_points_clockwise(pts): 139 | rect = np.zeros((4, 2), dtype="float32") 140 | s = pts.sum(axis=1) 141 | rect[0] = pts[np.argmin(s)] 142 | rect[2] = pts[np.argmax(s)] 143 | diff = np.diff(pts, axis=1) 144 | rect[1] = pts[np.argmin(diff)] 145 | rect[3] = pts[np.argmax(diff)] 146 | return rect 147 | 148 | 149 | def order_points_clockwise_list(pts): 150 | pts = pts.tolist() 151 | pts.sort(key=lambda x: (x[1], x[0])) 152 | pts[:2] = sorted(pts[:2], key=lambda x: x[0]) 153 | pts[2:] = sorted(pts[2:], key=lambda x: -x[0]) 154 | pts = np.array(pts) 155 | return pts 156 | 157 | 158 | def get_datalist(train_data_path): 159 | """ 160 | 获取训练和验证的数据list 161 | :param train_data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ 162 | :return: 163 | """ 164 | train_data = [] 165 | for p in train_data_path: 166 | with open(p, 'r', encoding='utf-8') as f: 167 | for line in f.readlines(): 168 | line = line.strip('\n').replace('.jpg ', '.jpg\t').split('\t') 169 | if len(line) > 1: 170 | img_path = pathlib.Path(line[0].strip(' ')) 171 | label_path = pathlib.Path(line[1].strip(' ')) 172 | if img_path.exists() and img_path.stat().st_size > 0 and label_path.exists() and label_path.stat().st_size > 0: 173 | train_data.append((str(img_path), str(label_path))) 174 | return train_data 175 | 176 | 177 | def parse_config(config: dict) -> dict: 178 | import anyconfig 179 | base_file_list = config.pop('base') 180 | base_config = {} 181 | for base_file in base_file_list: 182 | tmp_config = anyconfig.load(open(base_file, 'rb')) 183 | if 'base' in tmp_config: 184 | tmp_config = parse_config(tmp_config) 185 | anyconfig.merge(tmp_config, base_config) 186 | base_config = tmp_config 187 | anyconfig.merge(base_config, config) 188 | return base_config 189 | 190 | 191 | def save_result(result_path, box_list, score_list, is_output_polygon): 192 | if is_output_polygon: 193 | with open(result_path, 'wt') as res: 194 | for i, box in enumerate(box_list): 195 | box = box.reshape(-1).tolist() 196 | result = ",".join([str(int(x)) for x in box]) 197 | score = score_list[i] 198 | res.write(result + ',' + str(score) + "\n") 199 | else: 200 | with open(result_path, 'wt') as res: 201 | for i, box in enumerate(box_list): 202 | score = score_list[i] 203 | box = box.reshape(-1).tolist() 204 | result = ",".join([str(int(x)) for x in box]) 205 | res.write(result + ',' + str(score) + "\n") 206 | 207 | 208 | def expand_polygon(polygon): 209 | """ 210 | 对只有一个字符的框进行扩充 211 | """ 212 | (x, y), (w, h), angle = cv2.minAreaRect(np.float32(polygon)) 213 | if angle < -45: 214 | w, h = h, w 215 | angle += 90 216 | new_w = w + h 217 | box = ((x, y), (new_w, h), angle) 218 | points = cv2.boxPoints(box) 219 | return order_points_clockwise(points) 220 | 221 | 222 | if __name__ == '__main__': 223 | img = np.zeros((1, 3, 640, 640)) 224 | show_img(img[0][0]) 225 | plt.show() 226 | --------------------------------------------------------------------------------