├── .gitignore ├── 0.jpg ├── LICENSE ├── README.md ├── base ├── __init__.py ├── base_dataset.py └── base_trainer.py ├── config ├── image_dataset.yaml ├── imagedataset_None_VGG_RNN_Attn.yaml ├── imagedataset_None_VGG_RNN_CTC.yaml └── lmdb.yaml ├── data_loader ├── __init__.py ├── dataset.py └── modules │ ├── Text_Image_Augmentation_python │ ├── __init__.py │ ├── augment.py │ ├── demo.py │ └── warp_mls.py │ ├── __init__.py │ ├── augment.py │ └── resize.py ├── modeling ├── __init__.py ├── backbone │ ├── MobileNetV3.py │ ├── __init__.py │ ├── feature_extraction.py │ ├── resnet.py │ └── resnet_torch.py ├── basic.py ├── head │ ├── Attn.py │ ├── CTC.py │ └── __init__.py ├── losses │ ├── AttnLoss.py │ ├── CTCLoss.py │ └── __init__.py ├── model.py ├── modules │ └── seg │ │ ├── __init__.py │ │ ├── resnet.py │ │ ├── resnet_fpn.py │ │ └── unet.py ├── neck │ ├── __init__.py │ └── sequence_modeling.py └── trans │ ├── TPS.py │ └── __init__.py ├── msyh.ttc ├── predict.py ├── requirements.txt ├── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── create_lmdb_dataset.py ├── gen_img.py ├── get_keys.py ├── label_utils.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | output/ 4 | venv/ 5 | *.pyc 6 | .DS_Store 7 | *.pth 8 | *.pt 9 | *.pyc 10 | *.pyo 11 | *.log 12 | *.tmp 13 | -------------------------------------------------------------------------------- /0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/crnn.pytorch/bf7a7c62376eee93943ca7c68e88e3d563c09aa8/0.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Convolutional Recurrent Neural Network 2 | ====================================== 3 | 4 | This software implements the Convolutional Recurrent Neural Network (CRNN) in pytorch. 5 | Origin software could be found in [crnn](https://github.com/bgshih/crnn) 6 | 7 | 8 | ## Requirements 9 | * pytorch 1.3+ 10 | * torchvision 0.4+ 11 | 12 | ## Data Preparation 13 | Prepare a text in the following format 14 | ``` 15 | /path/to/img/img.jpg label 16 | ... 17 | ``` 18 | 19 | # Performance 20 | data link [baiduyun](https://pan.baidu.com/s/1w7KssjsOHbBTLtjaltLJ0w) code: 9p2m, the dataset is generate by 21 | dataset contains 10w images for train and 1w images for test:1w 22 | for all arch ,we stop training after 30 epochs 23 | environment: cuda9.2 torch1.4 torchvision0.5 24 | 25 | | arch | model size(m) | gpu mem(m) | speed(ms,avg of 100 inference) | acc | 26 | | ----------------------- | ------ | -------- | ------ | ------ | 27 | | CNN_lite_LSTM_CTC | 6.25 | 2731 | 6.91ms | 0.8866 | 28 | | VGG(BasicConv)_LSTM_CTC(w320) | 25.45 | 2409 | 4.02ms | 0.9874 | 29 | | VGG(BasicConv)_LSTM_CTC(w160) | 25.45 | 2409 | 4.02ms | 0.9908 | 30 | | VGG(BasicConv)_LSTM_CTC(w160_no_imagenet_mean_std) | 25.45 | 2409 | 4.02ms | 0.9927 | 31 | | VGG(BasicConv)_LSTM_CTC(w160.sub_(0.5).div_(0.5)) | 25.45 | 2409 | 4.02ms | 0.9927 | 32 | | VGG(BasicConv)_LSTM_CTC(w160 origin crnn rnn) | 25.45 | 2409 | 4.02ms | 0.9922 | 33 | | VGG(DWconv)_LSTM_CTC(w160_no_imagenet_mean_std) | 25.45 | 2409 | 4.01ms | 0.9725 | 34 | | VGG(GhostModule)_LSTM_CTC(w160_no_imagenet_mean_std) | 25.45 | 2329 | 5.46ms | 0.9878 | 35 | | ResNet(BasicBlockV2)_LSTM_CTC | 37.21 | 3161 | 5.83ms | 0.9935| 36 | | ResNet(DWBlock_no_se)_LSTM_CTC | 19.22 | 5533 | 12ms | 0.9566| 37 | | ResNet(DWBlock_se)_LSTM_CTC | 19.90 | 5729 | 10ms | 0.9559 | 38 | | ResNet(GhostBottleneck_se)_LSTM_CTC | 23.10 | 6291 | 13ms | 0.97| 39 | 40 | 41 | ## Train 42 | 43 | 1. config the `dataset['train']['dataset']['data_path']`,`dataset['validate']['dataset']['data_path']` in [config.yaml](config/icdar2015.yaml) 44 | 2. generate alphabet 45 | use fellow script to generate `alphabet.py` in the some folder with `train.py` 46 | ```sh 47 | python3 utils/get_keys.py 48 | ``` 49 | 2. use following script to run 50 | ```sh 51 | python3 train.py --config_path config.yaml 52 | ``` 53 | 54 | ## Predict 55 | [predict.py](predict.py) is used to inference on single image 56 | 57 | 1. config `model_path`, `img_path` in [predict.py](predict.py) 58 | 2. use following script to predict 59 | ```sh 60 | python3 predict.py 61 | ``` -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | from .base_dataset import BaseDataSet -------------------------------------------------------------------------------- /base/base_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/11/6 15:08 3 | # @Author : zhoujun 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from data_loader.modules import * 7 | 8 | 9 | class BaseDataSet(Dataset): 10 | def __init__(self, data_path: str, img_mode, num_label, ignore_chinese_punctuation, remove_blank, pre_processes, transform=None, **kwargs): 11 | """ 12 | :param ignore_chinese_punctuation: 是否转换全角为半角 13 | """ 14 | assert img_mode in ['RGB', 'BRG', 'GRAY'] 15 | self.img_mode = img_mode 16 | self.num_label = num_label 17 | self.transform = transform 18 | self.remove_blank = remove_blank 19 | self.ignore_chinese_punctuation = ignore_chinese_punctuation 20 | self.data_list = self.load_data(data_path) 21 | self._init_pre_processes(pre_processes) 22 | 23 | def _init_pre_processes(self, pre_processes): 24 | self.aug = [] 25 | if pre_processes is not None: 26 | for aug in pre_processes: 27 | if 'args' not in aug: 28 | args = {} 29 | else: 30 | args = aug['args'] 31 | if isinstance(args, dict): 32 | cls = eval(aug['type'])(**args) 33 | else: 34 | cls = eval(aug['type'])(args) 35 | self.aug.append(cls) 36 | 37 | def load_data(self, data_path: str) -> list: 38 | """ 39 | 把数据加载为一个list: 40 | :params data_path: 存储数据的文件夹或者文件 41 | return a list ,包含img_path和label 42 | """ 43 | raise NotImplementedError 44 | 45 | def apply_pre_processes(self, data): 46 | for aug in self.aug: 47 | data = aug(data) 48 | return data 49 | 50 | def get_sample(self, index): 51 | raise NotImplementedError 52 | 53 | def __getitem__(self, index): 54 | data = self.get_sample(index) 55 | data['img'] = self.apply_pre_processes(data['img']) 56 | if self.transform is not None: 57 | data['img'] = Image.fromarray(data['img']) 58 | data['img'] = self.transform(data['img']) 59 | # img.sub_(0.5).div_(0.5) 60 | return data 61 | 62 | def __len__(self): 63 | return len(self.data_list) 64 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/8/23 22:20 3 | # @Author : zhoujun 4 | 5 | import os 6 | import shutil 7 | import pathlib 8 | from pprint import pformat 9 | import traceback 10 | import torch 11 | from utils import setup_logger 12 | 13 | 14 | class BaseTrainer: 15 | def __init__(self, config, model, criterion, sample_input): 16 | # init ckpt path 17 | config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent), 18 | config['trainer']['output_dir']) 19 | config['name'] = config['name'] + '_' + model.name 20 | self.save_dir = os.path.join(config['trainer']['output_dir'], config['name']) 21 | self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') 22 | 23 | if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '': 24 | shutil.rmtree(self.save_dir, ignore_errors=True) 25 | if not os.path.exists(self.checkpoint_dir): 26 | os.makedirs(self.checkpoint_dir) 27 | 28 | self.global_step = 0 29 | self.start_epoch = 0 30 | self.config = config 31 | self.model = model 32 | self.criterion = criterion 33 | # logger 34 | self.tensorboard_enable = self.config['trainer']['tensorboard'] 35 | self.epochs = self.config['trainer']['epochs'] 36 | self.log_iter = self.config['trainer']['log_iter'] 37 | 38 | self.logger = setup_logger(os.path.join(self.save_dir, 'train.log')) 39 | self.logger.info(pformat(self.config)) 40 | self.logger.info(self.model) 41 | 42 | # device set 43 | torch.manual_seed(self.config['trainer']['seed']) # 为CPU设置随机种子 44 | if len(self.config['trainer']['gpus']) > 0 and torch.cuda.is_available(): 45 | self.with_cuda = True 46 | torch.backends.cudnn.benchmark = True 47 | self.logger.info(f"train with pytorch {torch.__version__} gpu {self.config['trainer']['gpus']}") 48 | self.gpus = {i: item for i, item in enumerate(self.config['trainer']['gpus'])} 49 | self.device = torch.device("cuda:0") 50 | else: 51 | self.with_cuda = False 52 | self.logger.info(f'train with pytorch {torch.__version__} and cpu') 53 | self.device = torch.device("cpu") 54 | 55 | self.optimizer = self._initialize('optimizer', torch.optim, model.parameters()) 56 | 57 | # resume or finetune 58 | if self.config['trainer']['resume_checkpoint'] != '': 59 | self._laod_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True) 60 | elif self.config['trainer']['finetune_checkpoint'] != '': 61 | self._laod_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False) 62 | self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer) 63 | 64 | self.model.to(self.device) 65 | self.batch_max_length = self.model.batch_max_length 66 | if torch.cuda.device_count() > 1: 67 | self.model = torch.nn.DataParallel(model) 68 | 69 | if self.tensorboard_enable: 70 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' # 只显示 Error 71 | from torch.utils.tensorboard import SummaryWriter 72 | self.writer = SummaryWriter(self.save_dir) 73 | try: 74 | # add graph 75 | self.writer.add_graph(self.model, sample_input.to(self.device)) 76 | torch.cuda.empty_cache() 77 | except: 78 | import traceback 79 | self.logger.error(traceback.format_exc()) 80 | self.logger.warn('add graph to tensorboard failed') 81 | 82 | def train(self): 83 | """ 84 | Full training logic 85 | """ 86 | try: 87 | # self._testval(max_step=10) 88 | for epoch in range(self.start_epoch + 1, self.epochs + 1): 89 | self._train_epoch(epoch) 90 | self.scheduler.step() 91 | self._on_epoch_finish() 92 | except: 93 | self.logger.error(traceback.format_exc()) 94 | if self.tensorboard_enable: 95 | self.writer.close() 96 | self._on_train_finish() 97 | 98 | def _train_epoch(self, epoch): 99 | """ 100 | Training logic for an epoch 101 | :param epoch: Current epoch number 102 | """ 103 | raise NotImplementedError 104 | 105 | def _testval(self, max_step): 106 | self._eval(max_step, f'test model on eval before train with {max_step} steps') 107 | 108 | def _eval(self, max_step, desc): 109 | """ 110 | eval logic for an epoch 111 | :param epoch: Current epoch number 112 | """ 113 | raise NotImplementedError 114 | 115 | def _before_step(self, batch): 116 | raise NotImplementedError 117 | 118 | def _run_step(self, batch): 119 | raise NotImplementedError 120 | 121 | def _after_step(self, batch_out): 122 | raise NotImplementedError 123 | 124 | def _on_epoch_finish(self): 125 | raise NotImplementedError 126 | 127 | def _on_train_finish(self): 128 | raise NotImplementedError 129 | 130 | def _save_checkpoint(self, epoch, file_name, save_best=False): 131 | """ 132 | Saving checkpoints 133 | :param epoch: current epoch number 134 | :param log: logging information of the epoch 135 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' 136 | """ 137 | state_dict = self.model.module.state_dict() if torch.cuda.device_count() > 1 else self.model.state_dict() 138 | state = { 139 | 'epoch': epoch, 140 | 'global_step': self.global_step, 141 | 'state_dict': state_dict, 142 | 'optimizer': self.optimizer.state_dict(), 143 | 'scheduler': self.scheduler.state_dict(), 144 | 'config': self.config, 145 | 'metrics': self.metrics 146 | } 147 | filename = os.path.join(self.checkpoint_dir, file_name) 148 | torch.save(state, filename) 149 | if save_best: 150 | shutil.copy(filename, os.path.join(self.checkpoint_dir, 'model_best.pth')) 151 | self.logger.info(f"Saving current best: {file_name}") 152 | else: 153 | self.logger.info(f"Saving checkpoint: {filename}") 154 | 155 | def _laod_checkpoint(self, checkpoint_path, resume): 156 | """ 157 | Resume from saved checkpoints 158 | :param checkpoint_path: Checkpoint path to be resumed 159 | """ 160 | self.logger.info(f"Loading checkpoint: {checkpoint_path} ...") 161 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) 162 | self.model.load_state_dict(checkpoint['state_dict'], strict=resume) 163 | if resume: 164 | self.global_step = checkpoint['global_step'] 165 | self.start_epoch = checkpoint['epoch'] 166 | self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch 167 | # self.scheduler.load_state_dict(checkpoint['scheduler']) 168 | self.optimizer.load_state_dict(checkpoint['optimizer']) 169 | if 'metrics' in checkpoint: 170 | self.metrics = checkpoint['metrics'] 171 | if self.with_cuda: 172 | for state in self.optimizer.state.values(): 173 | for k, v in state.items(): 174 | if isinstance(v, torch.Tensor): 175 | state[k] = v.to(self.device) 176 | self.logger.info(f"resume from checkpoint {checkpoint_path} (epoch {self.start_epoch})") 177 | else: 178 | self.logger.info(f"finetune from checkpoint {checkpoint_path}") 179 | 180 | def _initialize(self, name, module, *args, **kwargs): 181 | module_name = self.config[name]['type'] 182 | module_args = self.config[name]['args'] 183 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 184 | module_args.update(kwargs) 185 | return getattr(module, module_name)(*args, **module_args) 186 | -------------------------------------------------------------------------------- /config/image_dataset.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | alphabet: dict.txt 3 | train: 4 | dataset: 5 | type: ImageDataset # 数据集类型 6 | args: 7 | data_path: # [[文件,文件],[文件,文件]],每个文件格式为 img_path \t gt,每个子list按照data_ratio的比例进行采样 8 | - -'' 9 | data_ratio: 10 | - 1.0 11 | pre_processes: # 数据的预处理过程,包含augment和预处理 12 | # - type: IaaAugment # 使用imgaug进行变换 13 | - type: Resize 14 | args: 15 | img_h: 32 16 | img_w: 320 17 | pad: true 18 | random_crop: true 19 | transforms: # 对图片进行的变换方式 20 | - type: ColorJitter 21 | args: 22 | brightness: 0.5 23 | - type: ToTensor 24 | args: {} 25 | - type: Normalize 26 | args: 27 | mean: [0.485, 0.456, 0.406] 28 | std: [0.229, 0.224, 0.225] 29 | img_mode: RGB 30 | ignore_chinese_punctuation: true 31 | remove_blank: true 32 | loader: 33 | batch_size: 8 34 | shuffle: true 35 | pin_memory: false 36 | num_workers: 6 37 | validate: 38 | dataset: 39 | type: ImageDataset 40 | args: 41 | data_path: # [文件,文件],每个文件格式为 img_path \t gt 42 | - '' 43 | pre_processes: 44 | - type: Resize 45 | args: 46 | img_h: 32 47 | img_w: 320 48 | pad: true 49 | random_crop: false 50 | transforms: 51 | - type: ToTensor 52 | args: {} 53 | - type: Normalize 54 | args: 55 | mean: [0.485, 0.456, 0.406] 56 | std: [0.229, 0.224, 0.225] 57 | img_mode: RGB 58 | ignore_chinese_punctuation: true 59 | remove_blank: true 60 | loader: 61 | batch_size: 4 62 | shuffle: true 63 | pin_memory: false 64 | num_workers: 6 -------------------------------------------------------------------------------- /config/imagedataset_None_VGG_RNN_Attn.yaml: -------------------------------------------------------------------------------- 1 | name: crnn 2 | base: ['config/image_dataset.yaml'] 3 | arch: 4 | type: Model 5 | trans: 6 | type: None # TPS or None 7 | input_size: [32,320] 8 | num_fiducial: 20 9 | backbone: 10 | type: VGG 11 | conv_type: BasicConv 12 | neck: 13 | type: RNNDecoder # RNNDecoder or CNNDecoder or Reshape, Reshape 表示不使用decode 14 | hidden_size: 256 15 | head: 16 | type: Attn # CTC or Attn, Attn 必须和 RNNDecoder一起使用 17 | loss: 18 | type: AttnLoss 19 | 20 | optimizer: 21 | type: Adam # Adagrad 22 | args: 23 | lr: 0.001 24 | lr_scheduler: 25 | type: StepLR 26 | args: 27 | step_size: 30 28 | gamma: 0.1 29 | trainer: 30 | seed: 2 31 | gpus: 32 | - 0 33 | epochs: 10 34 | log_iter: 10 35 | resume_checkpoint: '' 36 | finetune_checkpoint: '' 37 | output_dir: output 38 | tensorboard: true 39 | dataset: 40 | alphabet: digit.txt 41 | train: 42 | dataset: 43 | type: ImageDataset # 数据集类型 44 | args: 45 | data_path: 46 | - - path/train.txt 47 | data_ratio: 48 | - 1.0 49 | pre_processes: # 数据的预处理过程,包含augment和预处理 50 | # - type: IaaAugment # 使用imgaug进行变换 51 | # - type: RandomAug # 进行变换扭曲变换 52 | - type: Resize 53 | args: 54 | img_h: 32 55 | img_w: 120 56 | pad: true 57 | random_crop: false 58 | transforms: # 对图片进行的变换方式 59 | # - type: ColorJitter 60 | # args: 61 | # brightness: 0.5 62 | - type: ToTensor 63 | args: {} 64 | # - type: Normalize 65 | # args: 66 | # mean: [0.485, 0.456, 0.406] 67 | # std: [0.229, 0.224, 0.225] 68 | img_mode: RGB 69 | ignore_chinese_punctuation: true 70 | remove_blank: true 71 | loader: 72 | batch_size: 16 73 | shuffle: true 74 | pin_memory: false 75 | num_workers: 6 76 | validate: 77 | dataset: 78 | type: ImageDataset 79 | args: 80 | data_path: # [文件,文件],每个文件格式为 img_path \t gt 81 | - path/val.txt 82 | pre_processes: 83 | - type: Resize 84 | args: 85 | img_h: 32 86 | img_w: 120 87 | pad: true 88 | random_crop: false 89 | transforms: 90 | - type: ToTensor 91 | args: {} 92 | # - type: Normalize 93 | # args: 94 | # mean: [0.485, 0.456, 0.406] 95 | # std: [0.229, 0.224, 0.225] 96 | img_mode: RGB 97 | ignore_chinese_punctuation: true 98 | remove_blank: true 99 | loader: 100 | batch_size: 4 101 | shuffle: true 102 | pin_memory: false 103 | num_workers: 6 -------------------------------------------------------------------------------- /config/imagedataset_None_VGG_RNN_CTC.yaml: -------------------------------------------------------------------------------- 1 | name: crnn 2 | base: ['config/image_dataset.yaml'] 3 | arch: 4 | type: Model 5 | trans: 6 | type: None # TPS or None 7 | input_size: [32,320] 8 | num_fiducial: 20 9 | backbone: 10 | type: VGG 11 | conv_type: BasicConv 12 | neck: 13 | type: RNNDecoder # RNNDecoder or CNNDecoder or Reshape, Reshape 表示不使用decode 14 | hidden_size: 256 15 | head: 16 | type: CTC # CTC or Attn, Attn 必须和 RNNDecoder一起使用 17 | loss: 18 | type: CTCLoss 19 | blank: 0 20 | 21 | optimizer: 22 | type: Adam # Adagrad 23 | args: 24 | lr: 0.001 25 | lr_scheduler: 26 | type: StepLR 27 | args: 28 | step_size: 30 29 | gamma: 0.1 30 | trainer: 31 | seed: 2 32 | gpus: 33 | - 0 34 | epochs: 10 35 | log_iter: 10 36 | resume_checkpoint: '' 37 | finetune_checkpoint: '' 38 | output_dir: output 39 | tensorboard: true 40 | dataset: 41 | alphabet: digit.txt 42 | train: 43 | dataset: 44 | type: ImageDataset # 数据集类型 45 | args: 46 | data_path: 47 | - - path/train.txt 48 | data_ratio: 49 | - 1.0 50 | pre_processes: # 数据的预处理过程,包含augment和预处理 51 | # - type: IaaAugment # 使用imgaug进行变换 52 | # - type: RandomAug # 进行变换扭曲变换 53 | - type: Resize 54 | args: 55 | img_h: 32 56 | img_w: 120 57 | pad: true 58 | random_crop: false 59 | transforms: # 对图片进行的变换方式 60 | # - type: ColorJitter 61 | # args: 62 | # brightness: 0.5 63 | - type: ToTensor 64 | args: {} 65 | # - type: Normalize 66 | # args: 67 | # mean: [0.485, 0.456, 0.406] 68 | # std: [0.229, 0.224, 0.225] 69 | img_mode: RGB 70 | ignore_chinese_punctuation: true 71 | remove_blank: true 72 | loader: 73 | batch_size: 16 74 | shuffle: true 75 | pin_memory: false 76 | num_workers: 6 77 | validate: 78 | dataset: 79 | type: ImageDataset 80 | args: 81 | data_path: # [文件,文件],每个文件格式为 img_path \t gt 82 | - path/val.txt 83 | pre_processes: 84 | - type: Resize 85 | args: 86 | img_h: 32 87 | img_w: 120 88 | pad: true 89 | random_crop: false 90 | transforms: 91 | - type: ToTensor 92 | args: {} 93 | # - type: Normalize 94 | # args: 95 | # mean: [0.485, 0.456, 0.406] 96 | # std: [0.229, 0.224, 0.225] 97 | img_mode: RGB 98 | ignore_chinese_punctuation: true 99 | remove_blank: true 100 | loader: 101 | batch_size: 4 102 | shuffle: true 103 | pin_memory: false 104 | num_workers: 6 -------------------------------------------------------------------------------- /config/lmdb.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | alphabet: dict.txt 3 | train: 4 | dataset: 5 | type: LmdbDataset # 数据集类型 6 | args: 7 | data_path: # [[文件,文件],[文件,文件]],每个文件格式为 img_path \t gt,每个子list按照data_ratio的比例进行采样 8 | - -'' 9 | data_ratio: 10 | - 1.0 11 | pre_processes: # 数据的预处理过程,包含augment和预处理 12 | # - type: IaaAugment # 使用imgaug进行变换 13 | - type: Resize 14 | args: 15 | img_h: 32 16 | img_w: 320 17 | pad: true 18 | random_crop: true 19 | transforms: # 对图片进行的变换方式 20 | - type: ColorJitter 21 | args: 22 | brightness: 0.5 23 | - type: ToTensor 24 | args: {} 25 | - type: Normalize 26 | args: 27 | mean: [0.485, 0.456, 0.406] 28 | std: [0.229, 0.224, 0.225] 29 | img_mode: RGB 30 | ignore_chinese_punctuation: true 31 | remove_blank: true 32 | loader: 33 | batch_size: 8 34 | shuffle: true 35 | pin_memory: false 36 | num_workers: 6 37 | validate: 38 | dataset: 39 | type: LmdbDataset 40 | args: 41 | data_path: # [文件,文件],每个文件格式为 img_path \t gt 42 | - '' 43 | pre_processes: 44 | - type: Resize 45 | args: 46 | img_h: 32 47 | img_w: 320 48 | pad: true 49 | random_crop: false 50 | transforms: 51 | - type: ToTensor 52 | args: {} 53 | - type: Normalize 54 | args: 55 | mean: [0.485, 0.456, 0.406] 56 | std: [0.229, 0.224, 0.225] 57 | img_mode: RGB 58 | ignore_chinese_punctuation: true 59 | remove_blank: true 60 | loader: 61 | batch_size: 4 62 | shuffle: true 63 | pin_memory: false 64 | num_workers: 6 -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-11-16 下午5:46 3 | # @Author : zhoujun 4 | import copy 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | 8 | 9 | def get_dataset(data_path, module_name, transform, dataset_args): 10 | """ 11 | 获取训练dataset 12 | :param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ 13 | :param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset 14 | :param transform: 该数据集使用的transforms 15 | :param dataset_args: module_name的参数 16 | :return: 如果data_path列表不为空,返回对应的Dataset对象,否则None 17 | """ 18 | from . import dataset 19 | s_dataset = getattr(dataset, module_name)(transform=transform, data_path=data_path, **dataset_args) 20 | return s_dataset 21 | 22 | 23 | def get_transforms(transforms_config): 24 | tr_list = [] 25 | for item in transforms_config: 26 | if 'args' not in item: 27 | args = {} 28 | else: 29 | args = item['args'] 30 | cls = getattr(transforms, item['type'])(**args) 31 | tr_list.append(cls) 32 | tr_list = transforms.Compose(tr_list) 33 | return tr_list 34 | 35 | 36 | def get_dataloader(module_config, num_label): 37 | if module_config is None: 38 | return None 39 | config = copy.deepcopy(module_config) 40 | dataset_args = config['dataset']['args'] 41 | dataset_args['num_label'] = num_label 42 | if 'transforms' in dataset_args: 43 | img_transfroms = get_transforms(dataset_args.pop('transforms')) 44 | else: 45 | img_transfroms = None 46 | # 创建数据集 47 | dataset_name = config['dataset']['type'] 48 | data_path_list = dataset_args.pop('data_path') 49 | if 'data_ratio' in dataset_args: 50 | data_ratio = dataset_args.pop('data_ratio') 51 | else: 52 | data_ratio = [1.0] 53 | 54 | _dataset_list = [] 55 | for data_path in data_path_list: 56 | _dataset_list.append(get_dataset(data_path=data_path, module_name=dataset_name, dataset_args=dataset_args, transform=img_transfroms)) 57 | if len(data_ratio) > 1 and len(dataset_args['data_ratio']) == len(_dataset_list): 58 | from . import dataset 59 | loader = dataset.Batch_Balanced_Dataset(dataset_list=_dataset_list, ratio_list=data_ratio, loader_args=config['loader']) 60 | else: 61 | _dataset = _dataset_list[0] 62 | loader = DataLoader(dataset=_dataset, **config['loader']) 63 | loader.dataset_len = len(_dataset) 64 | return loader 65 | -------------------------------------------------------------------------------- /data_loader/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/8/23 22:18 3 | # @Author : zhoujun 4 | import sys 5 | import six 6 | import lmdb 7 | from PIL import Image 8 | import cv2 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader 12 | 13 | from utils import punctuation_mend, get_datalist 14 | from base import BaseDataSet 15 | 16 | 17 | class ImageDataset(BaseDataSet): 18 | def __init__(self, data_path: str, img_mode, num_label, ignore_chinese_punctuation, remove_blank, pre_processes, transform=None, **kwargs): 19 | """ 20 | 数据集初始化 21 | :param data_txt: 存储着图片路径和对于label的文件 22 | :param data_shape: 图片的大小(h,w) 23 | :param img_channel: 图片通道数 24 | :param num_label: 最大字符个数,应该和网络最终输出的序列宽度一样 25 | :param alphabet: 字母表 26 | """ 27 | super().__init__(data_path, img_mode, num_label, ignore_chinese_punctuation, remove_blank, pre_processes, transform, **kwargs) 28 | 29 | def load_data(self, data_path: str) -> list: 30 | return get_datalist(data_path, self.num_label) 31 | 32 | def get_sample(self, index): 33 | img_path, label = self.data_list[index] 34 | img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0) 35 | if self.img_mode == 'RGB': 36 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 37 | if self.ignore_chinese_punctuation: 38 | label = punctuation_mend(label) 39 | if self.remove_blank: 40 | label = label.replace(' ', '') 41 | return {'img': img, 'label': label} 42 | 43 | 44 | class LmdbDataset(BaseDataSet): 45 | def __init__(self, data_path: str, img_mode, num_label, ignore_chinese_punctuation, remove_blank, pre_processes, transform=None, **kwargs): 46 | super().__init__(data_path, img_mode, num_label, ignore_chinese_punctuation, remove_blank, pre_processes, transform, **kwargs) 47 | 48 | def get_sample(self, index): 49 | index = self.data_list[index] 50 | with self.env.begin(write=False) as txn: 51 | label_key = 'label-%09d'.encode() % index 52 | label = txn.get(label_key).decode('utf-8') 53 | img_key = 'image-%09d'.encode() % index 54 | imgbuf = txn.get(img_key) 55 | 56 | buf = six.BytesIO() 57 | buf.write(imgbuf) 58 | buf.seek(0) 59 | if self.img_mode == 'RGB': 60 | img = Image.open(buf).convert('RGB') # for color image 61 | elif self.img_mode == "GRAY": 62 | img = Image.open(buf).convert('L') 63 | else: 64 | raise NotImplementedError 65 | # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) 66 | if self.remove_blank: 67 | label = label.replace(' ', '') 68 | if self.ignore_chinese_punctuation: 69 | label = punctuation_mend(label) 70 | img = np.array(img) 71 | return img, label 72 | 73 | def load_data(self, data_path: str) -> list: 74 | self.env = lmdb.open(data_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 75 | if not self.env: 76 | print('cannot create lmdb from %s' % (data_path)) 77 | sys.exit(0) 78 | 79 | filtered_index_list = [] 80 | with self.env.begin(write=False) as txn: 81 | nSamples = int(txn.get('num-samples'.encode())) 82 | self.nSamples = nSamples 83 | for index in range(self.nSamples): 84 | index += 1 # lmdb starts with 1 85 | label_key = 'label-%09d'.encode() % index 86 | label = txn.get(label_key).decode('utf-8') 87 | if len(label) > self.num_label: 88 | # print(f'The length of the label is longer than max_length: length 89 | # {len(label)}, {label} in dataset {self.root}') 90 | continue 91 | 92 | # By default, images containing characters which are not in opt.character are filtered. 93 | # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. 94 | filtered_index_list.append(index) 95 | return filtered_index_list 96 | 97 | 98 | class Batch_Balanced_Dataset(object): 99 | def __init__(self, dataset_list: list, ratio_list: list, loader_args: dict): 100 | """ 101 | 对datasetlist里的dataset按照ratio_list里对应的比例组合,似的每个batch里的数据按按照比例采样的 102 | :param dataset_list: 数据集列表 103 | :param ratio_list: 比例列表 104 | :param loader_args: dataloader的配置 105 | """ 106 | assert sum(ratio_list) == 1 and len(dataset_list) == len(ratio_list) 107 | 108 | self.dataset_len = 0 109 | self.data_loader_list = [] 110 | self.dataloader_iter_list = [] 111 | all_batch_size = loader_args.pop('batch_size') 112 | for _dataset, batch_ratio_d in zip(dataset_list, ratio_list): 113 | _batch_size = max(round(all_batch_size * float(batch_ratio_d)), 1) 114 | _data_loader = DataLoader(dataset=_dataset, batch_size=_batch_size, drop_last=True, **loader_args) 115 | self.data_loader_list.append(_data_loader) 116 | self.dataloader_iter_list.append(iter(_data_loader)) 117 | self.dataset_len += len(_dataset) 118 | 119 | def __iter__(self): 120 | return self 121 | 122 | def __len__(self): 123 | return min([len(x) for x in self.data_loader_list]) 124 | 125 | def __next__(self): 126 | balanced_batch_images = [] 127 | balanced_batch_texts = [] 128 | 129 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 130 | try: 131 | image, text = next(data_loader_iter) 132 | balanced_batch_images.append(image) 133 | balanced_batch_texts += text 134 | except StopIteration: 135 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 136 | image, text = next(self.dataloader_iter_list[i]) 137 | balanced_batch_images.append(image) 138 | balanced_batch_texts += text 139 | except ValueError: 140 | pass 141 | 142 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 143 | return balanced_batch_images, balanced_batch_texts 144 | 145 | 146 | if __name__ == '__main__': 147 | import os 148 | from tqdm import tqdm 149 | import anyconfig 150 | from torchvision import transforms 151 | from utils import parse_config 152 | 153 | train_transfroms = transforms.Compose([ 154 | transforms.ColorJitter(brightness=0.5), 155 | transforms.ToTensor() 156 | ]) 157 | config = anyconfig.load(open("config/icdar2015_win.yaml", 'rb')) 158 | if 'base' in config: 159 | config = parse_config(config) 160 | if os.path.isfile(config['dataset']['alphabet']): 161 | config['dataset']['alphabet'] = str(np.load(config['dataset']['alphabet'])) 162 | 163 | dataset_args = config['dataset']['validate']['dataset']['args'] 164 | dataset_args['num_label'] = 80 165 | dataset_args['alphabet'] = config['dataset']['alphabet'] 166 | dataset = ImageDataset(transform=train_transfroms, **dataset_args) 167 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=2) 168 | for i, (images, labels) in enumerate(tqdm(data_loader)): 169 | pass 170 | print(images.shape) 171 | print(labels) 172 | img = images[0].numpy().transpose((1, 2, 0)) 173 | from matplotlib import pyplot as plt 174 | 175 | plt.imshow(img) 176 | plt.show() 177 | -------------------------------------------------------------------------------- /data_loader/modules/Text_Image_Augmentation_python/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/4/27 11:28 3 | # @Author : zhoujun 4 | import numpy as np 5 | from .augment import distort, stretch, perspective 6 | 7 | __all__ = ['RandomAug'] 8 | 9 | 10 | class RandomAug: 11 | def __init__(self): 12 | pass 13 | 14 | def __call__(self, img): 15 | if np.random.randn() > 0.3: 16 | img = distort(img, 3) 17 | elif np.random.randn() > 0.6: 18 | img = stretch(img, 3) 19 | else: 20 | img = perspective(img) 21 | return img 22 | 23 | 24 | if __name__ == '__main__': 25 | from matplotlib import pyplot as plt 26 | import cv2 27 | r = RandomAug() 28 | im = cv2.imread(r'D:\code\crnn.pytorch\0.jpg') 29 | plt.imshow(im) 30 | resize_img = r(im) 31 | plt.figure() 32 | plt.imshow(resize_img) 33 | plt.show() 34 | -------------------------------------------------------------------------------- /data_loader/modules/Text_Image_Augmentation_python/augment.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: RubanSeven 3 | 4 | # import cv2 5 | import numpy as np 6 | # from transform import get_perspective_transform, warp_perspective 7 | from .warp_mls import WarpMLS 8 | 9 | 10 | def distort(src, segment): 11 | img_h, img_w = src.shape[:2] 12 | 13 | cut = img_w // segment 14 | thresh = cut // 3 15 | # thresh = img_h // segment // 3 16 | # thresh = img_h // 5 17 | 18 | src_pts = list() 19 | dst_pts = list() 20 | 21 | src_pts.append([0, 0]) 22 | src_pts.append([img_w, 0]) 23 | src_pts.append([img_w, img_h]) 24 | src_pts.append([0, img_h]) 25 | 26 | dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)]) 27 | dst_pts.append([img_w - np.random.randint(thresh), np.random.randint(thresh)]) 28 | dst_pts.append([img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)]) 29 | dst_pts.append([np.random.randint(thresh), img_h - np.random.randint(thresh)]) 30 | 31 | half_thresh = thresh * 0.5 32 | 33 | for cut_idx in np.arange(1, segment, 1): 34 | src_pts.append([cut * cut_idx, 0]) 35 | src_pts.append([cut * cut_idx, img_h]) 36 | dst_pts.append([cut * cut_idx + np.random.randint(thresh) - half_thresh, 37 | np.random.randint(thresh) - half_thresh]) 38 | dst_pts.append([cut * cut_idx + np.random.randint(thresh) - half_thresh, 39 | img_h + np.random.randint(thresh) - half_thresh]) 40 | 41 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 42 | dst = trans.generate() 43 | 44 | return dst 45 | 46 | 47 | def stretch(src, segment): 48 | img_h, img_w = src.shape[:2] 49 | 50 | cut = img_w // segment 51 | thresh = cut * 4 // 5 52 | # thresh = img_h // segment // 3 53 | # thresh = img_h // 5 54 | 55 | src_pts = list() 56 | dst_pts = list() 57 | 58 | src_pts.append([0, 0]) 59 | src_pts.append([img_w, 0]) 60 | src_pts.append([img_w, img_h]) 61 | src_pts.append([0, img_h]) 62 | 63 | dst_pts.append([0, 0]) 64 | dst_pts.append([img_w, 0]) 65 | dst_pts.append([img_w, img_h]) 66 | dst_pts.append([0, img_h]) 67 | 68 | half_thresh = thresh * 0.5 69 | 70 | for cut_idx in np.arange(1, segment, 1): 71 | move = np.random.randint(thresh) - half_thresh 72 | src_pts.append([cut * cut_idx, 0]) 73 | src_pts.append([cut * cut_idx, img_h]) 74 | dst_pts.append([cut * cut_idx + move, 0]) 75 | dst_pts.append([cut * cut_idx + move, img_h]) 76 | 77 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 78 | dst = trans.generate() 79 | 80 | return dst 81 | 82 | 83 | def perspective(src): 84 | img_h, img_w = src.shape[:2] 85 | 86 | thresh = img_h // 2 87 | 88 | src_pts = list() 89 | dst_pts = list() 90 | 91 | src_pts.append([0, 0]) 92 | src_pts.append([img_w, 0]) 93 | src_pts.append([img_w, img_h]) 94 | src_pts.append([0, img_h]) 95 | 96 | dst_pts.append([0, np.random.randint(thresh)]) 97 | dst_pts.append([img_w, np.random.randint(thresh)]) 98 | dst_pts.append([img_w, img_h - np.random.randint(thresh)]) 99 | dst_pts.append([0, img_h - np.random.randint(thresh)]) 100 | 101 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 102 | dst = trans.generate() 103 | 104 | return dst 105 | 106 | # def distort(src, segment): 107 | # img_h, img_w = src.shape[:2] 108 | # dst = np.zeros_like(src, dtype=np.uint8) 109 | # 110 | # cut = img_w // segment 111 | # thresh = img_h // 8 112 | # 113 | # src_pts = list() 114 | # # dst_pts = list() 115 | # 116 | # src_pts.append([-np.random.randint(thresh), -np.random.randint(thresh)]) 117 | # src_pts.append([-np.random.randint(thresh), img_h + np.random.randint(thresh)]) 118 | # 119 | # # dst_pts.append([0, 0]) 120 | # # dst_pts.append([0, img_h]) 121 | # dst_box = np.array([[0, 0], [0, img_h], [cut, 0], [cut, img_h]], dtype=np.float32) 122 | # 123 | # half_thresh = thresh * 0.5 124 | # 125 | # for cut_idx in np.arange(1, segment, 1): 126 | # src_pts.append([cut * cut_idx + np.random.randint(thresh) - half_thresh, 127 | # np.random.randint(thresh) - half_thresh]) 128 | # src_pts.append([cut * cut_idx + np.random.randint(thresh) - half_thresh, 129 | # img_h + np.random.randint(thresh) - half_thresh]) 130 | # 131 | # # dst_pts.append([cut * i, 0]) 132 | # # dst_pts.append([cut * i, img_h]) 133 | # 134 | # src_box = np.array(src_pts[-4:-2] + src_pts[-2:-1] + src_pts[-1:], dtype=np.float32) 135 | # 136 | # # mat = cv2.getPerspectiveTransform(src_box, dst_box) 137 | # # print(mat) 138 | # # dst[:, cut * (cut_idx - 1):cut * cut_idx] = cv2.warpPerspective(src, mat, (cut, img_h)) 139 | # 140 | # mat = get_perspective_transform(dst_box, src_box) 141 | # dst[:, cut * (cut_idx - 1):cut * cut_idx] = warp_perspective(src, mat, (cut, img_h)) 142 | # # print(mat) 143 | # 144 | # src_pts.append([img_w + np.random.randint(thresh) - half_thresh, 145 | # np.random.randint(thresh) - half_thresh]) 146 | # src_pts.append([img_w + np.random.randint(thresh) - half_thresh, 147 | # img_h + np.random.randint(thresh) - half_thresh]) 148 | # src_box = np.array(src_pts[-4:-2] + src_pts[-2:-1] + src_pts[-1:], dtype=np.float32) 149 | # 150 | # # mat = cv2.getPerspectiveTransform(src_box, dst_box) 151 | # # dst[:, cut * (segment - 1):] = cv2.warpPerspective(src, mat, (img_w - cut * (segment - 1), img_h)) 152 | # mat = get_perspective_transform(dst_box, src_box) 153 | # dst[:, cut * (segment - 1):] = warp_perspective(src, mat, (img_w - cut * (segment - 1), img_h)) 154 | # 155 | # return dst 156 | -------------------------------------------------------------------------------- /data_loader/modules/Text_Image_Augmentation_python/demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: RubanSeven 3 | 4 | import cv2 5 | import imageio 6 | from .augment import distort, stretch, perspective 7 | 8 | 9 | def create_gif(image_list, gif_name, duration=0.1): 10 | frames = [] 11 | for image in image_list: 12 | frames.append(image) 13 | imageio.mimsave(gif_name, frames, 'GIF', duration=duration) 14 | return 15 | 16 | 17 | if __name__ == '__main__': 18 | im = cv2.imread("imgs/demo.png") 19 | im = cv2.resize(im, (200, 64)) 20 | cv2.imshow("im_CV", im) 21 | distort_img_list = list() 22 | stretch_img_list = list() 23 | perspective_img_list = list() 24 | for i in range(12): 25 | distort_img = distort(im, 4) 26 | distort_img_list.append(distort_img) 27 | cv2.imshow("distort_img", distort_img) 28 | 29 | stretch_img = stretch(im, 4) 30 | cv2.imshow("stretch_img", stretch_img) 31 | stretch_img_list.append(stretch_img) 32 | 33 | perspective_img = perspective(im) 34 | cv2.imshow("perspective_img", perspective_img) 35 | perspective_img_list.append(perspective_img) 36 | cv2.waitKey(100) 37 | 38 | create_gif(distort_img_list, r'imgs/distort.gif') 39 | create_gif(stretch_img_list, r'imgs/stretch.gif') 40 | create_gif(perspective_img_list, r'imgs/perspective.gif') 41 | -------------------------------------------------------------------------------- /data_loader/modules/Text_Image_Augmentation_python/warp_mls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: RubanSeven 3 | import math 4 | 5 | import numpy as np 6 | 7 | 8 | class WarpMLS: 9 | def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.): 10 | self.src = src 11 | self.src_pts = src_pts 12 | self.dst_pts = dst_pts 13 | self.pt_count = len(self.dst_pts) 14 | self.dst_w = dst_w 15 | self.dst_h = dst_h 16 | self.trans_ratio = trans_ratio 17 | self.grid_size = 100 18 | self.rdx = np.zeros((self.dst_h, self.dst_w)) 19 | self.rdy = np.zeros((self.dst_h, self.dst_w)) 20 | 21 | @staticmethod 22 | def __bilinear_interp(x, y, v11, v12, v21, v22): 23 | return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 * (1 - y) + v22 * y) * x 24 | 25 | def generate(self): 26 | self.calc_delta() 27 | return self.gen_img() 28 | 29 | def calc_delta(self): 30 | w = np.zeros(self.pt_count, dtype=np.float32) 31 | 32 | if self.pt_count < 2: 33 | return 34 | 35 | i = 0 36 | while 1: 37 | if self.dst_w <= i < self.dst_w + self.grid_size - 1: 38 | i = self.dst_w - 1 39 | elif i >= self.dst_w: 40 | break 41 | 42 | j = 0 43 | while 1: 44 | if self.dst_h <= j < self.dst_h + self.grid_size - 1: 45 | j = self.dst_h - 1 46 | elif j >= self.dst_h: 47 | break 48 | 49 | sw = 0 50 | swp = np.zeros(2, dtype=np.float32) 51 | swq = np.zeros(2, dtype=np.float32) 52 | new_pt = np.zeros(2, dtype=np.float32) 53 | cur_pt = np.array([i, j], dtype=np.float32) 54 | 55 | k = 0 56 | for k in range(self.pt_count): 57 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 58 | break 59 | 60 | w[k] = 1. / ((i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) + 61 | (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1])) 62 | 63 | sw += w[k] 64 | swp = swp + w[k] * np.array(self.dst_pts[k]) 65 | swq = swq + w[k] * np.array(self.src_pts[k]) 66 | 67 | if k == self.pt_count - 1: 68 | pstar = 1 / sw * swp 69 | qstar = 1 / sw * swq 70 | 71 | miu_s = 0 72 | for k in range(self.pt_count): 73 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 74 | continue 75 | pt_i = self.dst_pts[k] - pstar 76 | miu_s += w[k] * np.sum(pt_i * pt_i) 77 | 78 | cur_pt -= pstar 79 | cur_pt_j = np.array([-cur_pt[1], cur_pt[0]]) 80 | 81 | for k in range(self.pt_count): 82 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 83 | continue 84 | 85 | pt_i = self.dst_pts[k] - pstar 86 | pt_j = np.array([-pt_i[1], pt_i[0]]) 87 | 88 | tmp_pt = np.zeros(2, dtype=np.float32) 89 | tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \ 90 | np.sum(pt_j * cur_pt) * self.src_pts[k][1] 91 | tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \ 92 | np.sum(pt_j * cur_pt_j) * self.src_pts[k][1] 93 | tmp_pt *= (w[k] / miu_s) 94 | new_pt += tmp_pt 95 | 96 | new_pt += qstar 97 | else: 98 | new_pt = self.src_pts[k] 99 | 100 | self.rdx[j, i] = new_pt[0] - i 101 | self.rdy[j, i] = new_pt[1] - j 102 | 103 | j += self.grid_size 104 | i += self.grid_size 105 | 106 | def gen_img(self): 107 | src_h, src_w = self.src.shape[:2] 108 | dst = np.zeros_like(self.src, dtype=np.float32) 109 | 110 | for i in np.arange(0, self.dst_h, self.grid_size): 111 | for j in np.arange(0, self.dst_w, self.grid_size): 112 | ni = i + self.grid_size 113 | nj = j + self.grid_size 114 | w = h = self.grid_size 115 | if ni >= self.dst_h: 116 | ni = self.dst_h - 1 117 | h = ni - i + 1 118 | if nj >= self.dst_w: 119 | nj = self.dst_w - 1 120 | w = nj - j + 1 121 | 122 | di = np.reshape(np.arange(h), (-1, 1)) 123 | dj = np.reshape(np.arange(w), (1, -1)) 124 | delta_x = self.__bilinear_interp(di / h, dj / w, 125 | self.rdx[i, j], self.rdx[i, nj], 126 | self.rdx[ni, j], self.rdx[ni, nj]) 127 | delta_y = self.__bilinear_interp(di / h, dj / w, 128 | self.rdy[i, j], self.rdy[i, nj], 129 | self.rdy[ni, j], self.rdy[ni, nj]) 130 | nx = j + dj + delta_x * self.trans_ratio 131 | ny = i + di + delta_y * self.trans_ratio 132 | nx = np.clip(nx, 0, src_w - 1) 133 | ny = np.clip(ny, 0, src_h - 1) 134 | nxi = np.array(np.floor(nx), dtype=np.int32) 135 | nyi = np.array(np.floor(ny), dtype=np.int32) 136 | nxi1 = np.array(np.ceil(nx), dtype=np.int32) 137 | nyi1 = np.array(np.ceil(ny), dtype=np.int32) 138 | 139 | if len(self.src.shape) == 3: 140 | x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3)) 141 | y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3)) 142 | else: 143 | x = ny - nyi 144 | y = nx - nxi 145 | dst[i:i + h, j:j + w] = self.__bilinear_interp(x, 146 | y, 147 | self.src[nyi, nxi], 148 | self.src[nyi, nxi1], 149 | self.src[nyi1, nxi], 150 | self.src[nyi1, nxi1] 151 | ) 152 | 153 | # for di in range(h): 154 | # for dj in range(w): 155 | # # print(ni, nj, i, j) 156 | # delta_x = self.__bilinear_interp(di / h, dj / w, self.rdx[i, j], self.rdx[i, nj], 157 | # self.rdx[ni, j], self.rdx[ni, nj]) 158 | # delta_y = self.__bilinear_interp(di / h, dj / w, self.rdy[i, j], self.rdy[i, nj], 159 | # self.rdy[ni, j], self.rdy[ni, nj]) 160 | # nx = j + dj + delta_x * self.trans_ratio 161 | # ny = i + di + delta_y * self.trans_ratio 162 | # nx = min(src_w - 1, max(0, nx)) 163 | # ny = min(src_h - 1, max(0, ny)) 164 | # nxi = int(nx) 165 | # nyi = int(ny) 166 | # nxi1 = math.ceil(nx) 167 | # nyi1 = math.ceil(ny) 168 | # 169 | # dst[i + di, j + dj] = self.__bilinear_interp(ny - nyi, nx - nxi, 170 | # self.src[nyi, nxi], 171 | # self.src[nyi, nxi1], 172 | # self.src[nyi1, nxi], 173 | # self.src[nyi1, nxi1] 174 | # ) 175 | 176 | dst = np.clip(dst, 0, 255) 177 | dst = np.array(dst, dtype=np.uint8) 178 | 179 | return dst 180 | -------------------------------------------------------------------------------- /data_loader/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zhoujun 5 | @time: 2019/12/19 下午3:18 6 | ''' 7 | from .augment import IaaAugment 8 | from .resize import Resize 9 | from .Text_Image_Augmentation_python import RandomAug -------------------------------------------------------------------------------- /data_loader/modules/augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zhoujun 5 | @time: 2019/12/19 下午3:18 6 | ''' 7 | import imgaug.augmenters as iaa 8 | 9 | __all__ = ['IaaAugment'] 10 | 11 | 12 | class IaaAugment(): 13 | def __init__(self): 14 | self.seq = iaa.Sequential([ 15 | iaa.ChannelShuffle(0.5), 16 | iaa.Sometimes(0.5, iaa.OneOf([ 17 | iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0 18 | iaa.AverageBlur(k=(2, 7)), # blur image using local means with kernel sizes between 2 and 7 19 | iaa.MedianBlur(k=(3, 11)), # blur image using local medians with kernel sizes between 2 and 7 20 | ])), 21 | iaa.Sometimes(0.5, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5)), 22 | iaa.Sometimes(0.5, iaa.BlendAlphaFrequencyNoise( 23 | exponent=(-4, 0), 24 | foreground=iaa.Multiply((0.5, 1.5), per_channel=True), 25 | background=iaa.LinearContrast((0.5, 2.0)) 26 | )), 27 | # iaa.Sometimes(0.5, iaa.PiecewiseAffine(scale=(0.01, 0.05))), 28 | # iaa.Sometimes(0.5, iaa.PerspectiveTransform(scale=(0.01, 0.1))) 29 | ], random_order=True) 30 | 31 | def __call__(self, img): 32 | img = self.seq.augment_image(img) 33 | return img 34 | 35 | 36 | if __name__ == '__main__': 37 | import cv2 38 | from matplotlib import pyplot as plt 39 | 40 | r = IaaAugment() 41 | im = cv2.imread('0.jpg') 42 | plt.imshow(im) 43 | resize_img = r(im) 44 | plt.figure() 45 | plt.imshow(resize_img) 46 | plt.show() -------------------------------------------------------------------------------- /data_loader/modules/resize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zhoujun 5 | @time: 2019/12/19 下午3:23 6 | ''' 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | from torchvision import transforms 11 | 12 | class Resize: 13 | def __init__(self, img_h, img_w, pad=True, random_crop=False, **kwargs): 14 | self.img_h = img_h 15 | self.img_w = img_w 16 | self.pad = pad 17 | self.random_crop = random_crop 18 | 19 | def __call__(self, img: np.ndarray): 20 | """ 21 | 对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度 22 | :param img_path: 图片地址 23 | :return: 24 | """ 25 | img_h = self.img_h 26 | img_w = self.img_w 27 | augment = self.random_crop and np.random.rand() > 0.5 28 | if augment: 29 | img_h += 20 30 | img_w += 20 31 | h, w = img.shape[:2] 32 | ratio_h = self.img_h / h 33 | new_w = int(w * ratio_h) 34 | if new_w < img_w and self.pad: 35 | img = cv2.resize(img, (new_w, img_h)) 36 | if len(img.shape) == 2: 37 | img = np.expand_dims(img, 2) 38 | step = np.zeros((img_h, img_w - new_w, img.shape[-1]), dtype=img.dtype) 39 | img = np.column_stack((img, step)) 40 | else: 41 | img = cv2.resize(img, (img_w, img_h)) 42 | if len(img.shape) == 2: 43 | img = np.expand_dims(img, 2) 44 | if img.shape[-1] == 1: 45 | img = img[:, :, 0] 46 | if augment: 47 | img = transforms.RandomCrop((self.img_h, self.img_w))(Image.fromarray(img)) 48 | img = np.array(img) 49 | return img 50 | 51 | 52 | if __name__ == '__main__': 53 | from matplotlib import pyplot as plt 54 | 55 | r = Resize(32, 320,random_crop=True) 56 | im = cv2.imread('0.jpg', 1) 57 | plt.imshow(im) 58 | plt.show() 59 | resize_img = r(im) 60 | plt.imshow(resize_img) 61 | plt.show() 62 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from .model import Model 3 | from .losses import build_loss 4 | 5 | __all__ = ['build_loss', 'build_model'] 6 | support_model = ['Model'] 7 | 8 | 9 | def build_model(config): 10 | """ 11 | get architecture model class 12 | """ 13 | copy_config = copy.deepcopy(config) 14 | arch_type = copy_config.pop('type') 15 | assert arch_type in support_model, f'{arch_type} is not developed yet!, only {support_model} are support now' 16 | arch_model = eval(arch_type)(copy_config) 17 | return arch_model 18 | -------------------------------------------------------------------------------- /modeling/backbone/MobileNetV3.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | __all__ = ['MobileNetV3'] 9 | 10 | 11 | class HSwish(nn.Module): 12 | def forward(self, x): 13 | out = x * F.relu6(x + 3, inplace=True) / 6 14 | return out 15 | 16 | 17 | class ConvBNACT(nn.Module): 18 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None): 19 | super().__init__() 20 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 21 | stride=stride, padding=padding, groups=groups, 22 | bias=False) 23 | self.bn = nn.BatchNorm2d(out_channels) 24 | if act == 'relu': 25 | self.act = nn.ReLU() 26 | elif act == 'hard_swish': 27 | self.act = HSwish() 28 | elif act is None: 29 | self.act = None 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | x = self.bn(x) 34 | if self.act is not None: 35 | x = self.act(x) 36 | return x 37 | 38 | 39 | class HardSigmoid(nn.Module): 40 | def __init__(self, slope=.2, offset=.5): 41 | super().__init__() 42 | self.slope = slope 43 | self.offset = offset 44 | 45 | def forward(self, x): 46 | x = (self.slope * x) + self.offset 47 | x = F.threshold(-x, -1, -1) 48 | x = F.threshold(-x, 0, 0) 49 | return x 50 | 51 | 52 | class SEBlock(nn.Module): 53 | def __init__(self, in_channels, out_channels, ratio=4): 54 | super().__init__() 55 | num_mid_filter = out_channels // ratio 56 | self.pool = nn.AdaptiveAvgPool2d(1) 57 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_mid_filter, kernel_size=1, bias=True) 58 | self.relu1 = nn.ReLU() 59 | self.conv2 = nn.Conv2d(in_channels=num_mid_filter, kernel_size=1, out_channels=out_channels, bias=True) 60 | self.relu2 = HardSigmoid() 61 | 62 | def forward(self, x): 63 | attn = self.pool(x) 64 | attn = self.conv1(attn) 65 | attn = self.relu1(attn) 66 | attn = self.conv2(attn) 67 | attn = self.relu2(attn) 68 | return x * attn 69 | 70 | 71 | class ResidualUnit(nn.Module): 72 | def __init__(self, num_in_filter, num_mid_filter, num_out_filter, stride, kernel_size, act=None, use_se=False): 73 | super().__init__() 74 | self.conv0 = ConvBNACT(in_channels=num_in_filter, out_channels=num_mid_filter, kernel_size=1, stride=1, 75 | padding=0, act=act) 76 | 77 | self.conv1 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_mid_filter, kernel_size=kernel_size, 78 | stride=stride, 79 | padding=int((kernel_size - 1) // 2), act=act, groups=num_mid_filter) 80 | if use_se: 81 | self.se = SEBlock(in_channels=num_mid_filter, out_channels=num_mid_filter) 82 | else: 83 | self.se = None 84 | 85 | self.conv2 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_out_filter, kernel_size=1, stride=1, 86 | padding=0) 87 | self.not_add = num_in_filter != num_out_filter or stride != 1 88 | 89 | def forward(self, x): 90 | y = self.conv0(x) 91 | y = self.conv1(y) 92 | if self.se is not None: 93 | y = self.se(y) 94 | y = self.conv2(y) 95 | if not self.not_add: 96 | y = x + y 97 | return y 98 | 99 | 100 | class MobileNetV3(nn.Module): 101 | def __init__(self, in_channels, **kwargs): 102 | super().__init__() 103 | self.scale = kwargs.get('scale', 0.5) 104 | model_name = kwargs.get('model_name', 'large') 105 | self.inplanes = 16 106 | if model_name == "large": 107 | self.cfg = [ 108 | # k, exp, c, se, nl, s, 109 | [3, 16, 16, False, 'relu', 1], 110 | [3, 64, 24, False, 'relu', (2, 1)], 111 | [3, 72, 24, False, 'relu', 1], 112 | [5, 72, 40, True, 'relu', (2, 1)], 113 | [5, 120, 40, True, 'relu', 1], 114 | [5, 120, 40, True, 'relu', 1], 115 | [3, 240, 80, False, 'hard_swish', 1], 116 | [3, 200, 80, False, 'hard_swish', 1], 117 | [3, 184, 80, False, 'hard_swish', 1], 118 | [3, 184, 80, False, 'hard_swish', 1], 119 | [3, 480, 112, True, 'hard_swish', 1], 120 | [3, 672, 112, True, 'hard_swish', 1], 121 | [5, 672, 160, True, 'hard_swish', (2, 1)], 122 | [5, 960, 160, True, 'hard_swish', 1], 123 | [5, 960, 160, True, 'hard_swish', 1], 124 | ] 125 | self.cls_ch_squeeze = 960 126 | self.cls_ch_expand = 1280 127 | elif model_name == "small": 128 | self.cfg = [ 129 | # k, exp, c, se, nl, s, 130 | [3, 16, 16, True, 'relu', (2, 1)], 131 | [3, 72, 24, False, 'relu', (2, 1)], 132 | [3, 88, 24, False, 'relu', 1], 133 | [5, 96, 40, True, 'hard_swish', (2, 1)], 134 | [5, 240, 40, True, 'hard_swish', 1], 135 | [5, 240, 40, True, 'hard_swish', 1], 136 | [5, 120, 48, True, 'hard_swish', 1], 137 | [5, 144, 48, True, 'hard_swish', 1], 138 | [5, 288, 96, True, 'hard_swish', (2, 1)], 139 | [5, 576, 96, True, 'hard_swish', 1], 140 | [5, 576, 96, True, 'hard_swish', 1], 141 | ] 142 | self.cls_ch_squeeze = 576 143 | self.cls_ch_expand = 1280 144 | else: 145 | raise NotImplementedError("mode[" + model_name + 146 | "_model] is not implemented!") 147 | 148 | supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25] 149 | assert self.scale in supported_scale, "supported scale are {} but input scale is {}".format(supported_scale, 150 | self.scale) 151 | 152 | scale = self.scale 153 | inplanes = self.inplanes 154 | cfg = self.cfg 155 | cls_ch_squeeze = self.cls_ch_squeeze 156 | # conv1 157 | self.conv1 = ConvBNACT(in_channels=in_channels, 158 | out_channels=self.make_divisible(inplanes * scale), 159 | kernel_size=3, 160 | stride=2, 161 | padding=1, 162 | groups=1, 163 | act='hard_swish') 164 | inplanes = self.make_divisible(inplanes * scale) 165 | block_list = [] 166 | for layer_cfg in cfg: 167 | block = ResidualUnit(num_in_filter=inplanes, 168 | num_mid_filter=self.make_divisible(scale * layer_cfg[1]), 169 | num_out_filter=self.make_divisible(scale * layer_cfg[2]), 170 | act=layer_cfg[4], 171 | stride=layer_cfg[5], 172 | kernel_size=layer_cfg[0], 173 | use_se=layer_cfg[3]) 174 | block_list.append(block) 175 | inplanes = self.make_divisible(scale * layer_cfg[2]) 176 | 177 | self.block_list = nn.Sequential(*block_list) 178 | self.conv2 = ConvBNACT(in_channels=inplanes, 179 | out_channels=self.make_divisible(scale * cls_ch_squeeze), 180 | kernel_size=1, 181 | stride=1, 182 | padding=0, 183 | groups=1, 184 | act='hard_swish') 185 | 186 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 187 | self.out_channels = self.make_divisible(scale * cls_ch_squeeze) 188 | 189 | def make_divisible(self, v, divisor=8, min_value=None): 190 | if min_value is None: 191 | min_value = divisor 192 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 193 | if new_v < 0.9 * v: 194 | new_v += divisor 195 | return new_v 196 | 197 | def forward(self, x): 198 | x = self.conv1(x) 199 | x = self.block_list(x) 200 | x = self.conv2(x) 201 | x = self.pool(x) 202 | return x 203 | -------------------------------------------------------------------------------- /modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/4/25 12:06 3 | # @Author : zhoujun 4 | 5 | from .feature_extraction import CNN_lite, VGG, ResNet, DenseNet 6 | from .resnet import ResNet_FeatureExtractor, ResNet_MT 7 | from .resnet_torch import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2 8 | from .MobileNetV3 import MobileNetV3 9 | 10 | __all__ = ['build_backbone'] 11 | support_backbone = ['CNN_lite', 'VGG', 'ResNet', 'DenseNet', 12 | 'ResNet_FeatureExtractor', 'ResNet_MT', 13 | 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 14 | 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', 15 | 'MobileNetV3'] 16 | 17 | 18 | def build_backbone(backbone_name, **kwargs): 19 | assert backbone_name in support_backbone, f'all support backbone is {support_backbone}' 20 | backbone = eval(backbone_name)(**kwargs) 21 | return backbone 22 | -------------------------------------------------------------------------------- /modeling/backbone/feature_extraction.py: -------------------------------------------------------------------------------- 1 | from modeling.basic import * 2 | from torchvision.models.densenet import _DenseBlock 3 | 4 | 5 | class CNN_lite(nn.Module): 6 | # a vgg like net 7 | def __init__(self, in_channels, **kwargs): 8 | super().__init__() 9 | channels = [24, 128, 256, 256, 512, 512, 512] 10 | self.out_channels = channels[-1] 11 | self.cnn = nn.Sequential( 12 | nn.Conv2d(in_channels, channels[0], kernel_size=5, stride=2, padding=2), 13 | nn.ReLU(True), 14 | DWConv(channels[0], channels[1], kernel_size=3, stride=1, padding=1), 15 | nn.MaxPool2d(2, 2), 16 | DWConv(channels[1], channels[2], kernel_size=3, stride=1, padding=1, use_bn=True), 17 | DWConv(channels[2], channels[3], kernel_size=3, stride=1, padding=1), 18 | nn.MaxPool2d((2, 2), (2, 1), (0, 1)), 19 | DWConv(channels[3], channels[4], kernel_size=3, stride=1, padding=1, use_bn=True), 20 | DWConv(channels[4], channels[5], kernel_size=3, stride=1, padding=1), 21 | nn.MaxPool2d((2, 2), (2, 1), (0, 1)), 22 | DWConv(channels[5], channels[6], kernel_size=2, stride=1, padding=0, use_bn=True), 23 | ) 24 | 25 | def forward(self, x): 26 | conv = self.cnn(x) 27 | return conv 28 | 29 | 30 | class VGG(nn.Module): 31 | def __init__(self, in_channels, **kwargs): 32 | super().__init__() 33 | conv_type = kwargs.get('conv_type', 'BasicConv') 34 | assert conv_type in ['BasicConv', 'DWConv', 'GhostModule'] 35 | basic_conv = globals()[conv_type] 36 | channels = [64, 128, 256, 256, 512, 512, 512] 37 | self.features = nn.Sequential( 38 | # conv layer 39 | BasicConv(in_channels=in_channels, out_channels=channels[0], kernel_size=3, padding=1, use_bn=False), 40 | nn.MaxPool2d(kernel_size=2, stride=2), 41 | 42 | # second conv layer 43 | basic_conv(in_channels=channels[0], out_channels=channels[1], kernel_size=3, padding=1, use_bn=False), 44 | nn.MaxPool2d(kernel_size=2, stride=2), 45 | 46 | # third conv layer 47 | basic_conv(in_channels=channels[1], out_channels=channels[2], kernel_size=3, padding=1, use_bn=False), 48 | 49 | # fourth conv layer 50 | basic_conv(in_channels=channels[2], out_channels=channels[3], kernel_size=3, padding=1, use_bn=False), 51 | nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), 52 | 53 | # fifth conv layer 54 | basic_conv(in_channels=channels[3], out_channels=channels[4], kernel_size=3, padding=1, bias=False), 55 | 56 | # sixth conv layer 57 | basic_conv(in_channels=channels[4], out_channels=channels[5], kernel_size=3, padding=1, bias=False), 58 | nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), 59 | 60 | # seren conv layer 61 | BasicConv(in_channels=channels[5], out_channels=channels[6], kernel_size=2, use_bn=False, use_relu=True), 62 | ) 63 | self.out_channels = channels[-1] 64 | 65 | def forward(self, x): 66 | return self.features(x) 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, in_channels, **kwargs): 71 | super().__init__() 72 | conv_type = kwargs.get('conv_type', 'BasicBlockV2') 73 | assert conv_type in ['BasicBlockV2', 'DWBlock', 'GhostBottleneck'] 74 | 75 | BasicBlock = globals()[conv_type] 76 | 77 | channels = [64, 64, 64, 128, 128, 256, 256, 512, 512, 512] 78 | expand_size = [64, 64, 128, 128, 256] 79 | self.out_channels = channels[-1] 80 | 81 | self.features = nn.Sequential( 82 | BasicConv(in_channels=in_channels, out_channels=channels[0], kernel_size=3, padding=1, bias=False), 83 | nn.MaxPool2d(kernel_size=2, stride=2), 84 | # nn.Conv2d(in_channels=channels[0], out_channels=channels[1], kernel_size=2, stride=2, bias=False), 85 | 86 | BasicBlock(in_channels=channels[0], out_channels=channels[2], expand_size=expand_size[0], kernel_size=3, 87 | stride=1), 88 | BasicBlock(in_channels=channels[2], out_channels=channels[3], expand_size=expand_size[1], kernel_size=3, 89 | stride=1), 90 | nn.Dropout(0.2), 91 | 92 | BasicBlock(in_channels=channels[3], out_channels=channels[4], expand_size=expand_size[2], kernel_size=3, 93 | stride=2, use_se=True), 94 | BasicBlock(in_channels=channels[4], out_channels=channels[5], expand_size=expand_size[3], kernel_size=3, 95 | stride=1, use_se=True), 96 | nn.Dropout(0.2), 97 | 98 | nn.Conv2d(in_channels=channels[5], out_channels=channels[6], kernel_size=2, stride=(2, 1), padding=(0, 1), 99 | bias=False), 100 | 101 | BasicBlock(in_channels=channels[6], out_channels=channels[7], expand_size=expand_size[4], kernel_size=3, 102 | stride=1, use_se=True), 103 | nn.BatchNorm2d(512), 104 | nn.ReLU(), 105 | # nn.MaxPool2d(kernel_size=2, stride=(2,1)), 106 | BasicConv(in_channels=channels[7], out_channels=channels[8], kernel_size=3, padding=0, bias=False), 107 | BasicConv(in_channels=channels[8], out_channels=channels[9], kernel_size=2, padding=(0, 1), bias=False), 108 | ) 109 | 110 | def forward(self, x): 111 | return self.features(x) 112 | 113 | def _make_transition(in_channels, out_channels, pool_stride, pool_pad, dropout): 114 | out = nn.Sequential( 115 | nn.BatchNorm2d(in_channels), 116 | nn.ReLU(), 117 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 118 | ) 119 | if dropout: 120 | out.add_module('dropout', nn.Dropout(dropout)) 121 | out.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=pool_stride, padding=pool_pad)) 122 | return out 123 | 124 | 125 | class DenseNet(nn.Module): 126 | def __init__(self, in_channels, **kwargs): 127 | super(DenseNet, self).__init__() 128 | self.features = nn.Sequential( 129 | nn.Conv2d(in_channels, 64, 5, padding=2, stride=2, bias=False), 130 | _DenseBlock(8, 64, 4, 8, 0), 131 | _make_transition(128, 128, 2, 0, 0.2), 132 | 133 | _DenseBlock(8, 128, 4, 8, 0), 134 | _make_transition(192, 128, (2, 1), 0, 0.2), 135 | 136 | _DenseBlock(8, 128, 4, 8, 0), 137 | 138 | nn.BatchNorm2d(192), 139 | nn.ReLU() 140 | ) 141 | self.out_channels = 768 142 | 143 | def forward(self, x): 144 | x = self.features(x) 145 | B, C, H, W = x.shape 146 | x = x.reshape((B, C * H, 1, W)) 147 | return x 148 | 149 | 150 | if __name__ == '__main__': 151 | import torch 152 | 153 | device = torch.device('cpu') 154 | net = VGG(3).to(device) 155 | a = torch.randn(2, 3, 32, 320).to(device) 156 | import time 157 | 158 | tic = time.time() 159 | for i in range(1): 160 | b = net(a)[0] 161 | print(b.shape) 162 | print((time.time() - tic) / 1) 163 | -------------------------------------------------------------------------------- /modeling/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/4/11 10:29 3 | # @Author : zhoujun 4 | from modeling.basic import * 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False): 11 | super(BasicBlock, self).__init__() 12 | self.se = CBAM(out_channels) if use_cbam else None 13 | self.conv1 = BasicConv(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False, use_bn=True, use_relu=True, inplace=True) 14 | self.conv2 = BasicConv(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False, use_bn=True, use_relu=False) 15 | self.downsample = downsample 16 | self.stride = stride 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | residual = x 21 | out = self.conv1(x) 22 | out = self.conv2(out) 23 | if self.downsample is not None: 24 | residual = self.downsample(x) 25 | if self.se != None: 26 | out = self.se(out) 27 | out += residual 28 | out = self.relu(out) 29 | 30 | return out 31 | 32 | 33 | class ReaNet(nn.Module): 34 | def __init__(self, out_channels=512): 35 | super().__init__() 36 | self.out_channels = out_channels 37 | 38 | def _make_layer(self, block, planes, blocks, stride=1, use_cbam=False): 39 | downsample = None 40 | if stride != 1 or self.inplanes != planes * block.expansion: 41 | downsample = BasicConv(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, 42 | use_bn=True, use_relu=False) 43 | 44 | layers = [] 45 | layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=use_cbam)) 46 | self.inplanes = planes * block.expansion 47 | for i in range(1, blocks): 48 | layers.append(block(self.inplanes, planes, use_cbam=use_cbam)) 49 | 50 | return nn.Sequential(*layers) 51 | 52 | 53 | class ResNet_FeatureExtractor(ReaNet): 54 | """ FeatureExtractor of https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/feature_extraction.py """ 55 | 56 | def __init__(self, in_channels, out_channels=512, **kwargs): 57 | super().__init__(out_channels) 58 | layers = [1, 2, 5, 3] 59 | block = BasicBlock 60 | output_channel_block = [int(out_channels / 4), int(out_channels / 2), out_channels, out_channels] 61 | 62 | self.inplanes = int(out_channels / 8) 63 | self.conv0 = nn.Sequential( 64 | BasicConv(in_channels, int(out_channels / 16), kernel_size=3, stride=1, padding=1, bias=False, use_bn=True, 65 | use_relu=True, inplace=True), 66 | BasicConv(int(out_channels / 16), self.inplanes, kernel_size=3, stride=1, padding=1, bias=False, 67 | use_bn=True, use_relu=True, inplace=True) 68 | ) 69 | 70 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 71 | self.layer1 = self._make_layer(block, output_channel_block[0], layers[0]) 72 | self.conv1 = BasicConv(output_channel_block[0], output_channel_block[0], kernel_size=3, stride=1, padding=1, 73 | bias=False, use_bn=True, use_relu=True) 74 | 75 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 76 | self.layer2 = self._make_layer(block, output_channel_block[1], layers[1], stride=1) 77 | self.conv2 = BasicConv(output_channel_block[1], output_channel_block[1], kernel_size=3, stride=1, padding=1, 78 | bias=False, use_bn=True, use_relu=True) 79 | 80 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 81 | self.layer3 = self._make_layer(block, output_channel_block[2], layers[2], stride=1) 82 | self.conv3 = BasicConv(output_channel_block[2], output_channel_block[2], kernel_size=3, stride=1, padding=1, 83 | bias=False, use_bn=True, use_relu=True) 84 | 85 | self.layer4 = self._make_layer(block, output_channel_block[3], layers[3], stride=1) 86 | self.conv4 = nn.Sequential( 87 | BasicConv(output_channel_block[3], output_channel_block[3], kernel_size=2, stride=(2, 1), padding=(0, 1), 88 | bias=False, use_bn=True, use_relu=True), 89 | BasicConv(output_channel_block[3], output_channel_block[3], kernel_size=2, stride=1, padding=0, bias=False, 90 | use_bn=True, use_relu=True) 91 | ) 92 | 93 | def forward(self, x): 94 | x = self.conv0(x) 95 | 96 | x = self.maxpool1(x) 97 | x = self.layer1(x) 98 | x = self.conv1(x) 99 | 100 | x = self.maxpool2(x) 101 | x = self.layer2(x) 102 | x = self.conv2(x) 103 | 104 | x = self.maxpool3(x) 105 | x = self.layer3(x) 106 | x = self.conv3(x) 107 | 108 | x = self.layer4(x) 109 | x = self.conv4(x) 110 | return x 111 | 112 | 113 | class ResNet_MT(ReaNet): 114 | """ resnet of ReADS arxiv.org/pdf/2004.02070.pdf""" 115 | 116 | def __init__(self, in_channels, out_channels=512): 117 | super().__init__() 118 | layers = [3, 4, 6, 3] 119 | block = BasicBlock 120 | output_channel_block = [int(out_channels / 16), int(out_channels / 8), int(out_channels / 4), 121 | int(out_channels / 2), out_channels] 122 | 123 | self.inplanes = output_channel_block[0] 124 | self.conv0 = BasicConv(in_channels, output_channel_block[0], kernel_size=3, stride=1, padding=1, bias=False, 125 | use_bn=True, use_relu=True) 126 | 127 | self.layer1 = self._make_layer(block, output_channel_block[1], layers[0], use_cbam=True) 128 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 129 | 130 | self.layer2 = self._make_layer(block, output_channel_block[2], layers[1], stride=1, use_cbam=True) 131 | self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) 132 | 133 | self.layer3 = self._make_layer(block, output_channel_block[3], layers[2], stride=1, use_cbam=True) 134 | self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) 135 | 136 | self.layer4 = self._make_layer(block, output_channel_block[4], layers[3], stride=1, use_cbam=True) 137 | self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) 138 | self.conv4 = BasicConv(output_channel_block[4], out_channels, kernel_size=2, stride=2, bias=False, 139 | use_bn=False, use_relu=False) 140 | 141 | def forward(self, x): 142 | x = self.conv0(x) 143 | 144 | x = self.layer1(x) 145 | x = self.maxpool1(x) 146 | 147 | x = self.layer2(x) 148 | x = self.maxpool2(x) 149 | 150 | x = self.layer3(x) 151 | x = self.maxpool3(x) 152 | 153 | x = self.layer4(x) 154 | x = self.maxpool4(x) 155 | x = self.conv4(x) 156 | return x 157 | 158 | 159 | 160 | 161 | if __name__ == '__main__': 162 | import torch 163 | import time 164 | 165 | net = ResNet_FeatureExtractor(3, 512) 166 | a = torch.rand((1, 3, 32, 320)) 167 | tic = time.time() 168 | for i in range(1): 169 | b = net(a) 170 | print(b.shape) 171 | print((time.time() - tic) / 1) 172 | # print(net) 173 | -------------------------------------------------------------------------------- /modeling/backbone/resnet_torch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/4/14 17:28 3 | # @Author : zhoujun 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.models.utils import load_state_dict_from_url 7 | 8 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 22 | } 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=dilation, groups=groups, bias=False, dilation=dilation) 29 | 30 | 31 | def conv1x1(in_planes, out_planes, stride=1): 32 | """1x1 convolution""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | __constants__ = ['downsample'] 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 41 | base_width=64, dilation=1, norm_layer=None): 42 | super(BasicBlock, self).__init__() 43 | if norm_layer is None: 44 | norm_layer = nn.BatchNorm2d 45 | if groups != 1 or base_width != 64: 46 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 47 | if dilation > 1: 48 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 49 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 50 | self.conv1 = conv3x3(inplanes, planes, stride) 51 | self.bn1 = norm_layer(planes) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.conv2 = conv3x3(planes, planes) 54 | self.bn2 = norm_layer(planes) 55 | self.downsample = downsample 56 | self.stride = stride 57 | 58 | def forward(self, x): 59 | identity = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | identity = self.downsample(x) 70 | 71 | out += identity 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | expansion = 4 79 | __constants__ = ['downsample'] 80 | 81 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 82 | base_width=64, dilation=1, norm_layer=None): 83 | super(Bottleneck, self).__init__() 84 | if norm_layer is None: 85 | norm_layer = nn.BatchNorm2d 86 | width = int(planes * (base_width / 64.)) * groups 87 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 88 | self.conv1 = conv1x1(inplanes, width) 89 | self.bn1 = norm_layer(width) 90 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 91 | self.bn2 = norm_layer(width) 92 | self.conv3 = conv1x1(width, planes * self.expansion) 93 | self.bn3 = norm_layer(planes * self.expansion) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | identity = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | identity = self.downsample(x) 114 | 115 | out += identity 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class ResNet(nn.Module): 122 | 123 | def __init__(self, block, layers, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, 124 | norm_layer=None, **kwargs): 125 | super(ResNet, self).__init__() 126 | if norm_layer is None: 127 | norm_layer = nn.BatchNorm2d 128 | self._norm_layer = norm_layer 129 | in_channels = kwargs.get('in_channels', 3) 130 | self.out_channels = kwargs.get('out_channels', 512) 131 | self.inplanes = 64 132 | self.dilation = 1 133 | if replace_stride_with_dilation is None: 134 | # each element in the tuple indicates if we should replace 135 | # the 2x2 stride with a dilated convolution instead 136 | replace_stride_with_dilation = [False, False, False] 137 | if len(replace_stride_with_dilation) != 3: 138 | raise ValueError("replace_stride_with_dilation should be None " 139 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 140 | self.groups = groups 141 | self.base_width = width_per_group 142 | self.conv1 = nn.Sequential( 143 | nn.Conv2d(in_channels, self.inplanes, kernel_size=3, stride=2, padding=1, bias=False), 144 | norm_layer(self.inplanes), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), 147 | norm_layer(self.inplanes), 148 | nn.ReLU(inplace=True) 149 | ) 150 | 151 | self.layer1 = self._make_layer(block, 64, layers[0]) 152 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 153 | self.layer2 = self._make_layer(block, 128, layers[1], stride=1, dilate=replace_stride_with_dilation[0]) 154 | self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) 155 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilate=replace_stride_with_dilation[1]) 156 | self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) 157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilate=replace_stride_with_dilation[2]) 158 | 159 | self.out_conv = nn.Sequential( 160 | nn.Conv2d(512 * block.expansion, self.out_channels, kernel_size=2, stride=(2, 1), bias=False), 161 | norm_layer(self.out_channels), 162 | nn.ReLU(), 163 | ) 164 | 165 | for m in self.modules(): 166 | if isinstance(m, nn.Conv2d): 167 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 168 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 169 | nn.init.constant_(m.weight, 1) 170 | nn.init.constant_(m.bias, 0) 171 | 172 | # Zero-initialize the last BN in each residual branch, 173 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 174 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 175 | if zero_init_residual: 176 | for m in self.modules(): 177 | if isinstance(m, Bottleneck): 178 | nn.init.constant_(m.bn3.weight, 0) 179 | elif isinstance(m, BasicBlock): 180 | nn.init.constant_(m.bn2.weight, 0) 181 | 182 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 183 | norm_layer = self._norm_layer 184 | downsample = None 185 | previous_dilation = self.dilation 186 | if dilate: 187 | self.dilation *= stride 188 | stride = 1 189 | if stride != 1 or self.inplanes != planes * block.expansion: 190 | downsample = nn.Sequential( 191 | conv1x1(self.inplanes, planes * block.expansion, stride), 192 | norm_layer(planes * block.expansion), 193 | ) 194 | 195 | layers = [] 196 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 197 | self.base_width, previous_dilation, norm_layer)) 198 | self.inplanes = planes * block.expansion 199 | for _ in range(1, blocks): 200 | layers.append(block(self.inplanes, planes, groups=self.groups, 201 | base_width=self.base_width, dilation=self.dilation, 202 | norm_layer=norm_layer)) 203 | 204 | return nn.Sequential(*layers) 205 | 206 | def _forward_impl(self, x): 207 | # See note [TorchScript super()] 208 | x = self.conv1(x) 209 | 210 | x = self.layer1(x) 211 | x = self.maxpool1(x) 212 | 213 | x = self.layer2(x) 214 | x = self.maxpool2(x) 215 | 216 | x = self.layer3(x) 217 | x = self.maxpool3(x) 218 | 219 | x = self.layer4(x) 220 | x = self.out_conv(x) 221 | return x 222 | 223 | def forward(self, x): 224 | return self._forward_impl(x) 225 | 226 | 227 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 228 | model = ResNet(block, layers, **kwargs) 229 | if pretrained: 230 | state_dict = load_state_dict_from_url(model_urls[arch], 231 | progress=progress) 232 | model.load_state_dict(state_dict) 233 | return model 234 | 235 | 236 | def resnet18(pretrained=False, progress=True, **kwargs): 237 | r"""ResNet-18 model from 238 | `"Deep Residual Learning for Image Recognition" `_ 239 | 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet34(pretrained=False, progress=True, **kwargs): 249 | r"""ResNet-34 model from 250 | `"Deep Residual Learning for Image Recognition" `_ 251 | 252 | Args: 253 | pretrained (bool): If True, returns a model pre-trained on ImageNet 254 | progress (bool): If True, displays a progress bar of the download to stderr 255 | """ 256 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 257 | **kwargs) 258 | 259 | 260 | def resnet50(pretrained=False, progress=True, **kwargs): 261 | r"""ResNet-50 model from 262 | `"Deep Residual Learning for Image Recognition" `_ 263 | 264 | Args: 265 | pretrained (bool): If True, returns a model pre-trained on ImageNet 266 | progress (bool): If True, displays a progress bar of the download to stderr 267 | """ 268 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 269 | **kwargs) 270 | 271 | 272 | def resnet101(pretrained=False, progress=True, **kwargs): 273 | r"""ResNet-101 model from 274 | `"Deep Residual Learning for Image Recognition" `_ 275 | 276 | Args: 277 | pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | progress (bool): If True, displays a progress bar of the download to stderr 279 | """ 280 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 281 | **kwargs) 282 | 283 | 284 | def resnet152(pretrained=False, progress=True, **kwargs): 285 | r"""ResNet-152 model from 286 | `"Deep Residual Learning for Image Recognition" `_ 287 | 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 293 | **kwargs) 294 | 295 | 296 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 297 | r"""ResNeXt-50 32x4d model from 298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | kwargs['groups'] = 32 305 | kwargs['width_per_group'] = 4 306 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 307 | pretrained, progress, **kwargs) 308 | 309 | 310 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 311 | r"""ResNeXt-101 32x8d model from 312 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 313 | 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | progress (bool): If True, displays a progress bar of the download to stderr 317 | """ 318 | kwargs['groups'] = 32 319 | kwargs['width_per_group'] = 8 320 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 321 | pretrained, progress, **kwargs) 322 | 323 | 324 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 325 | r"""Wide ResNet-50-2 model from 326 | `"Wide Residual Networks" `_ 327 | 328 | The model is the same as ResNet except for the bottleneck number of channels 329 | which is twice larger in every block. The number of channels in outer 1x1 330 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 331 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 332 | 333 | Args: 334 | pretrained (bool): If True, returns a model pre-trained on ImageNet 335 | progress (bool): If True, displays a progress bar of the download to stderr 336 | """ 337 | kwargs['width_per_group'] = 64 * 2 338 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 339 | pretrained, progress, **kwargs) 340 | 341 | 342 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 343 | r"""Wide ResNet-101-2 model from 344 | `"Wide Residual Networks" `_ 345 | 346 | The model is the same as ResNet except for the bottleneck number of channels 347 | which is twice larger in every block. The number of channels in outer 1x1 348 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 349 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 350 | 351 | Args: 352 | pretrained (bool): If True, returns a model pre-trained on ImageNet 353 | progress (bool): If True, displays a progress bar of the download to stderr 354 | """ 355 | kwargs['width_per_group'] = 64 * 2 356 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 357 | pretrained, progress, **kwargs) 358 | 359 | 360 | if __name__ == '__main__': 361 | import torch 362 | 363 | net = resnet50(in_channels=3, out_channels=512) 364 | x = torch.rand((1, 3, 32, 320)) 365 | y = net(x) 366 | print(y.shape) 367 | # print(net) 368 | -------------------------------------------------------------------------------- /modeling/basic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/20 11:14 3 | # @Author : zhoujun 4 | import math 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class BasicConv(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', use_bn=True, 11 | use_relu=True, inplace=True): 12 | super().__init__() 13 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 14 | groups=groups, bias=bias, padding_mode=padding_mode) 15 | self.bn = nn.BatchNorm2d(out_channels) if use_bn else None 16 | self.relu = nn.ReLU(inplace=inplace) if use_relu else None 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | if self.bn is not None: 21 | x = self.bn(x) 22 | if self.relu is not None: 23 | x = self.relu(x) 24 | return x 25 | 26 | class BasicBlockV2(nn.Module): 27 | def __init__(self, in_channels, out_channels, stride, kernel_size=3, downsample=True, use_cbam=False, **kwargs): 28 | super(BasicBlockV2, self).__init__() 29 | self.se = CBAM(out_channels) if use_cbam else None 30 | self.bn1 = nn.BatchNorm2d(in_channels, momentum=0.9) 31 | self.relu1 = nn.ReLU() 32 | self.conv = nn.Sequential( 33 | BasicConv(in_channels, out_channels, kernel_size, stride, kernel_size // 2, bias=False), 34 | BasicConv(out_channels, out_channels, kernel_size, 1, kernel_size // 2, bias=False, use_bn=False, use_relu=False), 35 | ) 36 | if downsample: 37 | self.downsample = BasicConv(in_channels, out_channels, 1, stride, bias=False, use_bn=False, use_relu=False) 38 | else: 39 | self.downsample = None 40 | 41 | def forward(self, x): 42 | residual = x 43 | x = self.bn1(x) 44 | x = self.relu1(x) 45 | if self.downsample: 46 | residual = self.downsample(x) 47 | x = self.conv(x) 48 | if self.se != None: 49 | x = self.se(x) 50 | return x + residual 51 | 52 | 53 | class DWConv(nn.Module): 54 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, use_bn=False,**kwargs): 55 | super().__init__() 56 | self.use_bn = use_bn 57 | self.DWConv = BasicConv(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, use_bn=use_bn) 58 | self.conv1x1 = BasicConv(in_channels, out_channels, 1, 1, 0, use_bn=use_bn) 59 | 60 | def forward(self, x): 61 | x = self.DWConv(x) 62 | x = self.conv1x1(x) 63 | return x 64 | 65 | 66 | class DWBlock(nn.Module): 67 | '''expand + depthwise + pointwise''' 68 | 69 | def __init__(self, in_channels, out_channels, expand_size, kernel_size, stride, use_cbam=False, **kwargs): 70 | super().__init__() 71 | self.stride = stride 72 | self.se = CBAM(out_channels) if use_cbam else None 73 | # pw 74 | self.conv1 = BasicConv(in_channels, expand_size, kernel_size=1, stride=1, padding=0, bias=False) 75 | # dw 76 | self.conv2 = BasicConv(expand_size, expand_size, kernel_size, stride, padding=kernel_size // 2, groups=expand_size) 77 | # pw 78 | self.conv3 = BasicConv(expand_size, out_channels, kernel_size=1, stride=1, padding=0, bias=False, use_relu=False) 79 | 80 | self.shortcut = nn.Sequential() 81 | if stride == 1 and in_channels != out_channels: 82 | self.shortcut = BasicConv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, use_relu=False) 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out = self.conv2(out) 87 | out = self.conv3(out) 88 | if self.se != None: 89 | out = self.se(out) 90 | out = out + self.shortcut(x) if self.stride == 1 else out 91 | return out 92 | 93 | 94 | def _make_divisible(v, divisor, min_value=None): 95 | """ 96 | This function is taken from the original tf repo. 97 | It ensures that all layers have a channel number that is divisible by 8 98 | It can be seen here: 99 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 100 | """ 101 | if min_value is None: 102 | min_value = divisor 103 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 104 | # Make sure that round down does not go down by more than 10%. 105 | if new_v < 0.9 * v: 106 | new_v += divisor 107 | return new_v 108 | 109 | 110 | class ChannelAttention(nn.Module): 111 | def __init__(self, channel, reduction=16, use_max_pool=True): 112 | """ 113 | 当 use_max_pool为空时,变成SELayer 114 | Args: 115 | channel: 116 | reduction: 117 | use_max_pool: 118 | """ 119 | super().__init__() 120 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 121 | self.max_pool = nn.AdaptiveMaxPool2d((1, 1)) if use_max_pool else None 122 | self.fc = nn.Sequential( 123 | BasicConv(channel, channel // reduction, kernel_size=1, bias=False,use_bn=False), 124 | BasicConv(channel // reduction, channel, kernel_size=1, bias=False,use_bn=False,use_relu=False), 125 | ) 126 | self.sigmoid = nn.Sigmoid() 127 | 128 | def forward(self, x): 129 | y1 = self.avg_pool(x) 130 | y1 = self.fc(y1) 131 | if self.max_pool is not None: 132 | y2 = self.max_pool(x) 133 | y2 = self.fc(y2) 134 | y = y1 + y2 135 | else: 136 | y = y1 137 | y = self.sigmoid(y) 138 | return x * y 139 | 140 | 141 | class SpartialAttention(nn.Module): 142 | def __init__(self, kernel_size=7): 143 | super().__init__() 144 | assert kernel_size % 2 == 1, f"kernel_size = {kernel_size}" 145 | padding = (kernel_size - 1) // 2 146 | self.layer = nn.Sequential( 147 | BasicConv(2, 1, kernel_size=kernel_size, padding=padding), 148 | nn.Sigmoid(), 149 | ) 150 | 151 | def forward(self, x): 152 | avg_mask = torch.mean(x, dim=1, keepdim=True) 153 | max_mask, _ = torch.max(x, dim=1, keepdim=True) 154 | mask = torch.cat([avg_mask, max_mask], dim=1) 155 | mask = self.layer(mask) 156 | return x * mask 157 | 158 | 159 | class CBAM(nn.Module): 160 | def __init__(self, gate_channels, reduction_ratio=16, no_spatial=False): 161 | super(CBAM, self).__init__() 162 | self.ChannelGate = ChannelAttention(gate_channels, reduction_ratio) 163 | self.no_spatial = no_spatial 164 | if not no_spatial: 165 | self.SpatialGate = SpartialAttention() 166 | 167 | def forward(self, x): 168 | x_out = self.ChannelGate(x) 169 | if not self.no_spatial: 170 | x_out = self.SpatialGate(x_out) 171 | return x_out 172 | 173 | 174 | class GhostModule(nn.Module): 175 | def __init__(self, in_channels, out_channels, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True, **kwargs): 176 | super().__init__() 177 | self.oup = out_channels 178 | init_channels = math.ceil(out_channels / ratio) 179 | new_channels = init_channels * (ratio - 1) 180 | 181 | self.primary_conv = BasicConv(in_channels, init_channels, kernel_size, stride, kernel_size // 2, use_relu=relu) 182 | 183 | self.cheap_operation = BasicConv(init_channels, new_channels, dw_size, 1, dw_size // 2, groups=init_channels, bias=False, use_relu=relu) 184 | 185 | def forward(self, x): 186 | x1 = self.primary_conv(x) 187 | x2 = self.cheap_operation(x1) 188 | out = torch.cat([x1, x2], dim=1) 189 | return out[:, :self.oup, :, :] 190 | 191 | 192 | class GhostBottleneck(nn.Module): 193 | def __init__(self, in_channels, out_channels, expand_size, kernel_size, stride, use_cbam=False): 194 | super().__init__() 195 | assert stride in [1, 2] 196 | 197 | self.conv = nn.Sequential( 198 | # pw 199 | GhostModule(in_channels, expand_size, kernel_size=1, relu=True), 200 | # dw 201 | BasicConv(expand_size, expand_size, kernel_size, stride, kernel_size // 2, use_relu=False) if stride == 2 else nn.Sequential(), 202 | # Squeeze-and-Excite 203 | CBAM(expand_size) if use_cbam else nn.Sequential(), 204 | # pw-linear 205 | GhostModule(expand_size, out_channels, kernel_size=1, relu=False), 206 | ) 207 | 208 | if stride == 1 and in_channels == out_channels: 209 | self.shortcut = nn.Sequential() 210 | else: 211 | self.shortcut = nn.Sequential( 212 | BasicConv(in_channels, in_channels, 3, stride, kernel_size // 2, use_relu=True), 213 | BasicConv(in_channels, out_channels, 1, 1, 0, bias=False, use_relu=False) 214 | ) 215 | 216 | def forward(self, x): 217 | return self.conv(x) + self.shortcut(x) 218 | -------------------------------------------------------------------------------- /modeling/head/Attn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zhoujun 5 | @time: 2019/12/19 下午2:56 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class Attn(nn.Module): 14 | 15 | def __init__(self, in_channels, hidden_size, n_class, **kwargs): 16 | super(Attn, self).__init__() 17 | self.attention_cell = AttentionCell(in_channels, hidden_size, n_class) 18 | self.hidden_size = hidden_size 19 | self.num_classes = n_class 20 | self.generator = nn.Linear(hidden_size, n_class) 21 | 22 | def _char_to_onehot(self, input_char, onehot_dim=38): 23 | input_char = input_char.unsqueeze(1) 24 | batch_size = input_char.size(0) 25 | one_hot = torch.zeros(batch_size, onehot_dim).to(input_char.device) 26 | one_hot = one_hot.scatter_(1, input_char, 1) 27 | return one_hot 28 | 29 | def forward(self, batch_H, text, batch_max_length=25): 30 | """ 31 | input: 32 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes] 33 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 34 | output: probability distribution at each step [batch_size x num_steps x num_classes] 35 | """ 36 | batch_size = batch_H.size(0) 37 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 38 | 39 | output_hiddens = torch.zeros(batch_size, num_steps, self.hidden_size).to(batch_H.device) 40 | hidden = (torch.zeros(batch_size, self.hidden_size).to(batch_H.device), 41 | torch.zeros(batch_size, self.hidden_size).to(batch_H.device)) 42 | 43 | if self.training: 44 | for i in range(num_steps): 45 | # one-hot vectors for a i-th char. in a batch 46 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 47 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 48 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 49 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 50 | probs = self.generator(output_hiddens) 51 | 52 | else: 53 | targets = torch.zeros(batch_size, dtype=torch.long).to(batch_H.device) # [GO] token 54 | probs = torch.zeros(batch_size, num_steps, self.num_classes).to(batch_H.device) 55 | 56 | for i in range(num_steps): 57 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 58 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 59 | probs_step = self.generator(hidden[0]) 60 | probs[:, i, :] = probs_step 61 | _, next_input = probs_step.max(1) 62 | if next_input[0] == 1: # meet end-of-sentence token 63 | break 64 | targets = next_input 65 | 66 | return probs # batch_size x num_steps x num_classes 67 | 68 | 69 | class AttentionCell(nn.Module): 70 | 71 | def __init__(self, input_size, hidden_size, num_embeddings): 72 | super(AttentionCell, self).__init__() 73 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 74 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 75 | self.score = nn.Linear(hidden_size, 1, bias=False) 76 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 77 | self.hidden_size = hidden_size 78 | 79 | def forward(self, prev_hidden, batch_H, char_onehots): 80 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 81 | batch_H_proj = self.i2h(batch_H) 82 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 83 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 84 | 85 | alpha = F.softmax(e, dim=1) 86 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 87 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 88 | cur_hidden = self.rnn(concat_context, prev_hidden) 89 | return cur_hidden, alpha 90 | -------------------------------------------------------------------------------- /modeling/head/CTC.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/17 10:57 3 | # @Author : zhoujun 4 | from torch import nn 5 | 6 | 7 | class CTC(nn.Module): 8 | def __init__(self, in_channels, n_class, **kwargs): 9 | super().__init__() 10 | self.fc = nn.Linear(in_channels, n_class) 11 | 12 | def forward(self, x): 13 | return self.fc(x) 14 | -------------------------------------------------------------------------------- /modeling/head/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/17 10:57 3 | # @Author : zhoujun 4 | 5 | from .CTC import CTC 6 | from .Attn import Attn 7 | 8 | __all__ = ['build_head'] 9 | support_head = ['CTC', 'Attn'] 10 | 11 | 12 | def build_head(head_name, **kwargs): 13 | assert head_name in support_head, f'all support head is {support_head}' 14 | head = eval(head_name)(**kwargs) 15 | return head 16 | -------------------------------------------------------------------------------- /modeling/losses/AttnLoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/17 15:03 3 | # @Author : zhoujun 4 | 5 | from torch import nn 6 | 7 | 8 | class AttnLoss(nn.Module): 9 | def __init__(self, ignore_index=0): 10 | super().__init__() 11 | self.func = nn.CrossEntropyLoss(ignore_index=ignore_index) 12 | 13 | def forward(self, preds, batch_data): 14 | target = batch_data['targets'][:, 1:] # without [GO] Symbol 15 | loss = self.func(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 16 | return loss 17 | -------------------------------------------------------------------------------- /modeling/losses/CTCLoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/17 11:17 3 | # @Author : zhoujun 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class CTCLoss(nn.Module): 9 | def __init__(self, blank=0, zero_infinity=True): 10 | super().__init__() 11 | self.func = nn.CTCLoss(blank=blank, zero_infinity=zero_infinity) 12 | 13 | def forward(self, preds, batch_data): 14 | cur_batch_size = batch_data['img'].shape[0] 15 | targets = batch_data['targets'] 16 | targets_lengths = batch_data['targets_lengths'] 17 | preds = preds.log_softmax(2) 18 | preds_lengths = torch.tensor([preds.size(1)] * cur_batch_size, dtype=torch.long) 19 | loss = self.func(preds.permute(1, 0, 2), targets, preds_lengths, targets_lengths) 20 | return loss 21 | -------------------------------------------------------------------------------- /modeling/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/17 11:17 3 | # @Author : zhoujun 4 | import copy 5 | from .CTCLoss import CTCLoss 6 | from .AttnLoss import AttnLoss 7 | 8 | __all__ = ['build_loss'] 9 | support_loss = ['CTCLoss', 'AttnLoss'] 10 | 11 | 12 | def build_loss(config): 13 | copy_config = copy.deepcopy(config) 14 | loss_type = copy_config.pop('type') 15 | assert loss_type in support_loss, f'all support loss is {support_loss}' 16 | criterion = eval(loss_type)(**copy_config) 17 | return criterion 18 | -------------------------------------------------------------------------------- /modeling/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from addict import Dict 3 | import torch 4 | from torch import nn 5 | 6 | from modeling.trans import build_trans 7 | from modeling.backbone import build_backbone 8 | from modeling.neck import build_neck 9 | from modeling.head import build_head 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, config): 14 | super(Model, self).__init__() 15 | model_config = copy.deepcopy(config) 16 | model_config = Dict(model_config) 17 | 18 | trans_type = model_config.trans.pop('type') 19 | backbone_type = model_config.backbone.pop('type') 20 | neck_type = model_config.neck.pop('type') 21 | self.head_type = model_config.head.pop('type') 22 | if self.head_type == 'Attention': 23 | assert neck_type == 'RNNDecoder' 24 | self.trans = build_trans(trans_type, **model_config.backbone) 25 | self.backbone = build_backbone(backbone_type, **model_config.backbone) 26 | self.neck = build_neck(neck_type, in_channels=self.backbone.out_channels, **model_config.neck) 27 | self.head = build_head(self.head_type, in_channels=self.neck.out_channels, **model_config.head) 28 | 29 | self.name = f'RecModel_{trans_type}_{backbone_type}_{neck_type}_{self.head_type}' 30 | self.batch_max_length = -1 31 | self.init() 32 | 33 | def get_batch_max_length(self, x): 34 | # 特征提取阶段 35 | 36 | if self.trans is not None: 37 | x = self.trans(x) 38 | x = self.backbone(x) 39 | self.batch_max_length = x.shape[-1] 40 | return self.batch_max_length 41 | 42 | def init(self): 43 | import torch.nn.init as init 44 | # weight initialization 45 | for name, param in self.named_parameters(): 46 | if 'localization_fc2' in name: 47 | print(f'Skip {name} as it is already initialized') 48 | continue 49 | try: 50 | if 'bias' in name: 51 | init.constant_(param, 0.0) 52 | elif 'weight' in name: 53 | init.kaiming_normal_(param) 54 | except Exception as e: # for batchnorm. 55 | if 'weight' in name: 56 | param.data.fill_(1) 57 | continue 58 | 59 | def forward(self, x, text=None): 60 | if self.trans is not None: 61 | x = self.trans(x) 62 | y = self.backbone(x) 63 | y = self.neck(y) 64 | # 预测阶段 65 | if self.head_type == 'CTC': 66 | y = self.head(y) 67 | elif self.head_type == 'Attn': 68 | y = self.head(y, text, self.batch_max_length) 69 | else: 70 | raise NotImplementedError 71 | return y, x 72 | 73 | 74 | if __name__ == '__main__': 75 | import os 76 | import anyconfig 77 | from utils import parse_config, load, get_parameter_number 78 | 79 | config = anyconfig.load(open("config/imagedataset_None_VGG_RNN_CTC.yaml", 'rb')) 80 | if 'base' in config: 81 | config = parse_config(config) 82 | if os.path.isfile(config['dataset']['alphabet']): 83 | config['dataset']['alphabet'] = load(config['dataset']['alphabet']) 84 | 85 | device = torch.device('cpu') 86 | config['arch']['backbone']['in_channels'] = 3 87 | config['arch']['head']['n_class'] = 95 88 | net = Model(config['arch']).to(device) 89 | print(net.name, len(config['dataset']['alphabet'])) 90 | a = torch.randn(2, 3, 32, 320).to(device) 91 | 92 | import time 93 | 94 | text_for_pred = torch.LongTensor(2, 25 + 1).fill_(0) 95 | tic = time.time() 96 | for i in range(1): 97 | b = net(a, text_for_pred)[0] 98 | print(b.shape) 99 | print((time.time() - tic) / 1) 100 | print(get_parameter_number(net)) 101 | -------------------------------------------------------------------------------- /modeling/modules/seg/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zhoujun 5 | @time: 2019/12/17 下午1:51 6 | ''' 7 | from modeling.modules.seg.resnet_fpn import ResNetFPN 8 | from modeling.modules.seg.unet import UNet -------------------------------------------------------------------------------- /modeling/modules/seg/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/12/20 14:47 3 | # @Author : zhoujun 4 | 5 | import torch.nn as nn 6 | import math 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | BatchNorm2d = nn.BatchNorm2d 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'deformable_resnet18', 'deformable_resnet50', 12 | 'resnet152'] 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | } 21 | 22 | 23 | def constant_init(module, constant, bias=0): 24 | nn.init.constant_(module.weight, constant) 25 | if hasattr(module, 'bias'): 26 | nn.init.constant_(module.bias, bias) 27 | 28 | 29 | def conv3x3(in_planes, out_planes, stride=1): 30 | """3x3 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=1, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): 39 | super(BasicBlock, self).__init__() 40 | self.with_dcn = dcn is not None 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.with_modulated_dcn = False 45 | if not self.with_dcn: 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 47 | else: 48 | from torchvision.ops import DeformConv2d 49 | deformable_groups = dcn.get('deformable_groups', 1) 50 | offset_channels = 18 51 | self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1) 52 | self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, bias=False) 53 | self.bn2 = BatchNorm2d(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | residual = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | # out = self.conv2(out) 65 | if not self.with_dcn: 66 | out = self.conv2(out) 67 | else: 68 | offset = self.conv2_offset(out) 69 | out = self.conv2(out, offset) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): 85 | super(Bottleneck, self).__init__() 86 | self.with_dcn = dcn is not None 87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 88 | self.bn1 = BatchNorm2d(planes) 89 | self.with_modulated_dcn = False 90 | if not self.with_dcn: 91 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 92 | else: 93 | deformable_groups = dcn.get('deformable_groups', 1) 94 | from torchvision.ops import DeformConv2d 95 | offset_channels = 18 96 | self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1) 97 | self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) 98 | self.bn2 = BatchNorm2d(planes) 99 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 100 | self.bn3 = BatchNorm2d(planes * 4) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.downsample = downsample 103 | self.stride = stride 104 | self.dcn = dcn 105 | self.with_dcn = dcn is not None 106 | 107 | def forward(self, x): 108 | residual = x 109 | 110 | out = self.conv1(x) 111 | out = self.bn1(out) 112 | out = self.relu(out) 113 | 114 | # out = self.conv2(out) 115 | if not self.with_dcn: 116 | out = self.conv2(out) 117 | else: 118 | offset = self.conv2_offset(out) 119 | out = self.conv2(out, offset) 120 | out = self.bn2(out) 121 | out = self.relu(out) 122 | 123 | out = self.conv3(out) 124 | out = self.bn3(out) 125 | 126 | if self.downsample is not None: 127 | residual = self.downsample(x) 128 | 129 | out += residual 130 | out = self.relu(out) 131 | 132 | return out 133 | 134 | 135 | class ResNet(nn.Module): 136 | def __init__(self, block, layers, num_classes=1000, dcn=None): 137 | self.dcn = dcn 138 | self.inplanes = 64 139 | super(ResNet, self).__init__() 140 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 141 | bias=False) 142 | self.bn1 = BatchNorm2d(64) 143 | self.relu = nn.ReLU(inplace=True) 144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 145 | self.layer1 = self._make_layer(block, 64, layers[0]) 146 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dcn=dcn) 147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dcn=dcn) 148 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dcn=dcn) 149 | self.avgpool = nn.AvgPool2d(7, stride=1) 150 | self.fc = nn.Linear(512 * block.expansion, num_classes) 151 | 152 | self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1) 153 | 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 157 | m.weight.data.normal_(0, math.sqrt(2. / n)) 158 | elif isinstance(m, BatchNorm2d): 159 | m.weight.data.fill_(1) 160 | m.bias.data.zero_() 161 | if self.dcn is not None: 162 | for m in self.modules(): 163 | if isinstance(m, Bottleneck) or isinstance(m, BasicBlock): 164 | if hasattr(m, 'conv2_offset'): 165 | constant_init(m.conv2_offset, 0) 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dcn=None): 168 | downsample = None 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | nn.Conv2d(self.inplanes, planes * block.expansion, 172 | kernel_size=1, stride=stride, bias=False), 173 | BatchNorm2d(planes * block.expansion), 174 | ) 175 | 176 | layers = [] 177 | layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn)) 178 | self.inplanes = planes * block.expansion 179 | for i in range(1, blocks): 180 | layers.append(block(self.inplanes, planes, dcn=dcn)) 181 | 182 | return nn.Sequential(*layers) 183 | 184 | def forward(self, x): 185 | x = self.conv1(x) 186 | x = self.bn1(x) 187 | x = self.relu(x) 188 | x = self.maxpool(x) 189 | 190 | x2 = self.layer1(x) 191 | x3 = self.layer2(x2) 192 | x4 = self.layer3(x3) 193 | x5 = self.layer4(x4) 194 | 195 | return x2, x3, x4, x5 196 | 197 | 198 | def resnet18(pretrained=True, **kwargs): 199 | """Constructs a ResNet-18 model. 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 204 | if pretrained: 205 | print('load from imagenet') 206 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 207 | return model 208 | 209 | 210 | def deformable_resnet18(pretrained=True, **kwargs): 211 | """Constructs a ResNet-18 model. 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(BasicBlock, [2, 2, 2, 2], dcn=dict(deformable_groups=1)) 216 | if pretrained: 217 | print('load from imagenet') 218 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 219 | return model 220 | 221 | 222 | def resnet34(pretrained=True, **kwargs): 223 | """Constructs a ResNet-34 model. 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | """ 227 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 228 | if pretrained: 229 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False) 230 | return model 231 | 232 | 233 | def resnet50(pretrained=True, **kwargs): 234 | """Constructs a ResNet-50 model. 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | """ 238 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 239 | if pretrained: 240 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) 241 | return model 242 | 243 | 244 | def deformable_resnet50(pretrained=True, **kwargs): 245 | """Constructs a ResNet-50 model with deformable conv. 246 | Args: 247 | pretrained (bool): If True, returns a model pre-trained on ImageNet 248 | """ 249 | model = ResNet(Bottleneck, [3, 4, 6, 3], dcn=dict(deformable_groups=1) ** kwargs) 250 | if pretrained: 251 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) 252 | return model 253 | 254 | 255 | def resnet101(pretrained=True, **kwargs): 256 | """Constructs a ResNet-101 model. 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | """ 260 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 261 | if pretrained: 262 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False) 263 | return model 264 | 265 | 266 | def resnet152(pretrained=True, **kwargs): 267 | """Constructs a ResNet-152 model. 268 | Args: 269 | pretrained (bool): If True, returns a model pre-trained on ImageNet 270 | """ 271 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 272 | if pretrained: 273 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False) 274 | return model -------------------------------------------------------------------------------- /modeling/modules/seg/resnet_fpn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zhoujun 5 | @time: 2019/12/17 上午11:02 6 | ''' 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from modeling.basic import BasicConv 11 | 12 | from modeling.modules.seg.resnet import * 13 | 14 | 15 | class FPN(nn.Module): 16 | def __init__(self, backbone_out_channels, inner_channels=256): 17 | """ 18 | :param backbone_out_channels: 基础网络输出的维度 19 | :param kwargs: 20 | """ 21 | super().__init__() 22 | inplace = True 23 | self.conv_out_channels = inner_channels 24 | inner_channels = inner_channels // 4 25 | # reduce layers 26 | self.reduce_conv_c2 = BasicConv(backbone_out_channels[0], inner_channels, kernel_size=1, inplace=inplace) 27 | self.reduce_conv_c3 = BasicConv(backbone_out_channels[1], inner_channels, kernel_size=1, inplace=inplace) 28 | self.reduce_conv_c4 = BasicConv(backbone_out_channels[2], inner_channels, kernel_size=1, inplace=inplace) 29 | self.reduce_conv_c5 = BasicConv(backbone_out_channels[3], inner_channels, kernel_size=1, inplace=inplace) 30 | # Smooth layers 31 | self.smooth_p4 = BasicConv(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace) 32 | self.smooth_p3 = BasicConv(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace) 33 | self.smooth_p2 = BasicConv(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace) 34 | 35 | self.conv = nn.Sequential( 36 | nn.Conv2d(self.conv_out_channels, self.conv_out_channels, kernel_size=3, padding=1, stride=1), 37 | nn.BatchNorm2d(self.conv_out_channels), 38 | nn.ReLU(inplace=inplace) 39 | ) 40 | self.out_conv = nn.Conv2d(in_channels=self.conv_out_channels, out_channels=1, kernel_size=1) 41 | 42 | def forward(self, x): 43 | c2, c3, c4, c5 = x 44 | # Top-down 45 | p5 = self.reduce_conv_c5(c5) 46 | p4 = self._upsample_add(p5, self.reduce_conv_c4(c4)) 47 | p4 = self.smooth_p4(p4) 48 | p3 = self._upsample_add(p4, self.reduce_conv_c3(c3)) 49 | p3 = self.smooth_p3(p3) 50 | p2 = self._upsample_add(p3, self.reduce_conv_c2(c2)) 51 | p2 = self.smooth_p2(p2) 52 | 53 | x = self._upsample_cat(p2, p3, p4, p5) 54 | x = self.conv(x) 55 | x = self.out_conv(x) 56 | return x 57 | 58 | def _upsample_add(self, x, y): 59 | return F.interpolate(x, size=y.size()[2:]) + y 60 | 61 | def _upsample_cat(self, p2, p3, p4, p5): 62 | h, w = p2.size()[2:] 63 | p3 = F.interpolate(p3, size=(h, w)) 64 | p4 = F.interpolate(p4, size=(h, w)) 65 | p5 = F.interpolate(p5, size=(h, w)) 66 | return torch.cat([p2, p3, p4, p5], dim=1) 67 | 68 | 69 | class ResNetFPN(nn.Module): 70 | def __init__(self, backbone, pretrained, **kwargs): 71 | """ 72 | PANnet 73 | :param model_config: 模型配置 74 | """ 75 | super().__init__() 76 | self.k = kwargs.get('k', 1) 77 | backbone_dict = { 78 | 'resnet18': {'modeling': resnet18, 'out': [64, 128, 256, 512]}, 79 | 'deformable_resnet18': {'modeling': deformable_resnet18, 'out': [64, 128, 256, 512]}, 80 | 'resnet34': {'modeling': resnet34, 'out': [64, 128, 256, 512]}, 81 | 'resnet50': {'modeling': resnet50, 'out': [256, 512, 1024, 2048]}, 82 | 'deformable_resnet50': {'modeling': deformable_resnet50, 'out': [256, 512, 1024, 2048]}, 83 | 'resnet101': {'modeling': resnet101, 'out': [256, 512, 1024, 2048]}, 84 | 'resnet152': {'modeling': resnet152, 'out': [256, 512, 1024, 2048]}, 85 | } 86 | assert backbone in backbone_dict, f'backbone must in: {backbone_dict}' 87 | backbone_model, backbone_out = backbone_dict[backbone]['modeling'], backbone_dict[backbone]['out'] 88 | self.backbone = backbone_model(pretrained=pretrained) 89 | self.segmentation_head = FPN(backbone_out) 90 | self.out_channels = 1 91 | 92 | def forward(self, x): 93 | _, _, H, W = x.size() 94 | backbone_out = self.backbone(x) 95 | y = self.segmentation_head(backbone_out) 96 | y = torch.sigmoid(y * self.k) 97 | y = F.interpolate(y, size=(H, W)) 98 | return y 99 | -------------------------------------------------------------------------------- /modeling/modules/seg/unet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: zhoujun 5 | @time: 2019/12/17 下午1:51 6 | ''' 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class ConvBlock(nn.Module): 13 | def __init__(self, in_channels, out_channels, kernel_size): 14 | super().__init__() 15 | self.conv = nn.Sequential( 16 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1, bias=False), 17 | nn.BatchNorm2d(out_channels), 18 | nn.LeakyReLU(0.1) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.conv(x) 23 | 24 | 25 | class DownBlock(nn.Module): 26 | def __init__(self, in_channels, out_channels): 27 | super().__init__() 28 | self.conv = nn.Sequential( 29 | ConvBlock(in_channels, out_channels, 3), 30 | ConvBlock(out_channels, out_channels, 3) 31 | ) 32 | 33 | def forward(self, x): 34 | return self.conv(x) 35 | 36 | 37 | class UpBlock(nn.Module): 38 | def __init__(self, in_channels, out_channels, shrink=True, **kwargs): 39 | super().__init__() 40 | self.conv3_0 = ConvBlock(in_channels, out_channels, 3) 41 | if shrink: 42 | self.conv3_1 = ConvBlock(out_channels, int(out_channels / 2), 3) 43 | else: 44 | self.conv3_1 = ConvBlock(out_channels, out_channels, 3) 45 | 46 | def forward(self, x, s): 47 | x = F.interpolate(x, scale_factor=2) 48 | 49 | x = torch.cat([s, x], dim=1) 50 | x = self.conv3_0(x) 51 | x = self.conv3_1(x) 52 | return x 53 | 54 | 55 | class UNet(nn.Module): 56 | def __init__(self, in_channels, **kwargs): 57 | super().__init__() 58 | self.k = kwargs.get('k', 1) 59 | self.stage_channels = [32, 64, 128, 256, 512] 60 | self.d0 = DownBlock(in_channels, self.stage_channels[0]) 61 | 62 | self.d1 = nn.Sequential(nn.MaxPool2d(2, 2, ceil_mode=True), DownBlock(self.stage_channels[0], self.stage_channels[1])) 63 | 64 | self.d2 = nn.Sequential(nn.MaxPool2d(2, 2, ceil_mode=True), DownBlock(self.stage_channels[1], self.stage_channels[2])) 65 | 66 | self.d3 = nn.Sequential(nn.MaxPool2d(2, 2, ceil_mode=True), DownBlock(self.stage_channels[2], self.stage_channels[3])) 67 | 68 | self.d4 = nn.Sequential(nn.MaxPool2d(2, 2, ceil_mode=True), DownBlock(self.stage_channels[3], self.stage_channels[4])) 69 | 70 | self.u3 = UpBlock(self.stage_channels[3] + self.stage_channels[4], self.stage_channels[3], shrink=True) 71 | self.u2 = UpBlock(self.stage_channels[3], self.stage_channels[2], shrink=True) 72 | self.u1 = UpBlock(self.stage_channels[2], self.stage_channels[1], shrink=True) 73 | self.u0 = UpBlock(self.stage_channels[1], self.stage_channels[0], shrink=False) 74 | 75 | self.conv = nn.Conv2d(self.stage_channels[0], 1, 1, bias=False) 76 | self.out_channels = 1 77 | 78 | def forward(self, x): 79 | x0 = self.d0(x) 80 | x1 = self.d1(x0) 81 | x2 = self.d2(x1) 82 | x3 = self.d3(x2) 83 | x4 = self.d4(x3) 84 | 85 | y3 = self.u3(x4, x3) 86 | y2 = self.u2(y3, x2) 87 | y1 = self.u1(y2, x1) 88 | y0 = self.u0(y1, x0) 89 | out = self.conv(y0) 90 | out = torch.sigmoid(out * self.k) 91 | return out 92 | 93 | 94 | if __name__ == '__main__': 95 | 96 | x = torch.zeros(1, 3, 32, 320) 97 | net = UNet(3) 98 | y = net(x) 99 | print(y.shape) 100 | -------------------------------------------------------------------------------- /modeling/neck/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/17 10:49 3 | # @Author : zhoujun 4 | 5 | from .sequence_modeling import RNNDecoder, CNNDecoder, Reshape 6 | 7 | __all__ = ['build_neck'] 8 | 9 | support_neck = ['RNNDecoder', 'CNNDecoder', 'Reshape'] 10 | 11 | 12 | def build_neck(neck_name, **kwargs): 13 | assert neck_name in support_neck, f'all support neck is {support_neck}' 14 | neck = eval(neck_name)(**kwargs) 15 | return neck 16 | -------------------------------------------------------------------------------- /modeling/neck/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from modeling.basic import BasicConv 4 | 5 | __all__ = ['RNNDecoder', 'CNNDecoder', 'Reshape'] 6 | 7 | 8 | class BidirectionalGRU(nn.Module): 9 | def __init__(self, input_size, hidden_size, num_layers, nOut): 10 | super(BidirectionalGRU, self).__init__() 11 | self.rnn = nn.GRU(input_size, hidden_size, num_layers, bidirectional=True, batch_first=True) 12 | self.fc = nn.Linear(hidden_size * 2, nOut) 13 | 14 | def forward(self, x): 15 | x, _ = self.rnn(x) 16 | x = self.fc(x) # [T * b, nOut] 17 | return x 18 | 19 | 20 | class BidirectionalLSTM(nn.Module): 21 | def __init__(self, input_size, hidden_size, use_fc=True): 22 | super(BidirectionalLSTM, self).__init__() 23 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 24 | if use_fc: 25 | self.fc = nn.Linear(hidden_size * 2, hidden_size) 26 | else: 27 | self.fc = None 28 | 29 | def forward(self, x): 30 | x, _ = self.rnn(x) 31 | if self.fc is not None: 32 | x = self.fc(x) 33 | return x 34 | 35 | 36 | class RNNDecoder(nn.Module): 37 | def __init__(self, in_channels, hidden_size=256, **kwargs): 38 | super(RNNDecoder, self).__init__() 39 | self.lstm = nn.Sequential( 40 | BidirectionalLSTM(in_channels, hidden_size, True), 41 | BidirectionalLSTM(hidden_size, hidden_size, True) 42 | ) 43 | self.out_channels = hidden_size 44 | 45 | def forward(self, x): 46 | x = x.squeeze(axis=2) 47 | x = x.permute((0, 2, 1)) # (NTC)(batch, width, channel)s 48 | x = self.lstm(x) 49 | return x 50 | 51 | 52 | class CNNDecoder(nn.Module): 53 | def __init__(self, in_channels, hidden_size=256): 54 | super().__init__() 55 | self.cnn_decoder = nn.Sequential( 56 | BasicConv(in_channels=in_channels, out_channels=hidden_size, kernel_size=3, padding=1, stride=(2, 1), bias=False), 57 | BasicConv(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1, stride=(2, 1), bias=False), 58 | BasicConv(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1, stride=(2, 1), bias=False), 59 | BasicConv(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding=1, stride=(2, 1), bias=False) 60 | ) 61 | self.out_channels = hidden_size 62 | 63 | def forward(self, x): 64 | x = self.cnn_decoder(x) 65 | x = x.squeeze(dim=2) 66 | x = x.permute(0, 2, 1) 67 | return x 68 | 69 | 70 | class Reshape(nn.Module): 71 | def __init__(self, in_channels, **kwargs): 72 | super().__init__() 73 | self.out_channels = in_channels 74 | 75 | def forward(self, x): 76 | B, C, H, W = x.shape 77 | x = x.reshape(B, C, H * W) 78 | x = x.permute((0, 2, 1)) # (NTC)(batch, width, channel)s 79 | return x 80 | -------------------------------------------------------------------------------- /modeling/trans/TPS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class TPS(nn.Module): 8 | """ Rectification Network of RARE, namely TPS based STN """ 9 | 10 | def __init__(self, num_fiducial, input_size, I_r_size=None, in_channels=1, **kwargs): 11 | """ Based on RARE TPS 12 | input: 13 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 14 | I_size : (height, width) of the input image I 15 | I_r_size : (height, width) of the rectified image I_r 16 | I_channel_num : the number of channels of the input image I 17 | output: 18 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 19 | """ 20 | super().__init__() 21 | self.F = num_fiducial 22 | self.I_size = input_size 23 | self.I_r_size = I_r_size if I_r_size is not None else input_size # = (I_r_height, I_r_width) 24 | self.I_channel_num = in_channels 25 | self.out_channels = in_channels 26 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 27 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 28 | 29 | def forward(self, batch_I): 30 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 31 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 32 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 33 | 34 | if torch.__version__ > "1.2.0": 35 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 36 | else: 37 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 38 | 39 | return batch_I_r 40 | 41 | 42 | class LocalizationNetwork(nn.Module): 43 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 44 | 45 | def __init__(self, F, I_channel_num): 46 | super(LocalizationNetwork, self).__init__() 47 | self.F = F 48 | self.I_channel_num = I_channel_num 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 51 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 52 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 53 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 54 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 55 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 56 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 57 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 58 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 59 | ) 60 | 61 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 62 | self.localization_fc2 = nn.Linear(256, self.F * 2) 63 | 64 | # Init fc2 in LocalizationNetwork 65 | self.localization_fc2.weight.data.fill_(0) 66 | """ see RARE paper Fig. 6 (a) """ 67 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 68 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 69 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 70 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 71 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 72 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 73 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 74 | 75 | def forward(self, batch_I): 76 | """ 77 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 78 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 79 | """ 80 | batch_size = batch_I.size(0) 81 | features = self.conv(batch_I).view(batch_size, -1) 82 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 83 | return batch_C_prime 84 | 85 | 86 | class GridGenerator(nn.Module): 87 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 88 | 89 | def __init__(self, F, I_r_size): 90 | """ Generate P_hat and inv_delta_C for later """ 91 | super(GridGenerator, self).__init__() 92 | self.eps = 1e-6 93 | self.I_r_height, self.I_r_width = I_r_size 94 | self.F = F 95 | self.C = self._build_C(self.F) # F x 2 96 | self.P = self._build_P(self.I_r_width, self.I_r_height) 97 | ## for multi-gpu, you need register buffer 98 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 99 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 100 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer 101 | # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 102 | # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 103 | 104 | def _build_C(self, F): 105 | """ Return coordinates of fiducial points in I_r; C """ 106 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 107 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 108 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 109 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 110 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 111 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 112 | return C # F x 2 113 | 114 | def _build_inv_delta_C(self, F, C): 115 | """ Return inv_delta_C which is needed to calculate T """ 116 | hat_C = np.zeros((F, F), dtype=float) # F x F 117 | for i in range(0, F): 118 | for j in range(i, F): 119 | r = np.linalg.norm(C[i] - C[j]) 120 | hat_C[i, j] = r 121 | hat_C[j, i] = r 122 | np.fill_diagonal(hat_C, 1) 123 | hat_C = (hat_C ** 2) * np.log(hat_C) 124 | # print(C.shape, hat_C.shape) 125 | delta_C = np.concatenate( # F+3 x F+3 126 | [ 127 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 128 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 129 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 130 | ], 131 | axis=0 132 | ) 133 | inv_delta_C = np.linalg.inv(delta_C) 134 | return inv_delta_C # F+3 x F+3 135 | 136 | def _build_P(self, I_r_width, I_r_height): 137 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 138 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 139 | P = np.stack( # self.I_r_width x self.I_r_height x 2 140 | np.meshgrid(I_r_grid_x, I_r_grid_y), 141 | axis=2 142 | ) 143 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 144 | 145 | def _build_P_hat(self, F, C, P): 146 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 147 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 148 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 149 | P_diff = P_tile - C_tile # n x F x 2 150 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 151 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 152 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 153 | return P_hat # n x F+3 154 | 155 | def build_P_prime(self, batch_C_prime): 156 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 157 | batch_size = batch_C_prime.size(0) 158 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 159 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 160 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 161 | batch_size, 3, 2).float().to(batch_C_prime.device)), dim=1) # batch_size x F+3 x 2 162 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 163 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 164 | return batch_P_prime # batch_size x n x 2 165 | -------------------------------------------------------------------------------- /modeling/trans/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/17 10:59 3 | # @Author : zhoujun 4 | 5 | from .TPS import TPS 6 | 7 | __all__ = ['build_trans'] 8 | support_trans = ['TPS', 'None'] 9 | 10 | 11 | def build_trans(trans_name, **kwargs): 12 | assert trans_name in support_trans, f'all support head is {support_trans}' 13 | if trans_name == 'None': 14 | return None 15 | head = eval(trans_name)(**kwargs) 16 | return head 17 | -------------------------------------------------------------------------------- /msyh.ttc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/crnn.pytorch/bf7a7c62376eee93943ca7c68e88e3d563c09aa8/msyh.ttc -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/8/23 22:21 3 | # @Author : zhoujun 4 | import os 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from utils import CTCLabelConverter,AttnLabelConverter 9 | 10 | from data_loader import get_transforms 11 | 12 | class PytorchNet: 13 | def __init__(self, model_path, gpu_id=None): 14 | """ 15 | 初始化模型 16 | :param model_path: 模型地址 17 | :param gpu_id: 在哪一块gpu上运行 18 | """ 19 | checkpoint = torch.load(model_path) 20 | print(f"load {checkpoint['epoch']} epoch params") 21 | config = checkpoint['config'] 22 | alphabet = config['dataset']['alphabet'] 23 | if gpu_id is not None and isinstance(gpu_id, int) and torch.cuda.is_available(): 24 | self.device = torch.device("cuda:%s" % gpu_id) 25 | else: 26 | self.device = torch.device("cpu") 27 | print('device:', self.device) 28 | 29 | self.transform = [] 30 | for t in config['dataset']['train']['dataset']['args']['transforms']: 31 | if t['type'] in ['ToTensor', 'Normalize']: 32 | self.transform.append(t) 33 | self.transform = get_transforms(self.transform) 34 | 35 | self.gpu_id = gpu_id 36 | img_h, img_w = 32, 100 37 | for process in config['dataset']['train']['dataset']['args']['pre_processes']: 38 | if process['type'] == "Resize": 39 | img_h = process['args']['img_h'] 40 | img_w = process['args']['img_w'] 41 | break 42 | self.img_w = img_w 43 | self.img_h = img_h 44 | self.img_mode = config['dataset']['train']['dataset']['args']['img_mode'] 45 | self.alphabet = alphabet 46 | img_channel = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1 47 | 48 | if config['arch']['args']['prediction']['type'] == 'CTC': 49 | self.converter = CTCLabelConverter(config['dataset']['alphabet']) 50 | elif config['arch']['args']['prediction']['type'] == 'Attn': 51 | self.converter = AttnLabelConverter(config['dataset']['alphabet']) 52 | self.net = get_model(img_channel, len(self.converter.character), config['arch']['args']) 53 | self.net.load_state_dict(checkpoint['state_dict']) 54 | # self.net = torch.jit.load('crnn_lite_gpu.pt') 55 | self.net.to(self.device) 56 | self.net.eval() 57 | sample_input = torch.zeros((2, img_channel, img_h, img_w)).to(self.device) 58 | self.net.get_batch_max_length(sample_input) 59 | 60 | def predict(self, img_path, model_save_path=None): 61 | """ 62 | 对传入的图像进行预测,支持图像地址和numpy数组 63 | :param img_path: 图像地址 64 | :return: 65 | """ 66 | assert os.path.exists(img_path), 'file is not exists' 67 | img = self.pre_processing(img_path) 68 | tensor = self.transform(img) 69 | tensor = tensor.unsqueeze(dim=0) 70 | 71 | tensor = tensor.to(self.device) 72 | preds, tensor_img = self.net(tensor) 73 | 74 | preds = preds.softmax(dim=2).detach().cpu().numpy() 75 | # result = decode(preds, self.alphabet, raw=True) 76 | # print(result) 77 | result = self.converter.decode(preds) 78 | if model_save_path is not None: 79 | # 输出用于部署的模型 80 | save(self.net, tensor, model_save_path) 81 | return result, tensor_img 82 | 83 | def pre_processing(self, img_path): 84 | """ 85 | 对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度 86 | :param img_path: 图片地址 87 | :return: 88 | """ 89 | img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0) 90 | if self.img_mode == 'RGB': 91 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 92 | h, w = img.shape[:2] 93 | ratio_h = float(self.img_h) / h 94 | new_w = int(w * ratio_h) 95 | 96 | if new_w < self.img_w: 97 | img = cv2.resize(img, (new_w, self.img_h)) 98 | step = np.zeros((self.img_h, self.img_w - new_w, img.shape[-1]), dtype=img.dtype) 99 | img = np.column_stack((img, step)) 100 | else: 101 | img = cv2.resize(img, (self.img_w, self.img_h)) 102 | return img 103 | 104 | 105 | def save(net, input, save_path): 106 | # 在gpu导出的模型只能在gpu使用,cpu导出的只能在cpu使用 107 | net.eval() 108 | traced_script_module = torch.jit.trace(net, input) 109 | traced_script_module.save(save_path) 110 | 111 | 112 | if __name__ == '__main__': 113 | from modeling import get_model 114 | import time 115 | from matplotlib import pyplot as plt 116 | from matplotlib.font_manager import FontProperties 117 | 118 | font = FontProperties(fname=r"msyh.ttc", size=14) 119 | 120 | img_path = '0.jpg' 121 | model_path = 'crnn_None_VGG_RNN_Attn/checkpoint/model_latest.pth' 122 | 123 | crnn_net = PytorchNet(model_path=model_path, gpu_id=0) 124 | start = time.time() 125 | for i in range(1): 126 | result, img = crnn_net.predict(img_path) 127 | break 128 | print((time.time() - start) *1000/ 1) 129 | 130 | label = result[0][0] 131 | print(result) 132 | # plt.title(label, fontproperties=font) 133 | # plt.imshow(img.detach().cpu().numpy().squeeze().transpose((1, 2, 0)), cmap='gray') 134 | # plt.show() 135 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyconfig==0.9.10 2 | imgaug==0.4.0 3 | lmdb==0.98 4 | matplotlib==3.2.1 5 | numpy==1.16.6 6 | opencv-python==4.2.0.34 7 | Pillow==7.0.0 8 | python-Levenshtein==0.12.0 9 | PyYAML==5.3.1 10 | torch==1.4.0 11 | torchvision==0.5.0 12 | tqdm==4.45.0 13 | trdg==1.5.0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/8/23 22:20 3 | # @Author : zhoujun 4 | import os 5 | 6 | 7 | def main(config): 8 | import torch 9 | 10 | from modeling import build_model, build_loss 11 | from data_loader import get_dataloader 12 | from trainer import Trainer 13 | from utils import CTCLabelConverter, AttnLabelConverter, load 14 | if os.path.isfile(config['dataset']['alphabet']): 15 | config['dataset']['alphabet'] = ''.join(load(config['dataset']['alphabet'])) 16 | 17 | prediction_type = config['arch']['head']['type'] 18 | 19 | # loss 设置 20 | criterion = build_loss(config['loss']) 21 | if prediction_type == 'CTC': 22 | converter = CTCLabelConverter(config['dataset']['alphabet']) 23 | elif prediction_type == 'Attn': 24 | converter = AttnLabelConverter(config['dataset']['alphabet']) 25 | else: 26 | raise NotImplementedError 27 | img_channel = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1 28 | config['arch']['backbone']['in_channels'] = img_channel 29 | config['arch']['head']['n_class'] = len(converter.character) 30 | # model = get_model(img_channel, len(converter.character), config['arch']['args']) 31 | model = build_model(config['arch']) 32 | img_h, img_w = 32, 100 33 | for process in config['dataset']['train']['dataset']['args']['pre_processes']: 34 | if process['type'] == "Resize": 35 | img_h = process['args']['img_h'] 36 | img_w = process['args']['img_w'] 37 | break 38 | sample_input = torch.zeros((2, img_channel, img_h, img_w)) 39 | num_label = model.get_batch_max_length(sample_input) 40 | train_loader = get_dataloader(config['dataset']['train'], num_label) 41 | assert train_loader is not None 42 | if 'validate' in config['dataset'] and config['dataset']['validate']['dataset']['args']['data_path'][0] is not None: 43 | validate_loader = get_dataloader(config['dataset']['validate'], num_label) 44 | else: 45 | validate_loader = None 46 | 47 | trainer = Trainer(config=config, model=model, criterion=criterion, train_loader=train_loader, validate_loader=validate_loader, sample_input=sample_input, 48 | converter=converter) 49 | trainer.train() 50 | 51 | 52 | def init_args(): 53 | import argparse 54 | parser = argparse.ArgumentParser(description='crnn.pytorch') 55 | parser.add_argument('--config_file', default='config/imagedataset_None_VGG_RNN_CTC_local.yaml', type=str) 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | if __name__ == '__main__': 61 | import sys 62 | import anyconfig 63 | 64 | project = 'crnn.pytorch' # 工作项目根目录 65 | sys.path.append(os.getcwd().split(project)[0] + project) 66 | 67 | from utils import parse_config 68 | 69 | args = init_args() 70 | assert os.path.exists(args.config_file) 71 | config = anyconfig.load(open(args.config_file, 'rb')) 72 | if 'base' in config: 73 | config = parse_config(config) 74 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in config['trainer']['gpus']]) 75 | main(config) 76 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/8/23 22:20 3 | # @Author : zhoujun 4 | import os 5 | import shutil 6 | import time 7 | import Levenshtein 8 | from tqdm import tqdm 9 | import torch 10 | 11 | from base import BaseTrainer 12 | from utils import save 13 | 14 | 15 | class Trainer(BaseTrainer): 16 | def __init__(self, config, model, criterion, train_loader, sample_input, converter, validate_loader=None): 17 | super().__init__(config, model, criterion, sample_input) 18 | self.train_loader = train_loader 19 | self.train_loader_len = len(train_loader) 20 | self.validate_loader = validate_loader 21 | self.converter = converter 22 | if self.validate_loader is not None: 23 | self.logger.info(f'train dataset has {self.train_loader.dataset_len} samples,{len(train_loader)} in dataloader,' 24 | f'validate dataset has {self.validate_loader.dataset_len} samples,{len(self.validate_loader)} in dataloader') 25 | else: 26 | self.logger.info(f'train dataset has {len(self.train_loader.dataset)} samples,{len(self.train_loader)} in dataloader') 27 | 28 | self.run_time_dict = {} 29 | # 保存本次实验的alphabet 到模型保存的地方 30 | self.alphabet = config['dataset']['alphabet'] 31 | save(self.alphabet, os.path.join(self.save_dir, 'dict.txt')) 32 | self.metrics = {'val_acc': 0, 33 | 'train_loss': float('inf'), 34 | 'best_acc_epoch': 0, 35 | 'train_acc': 0, 36 | 'norm_edit_dis': 0, 37 | 'best_ned_epoch': 0} 38 | 39 | def _train_epoch(self, epoch): 40 | self.model.train() 41 | epoch_start = time.time() 42 | self.run_time_dict['batch_start'] = time.time() 43 | self.run_time_dict['epoch'] = epoch 44 | self.run_time_dict['n_correct'] = 0 45 | self.run_time_dict['train_num'] = 0 46 | self.run_time_dict['train_loss'] = 0 47 | self.run_time_dict['norm_edit_dis'] = 0 48 | 49 | for i, batch in enumerate(self.train_loader): 50 | if i >= self.train_loader_len: 51 | break 52 | self.global_step += 1 53 | self.run_time_dict['lr'] = self.optimizer.param_groups[0]['lr'] 54 | self.run_time_dict['iter'] = i 55 | batch = self._before_step(batch) 56 | batch_out = self._run_step(batch) 57 | self._after_step(batch_out) 58 | epoch_time = time.time() - epoch_start 59 | self.run_time_dict['train_acc'] = self.run_time_dict['n_correct'] / self.run_time_dict['train_num'] 60 | self.logger.info(f"[{self.run_time_dict['epoch']}/{self.epochs}], " 61 | f"train_acc: {self.run_time_dict['train_acc']:.4f}, " 62 | f"train_loss: {self.run_time_dict['train_loss'] / self.train_loader_len:.4f}, " 63 | f"time: {epoch_time:.4f}, " 64 | f"lr: {self.run_time_dict['lr']}") 65 | 66 | def _before_step(self, batch): 67 | targets, targets_lengths = self.converter.encode(batch['label'], self.batch_max_length) 68 | batch['img'] = batch['img'].to(self.device) 69 | batch['targets'] = targets.to(self.device) 70 | batch['targets_lengths'] = targets_lengths 71 | return batch 72 | 73 | def _run_step(self, batch): 74 | # forward 75 | cur_batch_size = batch['img'].shape[0] 76 | targets = batch['targets'] 77 | if self.model.head_type == 'CTC': 78 | preds = self.model(batch['img'])[0] 79 | loss = self.criterion(preds, batch) 80 | elif self.model.head_type == 'Attn': 81 | preds = self.model(batch['img'], targets[:, :-1])[0] 82 | loss = self.criterion(preds, batch) 83 | else: 84 | raise NotImplementedError 85 | # backward 86 | self.optimizer.zero_grad() 87 | loss.backward() 88 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) 89 | self.optimizer.step() 90 | batch_dict = self.accuracy_batch(preds, batch['label']) 91 | batch_dict['loss'] = loss.item() 92 | batch_dict['batch_size'] = cur_batch_size 93 | return batch_dict 94 | 95 | def _after_step(self, batch_out): 96 | # loss 和 acc 记录到日志 97 | self.run_time_dict['train_num'] += batch_out['batch_size'] 98 | self.run_time_dict['train_loss'] += batch_out['loss'] 99 | self.run_time_dict['n_correct'] += batch_out['n_correct'] 100 | self.run_time_dict['norm_edit_dis'] += batch_out['norm_edit_dis'] 101 | 102 | acc = batch_out['n_correct'] / batch_out['batch_size'] 103 | norm_edit_dis = 1 - batch_out['norm_edit_dis'] / batch_out['batch_size'] 104 | if self.tensorboard_enable: 105 | # write tensorboard 106 | self.writer.add_scalar('TRAIN/loss', batch_out['loss'], self.global_step) 107 | self.writer.add_scalar('TRAIN/acc', acc, self.global_step) 108 | self.writer.add_scalar('TRAIN/norm_edit_dis', norm_edit_dis, self.global_step) 109 | self.writer.add_scalar('TRAIN/lr', self.run_time_dict['lr'], self.global_step) 110 | self.writer.add_text('Train/pred_gt', ' || '.join(batch_out['show_str'][:10]), self.global_step) 111 | 112 | if self.global_step % self.log_iter == 0: 113 | batch_time = time.time() - self.run_time_dict['batch_start'] 114 | speed = self.log_iter * batch_out['batch_size'] / batch_time 115 | self.logger.info(f"[{self.run_time_dict['epoch']}/{self.epochs}], " 116 | f"[{self.run_time_dict['iter'] + 1}/{self.train_loader_len}], global_step: {self.global_step}, " 117 | f"Speed: {speed:.1f} samples/sec, loss:{batch_out['loss']:.4f}, " 118 | f"acc:{acc:.4f}, norm_edit_dis:{norm_edit_dis:.4f} lr:{self.run_time_dict['lr']}, " 119 | f"time:{batch_time:.2f}") 120 | self.run_time_dict['batch_start'] = time.time() 121 | 122 | def _eval(self, max_step=None, dest='test model'): 123 | self.model.eval() 124 | n_correct = 0 125 | norm_edit_dis = 0 126 | show_str = [] 127 | for i, (images, labels) in enumerate(tqdm(self.validate_loader, desc=dest)): 128 | if max_step is not None and i >= max_step: 129 | break 130 | images = images.to(self.device) 131 | with torch.no_grad(): 132 | preds = self.model(images)[0] 133 | batch_dict = self.accuracy_batch(preds, labels) 134 | n_correct += batch_dict['n_correct'] 135 | norm_edit_dis += batch_dict['norm_edit_dis'] 136 | show_str.extend(batch_dict['show_str']) 137 | return {'n_correct': n_correct, 'norm_edit_dis': norm_edit_dis, 'show_str': show_str} 138 | 139 | def _on_epoch_finish(self): 140 | net_save_path = f'{self.checkpoint_dir}/model_latest.pth' 141 | self._save_checkpoint(self.run_time_dict['epoch'], net_save_path) 142 | 143 | if self.validate_loader is not None: 144 | epoch_eval_dict = self._eval() 145 | val_acc = epoch_eval_dict['n_correct'] / self.validate_loader.dataset_len 146 | norm_edit_dis = 1 - epoch_eval_dict['norm_edit_dis'] / self.validate_loader.dataset_len 147 | 148 | if self.tensorboard_enable: 149 | self.writer.add_scalar('EVAL/acc', val_acc, self.global_step) 150 | self.writer.add_scalar('EVAL/edit_distance', norm_edit_dis, self.global_step) 151 | self.writer.add_text('EVAL/pred_gt', ' || '.join(epoch_eval_dict['show_str'][:10]), self.global_step) 152 | 153 | self.logger.info(f"[{self.run_time_dict['epoch']}/{self.epochs}], val_acc: {val_acc:.6f}, " 154 | f"norm_edit_dis: {norm_edit_dis}") 155 | 156 | if val_acc >= self.metrics['val_acc']: 157 | self.metrics['val_acc'] = val_acc 158 | self.metrics['train_loss'] = self.run_time_dict['train_loss'] 159 | self.metrics['train_acc'] = self.run_time_dict['train_acc'] 160 | self.metrics['best_acc_epoch'] = self.run_time_dict['epoch'] 161 | best_save_path = f'{self.checkpoint_dir}/model_bect_acc.pth' 162 | shutil.copy(net_save_path, best_save_path) 163 | self.logger.info(f"Saving current best acc : {best_save_path}") 164 | if norm_edit_dis >= self.metrics['norm_edit_dis']: 165 | self.metrics['norm_edit_dis'] = norm_edit_dis 166 | self.metrics['train_loss'] = self.run_time_dict['train_loss'] 167 | self.metrics['train_acc'] = self.run_time_dict['train_acc'] 168 | self.metrics['best_ned_epoch'] = self.run_time_dict['epoch'] 169 | best_save_path = f'{self.checkpoint_dir}/model_bect_ned.pth' 170 | shutil.copy(net_save_path, best_save_path) 171 | self.logger.info(f"Saving current best norm_edit_dis : {best_save_path}") 172 | else: 173 | if self.run_time_dict['train_acc'] > self.metrics['train_acc']: 174 | self.metrics['train_loss'] = self.run_time_dict['train_loss'] 175 | self.metrics['train_acc'] = self.run_time_dict['train_acc'] 176 | self.metrics['best_model_epoch'] = self.run_time_dict['epoch'] 177 | best_save_path = f'{self.checkpoint_dir}/model_bect_loss.pth' 178 | shutil.copy(net_save_path, best_save_path) 179 | self.logger.info(f"Saving current best loss : {best_save_path}") 180 | best_str = 'current best, ' 181 | for k, v in self.metrics.items(): 182 | best_str += '{}: {}, '.format(k, v) 183 | self.logger.info(best_str) 184 | 185 | def accuracy_batch(self, predictions, labels): 186 | n_correct = 0 187 | norm_edit_dis = 0.0 188 | predictions = predictions.softmax(dim=2).detach().cpu().numpy() 189 | preds_str = self.converter.decode(predictions) 190 | show_str = [] 191 | for (pred, pred_conf), target in zip(preds_str, labels): 192 | norm_edit_dis += Levenshtein.distance(pred, target) / max(len(pred), len(target)) 193 | show_str.append(f'{pred} -> {target}') 194 | if pred == target: 195 | n_correct += 1 196 | return {'n_correct': n_correct, 'norm_edit_dis': norm_edit_dis, 'show_str': show_str} 197 | 198 | def _on_train_finish(self): 199 | for k, v in self.metrics.items(): 200 | self.logger.info(f'{k}:{v}') 201 | self.logger.info('finish train') 202 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .label_utils import * 3 | -------------------------------------------------------------------------------- /utils/create_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/11/6 15:31 3 | # @Author : zhoujun 4 | 5 | """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ 6 | 7 | import os 8 | import lmdb 9 | import cv2 10 | from tqdm import tqdm 11 | import numpy as np 12 | 13 | from utils import punctuation_mend, get_datalist 14 | 15 | 16 | def checkImageIsValid(imageBin): 17 | if imageBin is None: 18 | return False 19 | imageBuf = np.frombuffer(imageBin, dtype=np.uint8) 20 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 21 | imgH, imgW = img.shape[0], img.shape[1] 22 | if imgH * imgW == 0: 23 | return False 24 | return True 25 | 26 | 27 | def writeCache(env, cache): 28 | with env.begin(write=True) as txn: 29 | for k, v in cache.items(): 30 | txn.put(k, v) 31 | 32 | 33 | def createDataset(data_list, outputPath, checkValid=True): 34 | """ 35 | Create LMDB dataset for training and evaluation. 36 | ARGS: 37 | inputPath : input folder path where starts imagePath 38 | outputPath : LMDB output path 39 | gtFile : list of image path and label 40 | checkValid : if true, check the validity of every image 41 | """ 42 | os.makedirs(outputPath, exist_ok=True) 43 | env = lmdb.open(outputPath, map_size=1099511627776) 44 | cache = {} 45 | cnt = 1 46 | for imagePath, label in tqdm(data_list, desc=f'make dataset, save to {outputPath}'): 47 | with open(imagePath, 'rb') as f: 48 | imageBin = f.read() 49 | if checkValid: 50 | try: 51 | if not checkImageIsValid(imageBin): 52 | print('%s is not a valid image' % imagePath) 53 | continue 54 | except: 55 | continue 56 | 57 | imageKey = 'image-%09d'.encode() % cnt 58 | labelKey = 'label-%09d'.encode() % cnt 59 | cache[imageKey] = imageBin 60 | cache[labelKey] = label.encode() 61 | 62 | if cnt % 1000 == 0: 63 | writeCache(env, cache) 64 | cache = {} 65 | cnt += 1 66 | nSamples = cnt - 1 67 | cache['num-samples'.encode()] = str(nSamples).encode() 68 | writeCache(env, cache) 69 | print('Created dataset with %d samples' % nSamples) 70 | 71 | 72 | if __name__ == '__main__': 73 | data_list = [["train.txt"]] 74 | save_path = 'lmdb/train' 75 | os.makedirs(save_path, exist_ok=True) 76 | train_data_list = get_datalist(data_list, 800) 77 | 78 | createDataset(train_data_list, save_path) 79 | -------------------------------------------------------------------------------- /utils/gen_img.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/4/28 15:21 3 | # @Author : zhoujun 4 | 5 | import argparse 6 | import os, errno 7 | import sys 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | 11 | import random as rnd 12 | import sys 13 | 14 | from tqdm import tqdm 15 | from trdg.string_generator import ( 16 | create_strings_from_file, 17 | create_strings_from_wikipedia, 18 | create_strings_randomly, 19 | ) 20 | from trdg.utils import load_dict, load_fonts 21 | from trdg.data_generator import FakeTextDataGenerator 22 | from multiprocessing import Pool 23 | 24 | 25 | def create_strings_from_dict(length, allow_variable, count, lang_dict,add_blank=False): 26 | """ 27 | Create all strings by picking X random word in the dictionnary 28 | """ 29 | 30 | dict_len = len(lang_dict) 31 | strings = [] 32 | for _ in range(0, count): 33 | current_string = "" 34 | for _ in range(0, rnd.randint(1, length) if allow_variable else length): 35 | current_string += lang_dict[rnd.randrange(dict_len)] 36 | if add_blank: 37 | current_string += " " 38 | strings.append(current_string[:-1]) 39 | return strings 40 | 41 | 42 | def margins(margin): 43 | margins = margin.split(",") 44 | if len(margins) == 1: 45 | return [int(margins[0])] * 4 46 | return [int(m) for m in margins] 47 | 48 | 49 | def parse_arguments(): 50 | """ 51 | Parse the command line arguments of the program. 52 | """ 53 | 54 | parser = argparse.ArgumentParser( 55 | description="Generate synthetic text data for text recognition." 56 | ) 57 | parser.add_argument( 58 | "--output_dir", type=str, nargs="?", help="The output directory", default="out/" 59 | ) 60 | parser.add_argument( 61 | "-i", 62 | "--input_file", 63 | type=str, 64 | nargs="?", 65 | help="When set, this argument uses a specified text file as source for the text", 66 | default="", 67 | ) 68 | parser.add_argument( 69 | "-l", 70 | "--language", 71 | type=str, 72 | nargs="?", 73 | help="The language to use, should be fr (French), en (English), es (Spanish), de (German), ar (Arabic), cn (Chinese), or hi (Hindi)", 74 | default="en" 75 | ) 76 | parser.add_argument( 77 | "-c", 78 | "--count", 79 | type=int, 80 | nargs="?", 81 | help="The number of images to be created.", 82 | required=True, 83 | ) 84 | parser.add_argument( 85 | "-rs", 86 | "--random_sequences", 87 | action="store_true", 88 | help="Use random sequences as the source text for the generation. Set '-let','-num','-sym' to use letters/numbers/symbols. If none specified, using all three.", 89 | default=False, 90 | ) 91 | parser.add_argument( 92 | "-let", 93 | "--include_letters", 94 | action="store_true", 95 | help="Define if random sequences should contain letters. Only works with -rs", 96 | default=False, 97 | ) 98 | parser.add_argument( 99 | "-num", 100 | "--include_numbers", 101 | action="store_true", 102 | help="Define if random sequences should contain numbers. Only works with -rs", 103 | default=False, 104 | ) 105 | parser.add_argument( 106 | "-sym", 107 | "--include_symbols", 108 | action="store_true", 109 | help="Define if random sequences should contain symbols. Only works with -rs", 110 | default=False, 111 | ) 112 | parser.add_argument( 113 | "-w", 114 | "--length", 115 | type=int, 116 | nargs="?", 117 | help="Define how many words should be included in each generated sample. If the text source is Wikipedia, this is the MINIMUM length", 118 | default=1, 119 | ) 120 | parser.add_argument( 121 | "-r", 122 | "--random", 123 | action="store_true", 124 | help="Define if the produced string will have variable word count (with --length being the maximum)", 125 | default=False, 126 | ) 127 | parser.add_argument( 128 | "-f", 129 | "--format", 130 | type=int, 131 | nargs="?", 132 | help="Define the height of the produced images if horizontal, else the width", 133 | default=32, 134 | ) 135 | parser.add_argument( 136 | "-t", 137 | "--thread_count", 138 | type=int, 139 | nargs="?", 140 | help="Define the number of thread to use for image generation", 141 | default=1, 142 | ) 143 | parser.add_argument( 144 | "-e", 145 | "--extension", 146 | type=str, 147 | nargs="?", 148 | help="Define the extension to save the image with", 149 | default="jpg", 150 | ) 151 | parser.add_argument( 152 | "-k", 153 | "--skew_angle", 154 | type=int, 155 | nargs="?", 156 | help="Define skewing angle of the generated text. In positive degrees", 157 | default=0, 158 | ) 159 | parser.add_argument( 160 | "-rk", 161 | "--random_skew", 162 | action="store_true", 163 | help="When set, the skew angle will be randomized between the value set with -k and it's opposite", 164 | default=False, 165 | ) 166 | parser.add_argument( 167 | "-wk", 168 | "--use_wikipedia", 169 | action="store_true", 170 | help="Use Wikipedia as the source text for the generation, using this paremeter ignores -r, -n, -s", 171 | default=False, 172 | ) 173 | parser.add_argument( 174 | "-bl", 175 | "--blur", 176 | type=int, 177 | nargs="?", 178 | help="Apply gaussian blur to the resulting sample. Should be an integer defining the blur radius", 179 | default=0, 180 | ) 181 | parser.add_argument( 182 | "-rbl", 183 | "--random_blur", 184 | action="store_true", 185 | help="When set, the blur radius will be randomized between 0 and -bl.", 186 | default=False, 187 | ) 188 | parser.add_argument( 189 | "-b", 190 | "--background", 191 | type=int, 192 | nargs="?", 193 | help="Define what kind of background to use. 0: Gaussian Noise, 1: Plain white, 2: Quasicrystal, 3: Image", 194 | default=0, 195 | ) 196 | parser.add_argument( 197 | "-hw", 198 | "--handwritten", 199 | action="store_true", 200 | help='Define if the data will be "handwritten" by an RNN', 201 | ) 202 | parser.add_argument( 203 | "-na", 204 | "--name_format", 205 | type=int, 206 | help="Define how the produced files will be named. 0: [TEXT]_[ID].[EXT], 1: [ID]_[TEXT].[EXT] 2: [ID].[EXT] + one file labels.txt containing id-to-label mappings", 207 | default=0, 208 | ) 209 | parser.add_argument( 210 | "-om", 211 | "--output_mask", 212 | type=int, 213 | help="Define if the generator will return masks for the text", 214 | default=0, 215 | ) 216 | parser.add_argument( 217 | "-d", 218 | "--distorsion", 219 | type=int, 220 | nargs="?", 221 | help="Define a distorsion applied to the resulting image. 0: None (Default), 1: Sine wave, 2: Cosine wave, 3: Random", 222 | default=0, 223 | ) 224 | parser.add_argument( 225 | "-do", 226 | "--distorsion_orientation", 227 | type=int, 228 | nargs="?", 229 | help="Define the distorsion's orientation. Only used if -d is specified. 0: Vertical (Up and down), 1: Horizontal (Left and Right), 2: Both", 230 | default=0, 231 | ) 232 | parser.add_argument( 233 | "-wd", 234 | "--width", 235 | type=int, 236 | nargs="?", 237 | help="Define the width of the resulting image. If not set it will be the width of the text + 10. If the width of the generated text is bigger that number will be used", 238 | default=-1, 239 | ) 240 | parser.add_argument( 241 | "-al", 242 | "--alignment", 243 | type=int, 244 | nargs="?", 245 | help="Define the alignment of the text in the image. Only used if the width parameter is set. 0: left, 1: center, 2: right", 246 | default=1, 247 | ) 248 | parser.add_argument( 249 | "-or", 250 | "--orientation", 251 | type=int, 252 | nargs="?", 253 | help="Define the orientation of the text. 0: Horizontal, 1: Vertical", 254 | default=0, 255 | ) 256 | parser.add_argument( 257 | "-tc", 258 | "--text_color", 259 | type=str, 260 | nargs="?", 261 | help="Define the text's color, should be either a single hex color or a range in the ?,? format.", 262 | default="#282828", 263 | ) 264 | parser.add_argument( 265 | "-sw", 266 | "--space_width", 267 | type=float, 268 | nargs="?", 269 | help="Define the width of the spaces between words. 2.0 means twice the normal space width", 270 | default=1.0, 271 | ) 272 | parser.add_argument( 273 | "-cs", 274 | "--character_spacing", 275 | type=int, 276 | nargs="?", 277 | help="Define the width of the spaces between characters. 2 means two pixels", 278 | default=0, 279 | ) 280 | parser.add_argument( 281 | "-m", 282 | "--margins", 283 | type=margins, 284 | nargs="?", 285 | help="Define the margins around the text when rendered. In pixels", 286 | default=(5, 5, 5, 5), 287 | ) 288 | parser.add_argument( 289 | "-fi", 290 | "--fit", 291 | action="store_true", 292 | help="Apply a tight crop around the rendered text", 293 | default=False, 294 | ) 295 | parser.add_argument( 296 | "-ft", "--font", type=str, nargs="?", help="Define font to be used" 297 | ) 298 | parser.add_argument( 299 | "-fd", 300 | "--font_dir", 301 | type=str, 302 | nargs="?", 303 | help="Define a font directory to be used", 304 | ) 305 | parser.add_argument( 306 | "-id", 307 | "--image_dir", 308 | type=str, 309 | nargs="?", 310 | help="Define an image directory to use when background is set to image", 311 | default=os.path.join(os.path.split(os.path.realpath(__file__))[0], "images") 312 | ) 313 | parser.add_argument( 314 | "-ca", 315 | "--case", 316 | type=str, 317 | nargs="?", 318 | help="Generate upper or lowercase only. arguments: upper or lower. Example: --case upper", 319 | ) 320 | parser.add_argument( 321 | "-dt", "--dict", type=str, nargs="?", help="Define the dictionary to be used" 322 | ) 323 | parser.add_argument( 324 | "-ws", "--word_split", 325 | action="store_true", 326 | help="Split on words instead of on characters (preserves ligatures, no character spacing)", 327 | default=False, 328 | ) 329 | return parser.parse_args() 330 | 331 | 332 | def main(): 333 | """ 334 | Description: Main function 335 | """ 336 | 337 | # Argument parsing 338 | args = parse_arguments() 339 | 340 | # Create the directory if it does not exist. 341 | try: 342 | os.makedirs(args.output_dir) 343 | except OSError as e: 344 | if e.errno != errno.EEXIST: 345 | raise 346 | 347 | # Creating word list 348 | if args.dict: 349 | lang_dict = [] 350 | if os.path.isfile(args.dict): 351 | with open(args.dict, "r", encoding="utf8", errors="ignore") as d: 352 | lang_dict = [l for l in d.read().splitlines() if len(l) > 0] 353 | else: 354 | sys.exit("Cannot open dict") 355 | else: 356 | lang_dict = load_dict(args.language) 357 | 358 | # Create font (path) list 359 | if args.font_dir: 360 | fonts = [ 361 | os.path.join(args.font_dir, p) 362 | for p in os.listdir(args.font_dir) 363 | if os.path.splitext(p)[1] == ".ttf" 364 | ] 365 | elif args.font: 366 | if os.path.isfile(args.font): 367 | fonts = [args.font] 368 | else: 369 | sys.exit("Cannot open font") 370 | else: 371 | fonts = load_fonts(args.language) 372 | 373 | # Creating synthetic sentences (or word) 374 | strings = [] 375 | 376 | if args.use_wikipedia: 377 | strings = create_strings_from_wikipedia(args.length, args.count, args.language) 378 | elif args.input_file != "": 379 | strings = create_strings_from_file(args.input_file, args.count) 380 | elif args.random_sequences: 381 | strings = create_strings_randomly( 382 | args.length, 383 | args.random, 384 | args.count, 385 | args.include_letters, 386 | args.include_numbers, 387 | args.include_symbols, 388 | args.language, 389 | ) 390 | # Set a name format compatible with special characters automatically if they are used 391 | if args.include_symbols or True not in ( 392 | args.include_letters, 393 | args.include_numbers, 394 | args.include_symbols, 395 | ): 396 | args.name_format = 2 397 | else: 398 | strings = create_strings_from_dict( 399 | args.length, args.random, args.count, lang_dict 400 | ) 401 | 402 | if args.case == "upper": 403 | strings = [x.upper() for x in strings] 404 | if args.case == "lower": 405 | strings = [x.lower() for x in strings] 406 | 407 | string_count = len(strings) 408 | 409 | p = Pool(args.thread_count) 410 | for _ in tqdm( 411 | p.imap_unordered( 412 | FakeTextDataGenerator.generate_from_tuple, 413 | zip( 414 | [i for i in range(0, string_count)], 415 | strings, 416 | [fonts[rnd.randrange(0, len(fonts))] for _ in range(0, string_count)], 417 | [args.output_dir] * string_count, 418 | [args.format] * string_count, 419 | [args.extension] * string_count, 420 | [args.skew_angle] * string_count, 421 | [args.random_skew] * string_count, 422 | [args.blur] * string_count, 423 | [args.random_blur] * string_count, 424 | [args.background] * string_count, 425 | [args.distorsion] * string_count, 426 | [args.distorsion_orientation] * string_count, 427 | [args.handwritten] * string_count, 428 | [args.name_format] * string_count, 429 | [args.width] * string_count, 430 | [args.alignment] * string_count, 431 | [args.text_color] * string_count, 432 | [args.orientation] * string_count, 433 | [args.space_width] * string_count, 434 | [args.character_spacing] * string_count, 435 | [args.margins] * string_count, 436 | [args.fit] * string_count, 437 | [args.output_mask] * string_count, 438 | [args.word_split] * string_count, 439 | [args.image_dir] * string_count, 440 | ), 441 | ), 442 | total=args.count, 443 | ): 444 | pass 445 | p.terminate() 446 | 447 | if args.name_format == 2: 448 | # Create file with filename-to-label connections 449 | with open( 450 | os.path.join(args.output_dir, "labels.txt"), "w", encoding="utf8" 451 | ) as f: 452 | for i in range(string_count): 453 | file_name = str(i) + "." + args.extension 454 | f.write("{} {}\n".format(file_name, strings[i])) 455 | 456 | 457 | if __name__ == "__main__": 458 | main() 459 | -------------------------------------------------------------------------------- /utils/get_keys.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/8/24 10:20 3 | # @Author : zhoujun 4 | import os 5 | import sys 6 | import pathlib 7 | 8 | sys.path.append(str(pathlib.Path(os.path.abspath(__name__)).parent)) 9 | import argparse 10 | import cv2 11 | from tqdm import tqdm 12 | import numpy as np 13 | from collections import defaultdict 14 | from itertools import groupby 15 | from utils import parse_config, punctuation_mend 16 | 17 | 18 | def split(data_dict, num=10): 19 | if num == 1: 20 | x = [x[0] for x in data_dict] 21 | y = [x[1] for x in data_dict] 22 | else: 23 | x = [] 24 | y = [] 25 | for k, g in groupby(data_dict, key=lambda item: item[0] // num): 26 | cur_sum = sum([x[1] for x in g]) 27 | print('{}-{}: {}'.format(k * num, (k + 1) * num - 1, cur_sum)) 28 | x.append((k + 1) * num - 1) 29 | y.append(cur_sum) 30 | return x, y 31 | 32 | 33 | def show_dict(data_dict: dict, num, title): 34 | from matplotlib import pyplot as plt 35 | data_dict = sorted(data_dict.items(), key=lambda item: item[0]) 36 | x, y = split(data_dict, num) 37 | y = np.array(y) 38 | y = y / y.sum() 39 | print(x[y.argmax()]) 40 | plt.figure() 41 | plt.title(title) 42 | plt.plot(x, y) 43 | plt.savefig('1.jpg') 44 | # plt.show() 45 | 46 | 47 | def get_key(label_file_list, ignore_chinese_punctuation, show_max_img=False): 48 | data_list = [] 49 | label_list = [] 50 | len_dict = defaultdict(int) 51 | h_dict = defaultdict(int) 52 | w_dict = defaultdict(int) 53 | for label_path in label_file_list: 54 | with open(label_path, 'r', encoding='utf-8') as f: 55 | for line in tqdm(f.readlines(), desc=label_path): 56 | line = line.strip('\n').replace('.jpg ', '.jpg\t').replace('.png ', '.png\t').split('\t') 57 | if len(line) > 1 and os.path.exists(line[0]): 58 | data_list.append(line[0]) 59 | label = line[1] 60 | if ignore_chinese_punctuation: 61 | label = punctuation_mend(label) 62 | label_list.append(label) 63 | len_dict[len(line[1])] += 1 64 | if show_max_img: 65 | img = cv2.imread(line[0]) 66 | h, w = img.shape[:2] 67 | h_dict[h] += 1 68 | w_dict[w] += 1 69 | if show_max_img: 70 | print('******************分析宽度******************') 71 | show_dict(w_dict, 10, 'w') 72 | print('******************分析高度******************') 73 | show_dict(h_dict, 1, 'h') 74 | print('******************分析label长度******************') 75 | show_dict(len_dict, 1, 'label') 76 | a = ''.join(sorted(set((''.join(label_list))))) 77 | return a 78 | 79 | 80 | if __name__ == '__main__': 81 | # 根据label文本生产key 82 | import anyconfig 83 | from utils import save 84 | 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--label_file', nargs='+', help='label file', default=[""]) 87 | args = parser.parse_args() 88 | 89 | config_path = 'config/imagedataset_None_VGG_RNN_CTC.yaml' 90 | if os.path.exists(config_path): 91 | config = anyconfig.load(open(config_path, 'rb')) 92 | if 'base' in config: 93 | config = parse_config(config) 94 | label_file = [] 95 | for train_file in config['dataset']['train']['dataset']['args']['data_path']: 96 | if isinstance(train_file, list): 97 | label_file.extend(train_file) 98 | else: 99 | label_file.append(train_file) 100 | label_file.extend(config['dataset']['validate']['dataset']['args']['data_path']) 101 | ignore_chinese_punctuation = config['dataset']['train']['dataset']['args']['ignore_chinese_punctuation'] 102 | else: 103 | ignore_chinese_punctuation = True 104 | label_file = args.label_file 105 | alphabet = get_key(label_file, ignore_chinese_punctuation, show_max_img=False).replace(' ', '') 106 | save(list(alphabet), 'dict.txt') 107 | print(alphabet) 108 | -------------------------------------------------------------------------------- /utils/label_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch 5 | 6 | 7 | class CTCLabelConverter(object): 8 | """ Convert between text-label and text-index """ 9 | 10 | def __init__(self, character): 11 | # character (str): set of the possible characters. 12 | dict_character = list(character) 13 | 14 | self.dict = {} 15 | for i, char in enumerate(dict_character): 16 | # NOTE: 0 is reserved for 'blank' token required by CTCLoss 17 | self.dict[char] = i + 1 18 | 19 | self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0) 20 | 21 | def encode(self, text, batch_max_length=None): 22 | """convert text-label into text-index. 23 | input: 24 | text: text labels of each image. [batch_size] 25 | 26 | output: 27 | text: concatenated text index for CTCLoss. 28 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 29 | length: length of each text. [batch_size] 30 | """ 31 | length = [len(s) for s in text] 32 | # text = ''.join(text) 33 | # text = [self.dict[char] for char in text] 34 | d = [] 35 | for s in text: 36 | t = [self.dict[char] for char in s] 37 | t.extend([0] * (batch_max_length - len(s))) 38 | d.append(t) 39 | return (torch.tensor(d, dtype=torch.long), torch.tensor(length, dtype=torch.long)) 40 | 41 | def decode(self, preds, raw=False): 42 | """ convert text-index into text-label. """ 43 | preds_idx = preds.argmax(axis=2) 44 | preds_prob = preds.max(axis=2) 45 | result_list = [] 46 | for word, prob in zip(preds_idx, preds_prob): 47 | if raw: 48 | result_list.append((''.join([self.character[int(i)] for i in word]), prob)) 49 | else: 50 | result = [] 51 | conf = [] 52 | for i, index in enumerate(word): 53 | if word[i] != 0 and (not (i > 0 and word[i - 1] == word[i])): 54 | result.append(self.character[int(index)]) 55 | conf.append(prob[i]) 56 | result_list.append((''.join(result), conf)) 57 | return result_list 58 | 59 | 60 | class AttnLabelConverter(object): 61 | """ Convert between text-label and text-index """ 62 | 63 | def __init__(self, character): 64 | # character (str): set of the possible characters. 65 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 66 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 67 | list_character = list(character) 68 | self.character = list_token + list_character 69 | 70 | self.dict = {} 71 | for i, char in enumerate(self.character): 72 | # print(i, char) 73 | self.dict[char] = i 74 | 75 | def encode(self, text, batch_max_length): 76 | """ convert text-label into text-index. 77 | input: 78 | text: text labels of each image. [batch_size] 79 | batch_max_length: max length of text label in the batch. 25 by default 80 | 81 | output: 82 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 83 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 84 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 85 | """ 86 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 87 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 88 | batch_max_length += 1 89 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 90 | batch_text = torch.zeros(len(text), batch_max_length + 1, dtype=torch.long) 91 | for i, t in enumerate(text): 92 | text = list(t) 93 | text.append('[s]') 94 | text = [self.dict[char] for char in text] 95 | batch_text[i][1:1 + len(text)] = torch.Tensor(text).long() # batch_text[:, 0] = [GO] token 96 | return (batch_text, torch.Tensor(length).int()) 97 | 98 | def decode(self, preds): 99 | """ convert text-index into text-label. """ 100 | preds_idx = preds.argmax(axis=2) 101 | preds_prob = preds.max(axis=2) 102 | result_list = [] 103 | for word, prob in zip(preds_idx, preds_prob): 104 | text = ''.join([self.character[i] for i in word]) 105 | end_idx = text.find('[s]') 106 | text = text[:end_idx] 107 | conf = prob[:end_idx].tolist() 108 | result_list.append((text, conf)) 109 | return result_list 110 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 18-5-20 下午8:07 3 | # @Author : zhoujun 4 | import time 5 | import json 6 | import pathlib 7 | from tqdm import tqdm 8 | 9 | 10 | def setup_logger(log_file_path: str = None): 11 | import logging 12 | logger = logging.getLogger('crnn.pytorch') 13 | logger.setLevel(logging.DEBUG) 14 | formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s') 15 | ch = logging.StreamHandler() 16 | ch.setFormatter(formatter) 17 | logger.addHandler(ch) 18 | if log_file_path is not None: 19 | file_handle = logging.FileHandler(log_file_path) 20 | file_handle.setFormatter(formatter) 21 | logger.addHandler(file_handle) 22 | logger.info('logger init finished') 23 | return logger 24 | 25 | 26 | # --exeTime 27 | def exe_time(func): 28 | def newFunc(*args, **args2): 29 | t0 = time.time() 30 | back = func(*args, **args2) 31 | print(f"{func.__name__} cost {time.time() - t0:.3f}s") 32 | return back 33 | 34 | return newFunc 35 | 36 | 37 | def load(file_path: str): 38 | file_path = pathlib.Path(file_path) 39 | func_dict = {'.txt': _load_txt, '.json': _load_json, '.list': _load_txt} 40 | assert file_path.suffix in func_dict 41 | return func_dict[file_path.suffix](file_path) 42 | 43 | 44 | def _load_txt(file_path: str): 45 | with open(file_path, 'r', encoding='utf8') as f: 46 | content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()] 47 | return content 48 | 49 | 50 | def _load_json(file_path: str): 51 | with open(file_path, 'r', encoding='utf8') as f: 52 | content = json.load(f) 53 | return content 54 | 55 | 56 | def save(data, file_path): 57 | file_path = pathlib.Path(file_path) 58 | func_dict = {'.txt': _save_txt, '.json': _save_json} 59 | assert file_path.suffix in func_dict 60 | return func_dict[file_path.suffix](data, file_path) 61 | 62 | 63 | def _save_txt(data, file_path): 64 | """ 65 | 将一个list的数组写入txt文件里 66 | :param data: 67 | :param file_path: 68 | :return: 69 | """ 70 | if not isinstance(data, list): 71 | data = [data] 72 | with open(file_path, mode='w', encoding='utf8') as f: 73 | f.write('\n'.join(data)) 74 | 75 | 76 | def _save_json(data, file_path): 77 | with open(file_path, 'w', encoding='utf-8') as json_file: 78 | json.dump(data, json_file, ensure_ascii=False, indent=4) 79 | 80 | 81 | def punctuation_mend(string): 82 | # 输入字符串或者txt文件路径 83 | import unicodedata 84 | import pathlib 85 | 86 | table = {ord(f): ord(t) for f, t in zip( 87 | u',。!?【】()%#@&1234567890“”‘’', 88 | u',.!?[]()%#@&1234567890""\'\'')} # 其他自定义需要修改的符号可以加到这里 89 | res = unicodedata.normalize('NFKC', string) 90 | res = res.translate(table) 91 | return res 92 | 93 | 94 | def get_datalist(data_path, max_len): 95 | """ 96 | 获取训练和验证的数据list 97 | :param data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ 98 | :return: 99 | """ 100 | train_data = [] 101 | if isinstance(data_path, list): 102 | for p in data_path: 103 | train_data.extend(get_datalist(p, max_len)) 104 | else: 105 | with open(data_path, 'r', encoding='utf-8') as f: 106 | for line in tqdm(f.readlines(), desc=f'load data from {data_path}'): 107 | line = line.strip('\n').replace('.jpg ', '.jpg\t').replace('.png ', '.png\t').split('\t') 108 | if len(line) > 1: 109 | img_path = pathlib.Path(line[0].strip(' ')) 110 | label = line[1] 111 | if len(label) > max_len: 112 | continue 113 | if img_path.exists() and img_path.stat().st_size > 0: 114 | train_data.append((str(img_path), label)) 115 | return train_data 116 | 117 | 118 | def parse_config(config: dict) -> dict: 119 | import anyconfig 120 | base_file_list = config.pop('base') 121 | base_config = {} 122 | for base_file in base_file_list: 123 | tmp_config = anyconfig.load(open(base_file, 'rb')) 124 | if 'base' in tmp_config: 125 | tmp_config = parse_config(tmp_config) 126 | anyconfig.merge(tmp_config, base_config) 127 | base_config = tmp_config 128 | anyconfig.merge(base_config, config) 129 | return base_config 130 | 131 | 132 | # 网络参数数量 133 | def get_parameter_number(net): 134 | total_num = sum(p.numel() for p in net.parameters()) 135 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 136 | return {'Total': total_num, 'Trainable': trainable_num} 137 | 138 | 139 | class Averager(object): 140 | """Compute average for torch.Tensor, used for loss average.""" 141 | 142 | def __init__(self): 143 | self.reset() 144 | 145 | def add(self, v): 146 | count = v.data.numel() 147 | v = v.data.sum() 148 | self.n_count += count 149 | self.sum += v 150 | 151 | def reset(self): 152 | self.n_count = 0 153 | self.sum = 0 154 | 155 | def val(self): 156 | res = 0 157 | if self.n_count != 0: 158 | res = self.sum / float(self.n_count) 159 | return res 160 | 161 | 162 | if __name__ == '__main__': 163 | print(punctuation_mend('anufacturingcolt')) 164 | --------------------------------------------------------------------------------