├── .gitignore ├── LICENSE.md ├── README.MD ├── __init__.py ├── base ├── __init__.py └── base_trainer.py ├── config.json ├── config ├── __init__.py └── default.py ├── data_loader ├── __init__.py ├── augment.py ├── data_utils.py └── dataset.py ├── eval.py ├── imgs ├── example │ ├── img_10.jpg │ ├── img_2.jpg │ ├── img_29.jpg │ ├── img_75.jpg │ └── img_91.jpg └── paper │ └── PAN.jpg ├── models ├── __init__.py ├── loss.py ├── model.py └── modules │ ├── __init__.py │ ├── resnet.py │ ├── segmentation_head.py │ └── shufflenetv2.py ├── post_processing ├── Makefile ├── __init__.py ├── include │ └── pybind11 │ │ ├── attr.h │ │ ├── buffer_info.h │ │ ├── cast.h │ │ ├── chrono.h │ │ ├── class_support.h │ │ ├── common.h │ │ ├── complex.h │ │ ├── descr.h │ │ ├── detail │ │ ├── class.h │ │ ├── common.h │ │ ├── descr.h │ │ ├── init.h │ │ ├── internals.h │ │ └── typeid.h │ │ ├── eigen.h │ │ ├── embed.h │ │ ├── eval.h │ │ ├── functional.h │ │ ├── iostream.h │ │ ├── numpy.h │ │ ├── operators.h │ │ ├── options.h │ │ ├── pybind11.h │ │ ├── pytypes.h │ │ ├── stl.h │ │ ├── stl_bind.h │ │ └── typeid.h ├── kmeans.py ├── pse.cpp ├── pse.so └── pypse.py ├── predict.py ├── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── cal_recall ├── __init__.py ├── rrc_evaluation_funcs.py └── script.py ├── make_trainfile.py ├── metrics.py ├── schedulers.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pth 3 | *.pyc 4 | *.pyo 5 | *.log 6 | *.tmp 7 | *.pkl 8 | __pycache__/ 9 | .idea/ 10 | output/ -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network 2 | 3 | ![](imgs/paper/PAN.jpg) 4 | 5 | ## Requirements 6 | * pytorch 1.1+ 7 | * torchvision 0.3+ 8 | * pyclipper 9 | * opencv3 10 | * gcc 4.9+ 11 | 12 | ## Download 13 | 14 | `PAN_resnet18_FPEM_FFM` and `PAN_resnet18_FPEM_FFM` on icdar2015: 15 | 16 | the updated model(resnet18:78.8,shufflenetv2: 72.4,lr:le-3) is not the best model 17 | 18 | [google drive](https://drive.google.com/drive/folders/1bKPQEEOJ5kgSSRMpnDB8HIRecnD_s4bR?usp=sharing) 19 | 20 | ## Data Preparation 21 | 22 | train: prepare a text in the following format, use '\t' as a separator 23 | ```bash 24 | /path/to/img.jpg path/to/label.txt 25 | ... 26 | ``` 27 | val: 28 | use a folder 29 | ```bash 30 | img/ store img 31 | gt/ store gt file 32 | ``` 33 | 34 | ## Train 35 | 1. config the `train_data_path`,`val_data_path`in [config.json](config.json) 36 | 2. use following script to run 37 | ```sh 38 | python3 train.py 39 | ``` 40 | 41 | ## Test 42 | 43 | [eval.py](eval.py) is used to test model on test dataset 44 | 45 | 1. config `model_path`, `img_path`, `gt_path`, `save_path` in [eval.py](eval.py) 46 | 2. use following script to test 47 | ```sh 48 | python3 eval.py 49 | ``` 50 | 51 | ## Predict 52 | [predict.py](predict.py) is used to inference on single image 53 | 54 | 1. config `model_path`, `img_path`, in [predict.py](predict.py) 55 | 2. use following script to predict 56 | ```sh 57 | python3 predict.py 58 | ``` 59 | 60 | The project is still under development. 61 | 62 |

Performance

63 | 64 | ### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4) 65 | only train on ICDAR2015 dataset 66 | 67 | | Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS | 68 | |:--------------------------:|:-------:|:--------:|:--------:|:------------:|:---------------:|:-----:| 69 | | paper(resnet18) | 736 |x | x | x | 80.4 | 26.1 | 70 | | my (ShuffleNetV2+FPEM_FFM+pse扩张) |736 |1e-3| 81.72 | 66.73 | 73.47 | 24.71 (P100)| 71 | | my (resnet18+FPEM_FFM+pse扩张) |736 |1e-3| 84.93 | 74.09 | 79.14 | 21.31 (P100)| 72 | | my (resnet50+FPEM_FFM+pse扩张) |736 |1e-3| 84.23 | 76.12 | 79.96 | 14.22 (P100)| 73 | | my (ShuffleNetV2+FPEM_FFM+pse扩张) |736 |1e-4| 75.14 | 57.34 | 65.04 | 24.71 (P100)| 74 | | my (resnet18+FPEM_FFM+pse扩张) |736 |1e-4| 83.89 | 69.23 | 75.86 | 21.31 (P100)| 75 | | my (resnet50+FPEM_FFM+pse扩张) |736 |1e-4| 85.29 | 75.1 | 79.87 | 14.22 (P100)| 76 | | my (resnet18+FPN+pse扩张) | 736 |1e-3| 76.50 | 74.70 | 75.59 | 14.47 (P100)| 77 | | my (resnet50+FPN+pse扩张) | 736 |1e-3| 71.82 | 75.73 | 73.72 | 10.67 (P100)| 78 | | my (resnet18+FPN+pse扩张) | 736 |1e-4| 74.19 | 72.34 | 73.25 | 14.47 (P100)| 79 | | my (resnet50+FPN+pse扩张) | 736 |1e-4| 78.96 | 76.27 | 77.59 | 10.67 (P100)| 80 | 81 | ### examples 82 | ![](imgs/example/img_2.jpg) 83 | 84 | ![](imgs/example/img_10.jpg) 85 | 86 | ![](imgs/example/img_29.jpg) 87 | 88 | ![](imgs/example/img_75.jpg) 89 | 90 | ![](imgs/example/img_91.jpg) 91 | 92 | ### todo 93 | - [ ] MobileNet backbone 94 | 95 | - [x] ShuffleNet backbone 96 | ### reference 97 | 1. https://arxiv.org/pdf/1908.05900.pdf 98 | 2. https://github.com/WenmuZhou/PSENet.pytorch 99 | 100 | **If this repository helps you,please star it. Thanks.** -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/__init__.py -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:50 3 | # @Author : zhoujun 4 | 5 | import os 6 | import shutil 7 | import pathlib 8 | from pprint import pformat 9 | import torch 10 | from torch import nn 11 | 12 | from utils import setup_logger 13 | 14 | 15 | class BaseTrainer: 16 | def __init__(self, config, model, criterion, weights_init): 17 | config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent), 18 | config['trainer']['output_dir']) 19 | config['name'] = config['name'] + '_' + model.name 20 | self.save_dir = os.path.join(config['trainer']['output_dir'], config['name']) 21 | self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') 22 | 23 | if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '': 24 | shutil.rmtree(self.save_dir, ignore_errors=True) 25 | if not os.path.exists(self.checkpoint_dir): 26 | os.makedirs(self.checkpoint_dir) 27 | 28 | self.global_step = 0 29 | self.start_epoch = 1 30 | self.config = config 31 | 32 | self.model = model 33 | self.criterion = criterion 34 | # logger and tensorboard 35 | self.tensorboard_enable = self.config['trainer']['tensorboard'] 36 | self.epochs = self.config['trainer']['epochs'] 37 | self.display_interval = self.config['trainer']['display_interval'] 38 | if self.tensorboard_enable: 39 | from torch.utils.tensorboard import SummaryWriter 40 | self.writer = SummaryWriter(self.save_dir) 41 | 42 | self.logger = setup_logger(os.path.join(self.save_dir, 'train_log')) 43 | self.logger.info(pformat(self.config)) 44 | 45 | # device 46 | torch.manual_seed(self.config['trainer']['seed']) # 为CPU设置随机种子 47 | if len(self.config['trainer']['gpus']) > 0 and torch.cuda.is_available(): 48 | self.with_cuda = True 49 | torch.backends.cudnn.benchmark = True 50 | self.logger.info( 51 | 'train with gpu {} and pytorch {}'.format(self.config['trainer']['gpus'], torch.__version__)) 52 | self.gpus = {i: item for i, item in enumerate(self.config['trainer']['gpus'])} 53 | self.device = torch.device("cuda:0") 54 | torch.cuda.manual_seed(self.config['trainer']['seed']) # 为当前GPU设置随机种子 55 | torch.cuda.manual_seed_all(self.config['trainer']['seed']) # 为所有GPU设置随机种子 56 | else: 57 | self.with_cuda = False 58 | self.logger.info('train with cpu and pytorch {}'.format(torch.__version__)) 59 | self.device = torch.device("cpu") 60 | self.logger.info('device {}'.format(self.device)) 61 | self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf'), 'best_model': ''} 62 | 63 | self.optimizer = self._initialize('optimizer', torch.optim, model.parameters()) 64 | 65 | if self.config['trainer']['resume_checkpoint'] != '': 66 | self._laod_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True) 67 | elif self.config['trainer']['finetune_checkpoint'] != '': 68 | self._laod_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False) 69 | else: 70 | if weights_init is not None: 71 | model.apply(weights_init) 72 | if self.config['lr_scheduler']['type'] != 'PolynomialLR': 73 | self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer) 74 | 75 | # 单机多卡 76 | num_gpus = torch.cuda.device_count() 77 | if num_gpus > 1: 78 | self.model = nn.DataParallel(self.model) 79 | 80 | self.model.to(self.device) 81 | 82 | if self.tensorboard_enable: 83 | try: 84 | # add graph 85 | dummy_input = torch.zeros(1, self.config['data_loader']['args']['dataset']['img_channel'], 86 | self.config['data_loader']['args']['dataset']['input_size'], 87 | self.config['data_loader']['args']['dataset']['input_size']).to(self.device) 88 | self.writer.add_graph(model, dummy_input) 89 | except: 90 | import traceback 91 | # self.logger.error(traceback.format_exc()) 92 | self.logger.warn('add graph to tensorboard failed') 93 | 94 | def train(self): 95 | """ 96 | Full training logic 97 | """ 98 | for epoch in range(self.start_epoch, self.epochs + 1): 99 | try: 100 | self.epoch_result = self._train_epoch(epoch) 101 | if self.config['lr_scheduler']['type'] != 'PolynomialLR': 102 | self.scheduler.step() 103 | self._on_epoch_finish() 104 | except torch.cuda.CudaError: 105 | self._log_memory_usage() 106 | if self.tensorboard_enable: 107 | self.writer.close() 108 | self._on_train_finish() 109 | 110 | def _train_epoch(self, epoch): 111 | """ 112 | Training logic for an epoch 113 | 114 | :param epoch: Current epoch number 115 | """ 116 | raise NotImplementedError 117 | 118 | def _eval(self): 119 | """ 120 | eval logic for an epoch 121 | 122 | :param epoch: Current epoch number 123 | """ 124 | raise NotImplementedError 125 | 126 | def _on_epoch_finish(self): 127 | raise NotImplementedError 128 | 129 | def _on_train_finish(self): 130 | raise NotImplementedError 131 | 132 | def _log_memory_usage(self): 133 | if not self.with_cuda: 134 | return 135 | 136 | template = """Memory Usage: \n{}""" 137 | usage = [] 138 | for deviceID, device in self.gpus.items(): 139 | deviceID = int(deviceID) 140 | allocated = torch.cuda.memory_allocated(deviceID) / (1024 * 1024) 141 | cached = torch.cuda.memory_cached(deviceID) / (1024 * 1024) 142 | 143 | usage.append(' CUDA: {} Allocated: {} MB Cached: {} MB \n'.format(device, allocated, cached)) 144 | 145 | content = ''.join(usage) 146 | content = template.format(content) 147 | 148 | self.logger.debug(content) 149 | 150 | def _save_checkpoint(self, epoch, file_name, save_best=False): 151 | """ 152 | Saving checkpoints 153 | 154 | :param epoch: current epoch number 155 | :param log: logging information of the epoch 156 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' 157 | """ 158 | state = { 159 | 'epoch': epoch, 160 | 'global_step': self.global_step, 161 | 'state_dict': self.model.state_dict(), 162 | 'optimizer': self.optimizer.state_dict(), 163 | 'scheduler': self.scheduler.state_dict(), 164 | 'config': self.config, 165 | 'metrics': self.metrics 166 | } 167 | filename = os.path.join(self.checkpoint_dir, file_name) 168 | torch.save(state, filename) 169 | if save_best: 170 | shutil.copy(filename, os.path.join(self.checkpoint_dir, 'model_best.pth')) 171 | self.logger.info("Saving current best: {}".format(file_name)) 172 | else: 173 | self.logger.info("Saving checkpoint: {}".format(filename)) 174 | 175 | def _laod_checkpoint(self, checkpoint_path, resume): 176 | """ 177 | Resume from saved checkpoints 178 | :param checkpoint_path: Checkpoint path to be resumed 179 | """ 180 | self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path)) 181 | checkpoint = torch.load(checkpoint_path) 182 | self.model.load_state_dict(checkpoint['state_dict']) 183 | if resume: 184 | self.global_step = checkpoint['global_step'] 185 | self.start_epoch = checkpoint['epoch'] + 1 186 | self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch 187 | # self.scheduler.load_state_dict(checkpoint['scheduler']) 188 | self.optimizer.load_state_dict(checkpoint['optimizer']) 189 | if 'metrics' in checkpoint: 190 | self.metrics = checkpoint['metrics'] 191 | if self.with_cuda: 192 | for state in self.optimizer.state.values(): 193 | for k, v in state.items(): 194 | if isinstance(v, torch.Tensor): 195 | state[k] = v.to(self.device) 196 | self.logger.info("resume from checkpoint {} (epoch {})".format(checkpoint_path, self.start_epoch)) 197 | else: 198 | self.logger.info("finetune from checkpoint {}".format(checkpoint_path)) 199 | 200 | def _initialize(self, name, module, *args, **kwargs): 201 | module_name = self.config[name]['type'] 202 | module_args = self.config[name]['args'] 203 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 204 | module_args.update(kwargs) 205 | return getattr(module, module_name)(*args, **module_args) 206 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PAN", 3 | "data_loader": { 4 | "type": "ImageDataset", 5 | "args": { 6 | "dataset": { 7 | "train_data_path": [ 8 | [ 9 | "dataset1.txt1", 10 | "dataset1.txt2" 11 | ], 12 | [ 13 | "dataset2.txt1", 14 | "dataset2.txt2" 15 | ] 16 | ], 17 | "train_data_ratio": [ 18 | 0.5, 19 | 0.5 20 | ], 21 | "val_data_path": "path/to/test/", 22 | "input_size": 640, 23 | "img_channel": 3, 24 | "shrink_ratio": 0.5 25 | }, 26 | "loader": { 27 | "validation_split": 0.1, 28 | "train_batch_size": 16, 29 | "shuffle": true, 30 | "pin_memory": false, 31 | "num_workers": 6 32 | } 33 | } 34 | }, 35 | "arch": { 36 | "type": "PANModel", 37 | "args": { 38 | "backbone": "resnet18", 39 | "fpem_repeat": 2, 40 | "pretrained": true, 41 | "segmentation_head": "FPEM_FFM" 42 | } 43 | }, 44 | "loss": { 45 | "type": "PANLoss", 46 | "args": { 47 | "alpha": 0.5, 48 | "beta": 0.25, 49 | "delta_agg": 0.5, 50 | "delta_dis": 3, 51 | "ohem_ratio": 3 52 | } 53 | }, 54 | "optimizer": { 55 | "type": "Adam", 56 | "args": { 57 | "lr": 0.001, 58 | "weight_decay": 0, 59 | "amsgrad": true 60 | } 61 | }, 62 | "lr_scheduler": { 63 | "type": "StepLR", 64 | "args": { 65 | "step_size": 200, 66 | "gamma": 0.1 67 | } 68 | }, 69 | "trainer": { 70 | "seed": 2, 71 | "gpus": [ 72 | 0 73 | ], 74 | "epochs": 600, 75 | "display_interval": 10, 76 | "show_images_interval": 50, 77 | "resume_checkpoint": "", 78 | "finetune_checkpoint": "", 79 | "output_dir": "output", 80 | "tensorboard": true, 81 | "metrics": "hmean" 82 | } 83 | } -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:50 3 | # @Author : zhoujun -------------------------------------------------------------------------------- /config/default.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:51 3 | # @Author : zhoujun 4 | 5 | name = 'PAN' 6 | arch = { 7 | "type": "PANModel", # name of model architecture to train 8 | "args": { 9 | 'backbone': 'resnet18', 10 | 'fpem_repeat': 2, # fpem模块重复的次数 11 | 'pretrained': True, # backbone 是否使用imagesnet的预训练模型 12 | 'segmentation_head': 'FPN' #分割头,FPN or FPEM_FFM 13 | } 14 | } 15 | 16 | 17 | data_loader = { 18 | "type": "ImageDataset", # selecting data loader 19 | "args": { 20 | 'dataset': { 21 | 'train_data_path': [['dataset1.txt1', 'dataset1.txt2'], ['dataset2.txt1', 'dataset2.txt2']], 22 | 'train_data_ratio': [0.5, 0.5], 23 | 'val_data_path': ['path/to/test/'], 24 | 'input_size': 640, 25 | 'img_channel': 3, 26 | 'shrink_ratio': 0.5 # cv or PIL 27 | }, 28 | 'loader': { 29 | 'validation_split': 0.1, 30 | 'train_batch_size': 16, 31 | 'val_batch_size': 4, 32 | 'shuffle': True, 33 | 'pin_memory': False, 34 | 'num_workers': 6 35 | } 36 | } 37 | } 38 | loss = { 39 | "type": "PANLoss", # name of model architecture to train 40 | "args": { 41 | 'alpha': 0.5, 42 | 'beta': 0.25, 43 | 'delta_agg': 0.5, 44 | 'delta_dis': 3, 45 | 'ohem_ratio': 3 46 | } 47 | } 48 | 49 | optimizer = { 50 | "type": "Adam", 51 | "args": { 52 | "lr": 0.001, 53 | "weight_decay": 0, 54 | "amsgrad": True 55 | } 56 | } 57 | 58 | lr_scheduler = { 59 | "type": "StepLR", 60 | "args": { 61 | "step_size": 200, 62 | "gamma": 0.1 63 | } 64 | } 65 | 66 | resume = { 67 | 'restart_training': True, 68 | 'checkpoint': '' 69 | } 70 | 71 | trainer = { 72 | # random seed 73 | 'seed': 2, 74 | 'gpus': [0], 75 | 'epochs': 600, 76 | 'display_interval': 10, 77 | 'show_images_interval': 50, 78 | 'resume': resume, 79 | 'output_dir': 'output', 80 | 'tensorboard': True 81 | } 82 | 83 | config_dict = {} 84 | config_dict['name'] = name 85 | config_dict['data_loader'] = data_loader 86 | config_dict['arch'] = arch 87 | config_dict['loss'] = loss 88 | config_dict['optimizer'] = optimizer 89 | config_dict['lr_scheduler'] = lr_scheduler 90 | config_dict['trainer'] = trainer 91 | 92 | from utils import save_json 93 | 94 | save_json(config_dict, '../config.json') 95 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:52 3 | # @Author : zhoujun 4 | 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | import copy 8 | import pathlib 9 | from . import dataset 10 | 11 | 12 | def get_datalist(train_data_path, validation_split=0.1): 13 | """ 14 | 获取训练和验证的数据list 15 | :param train_data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ 16 | :param validation_split: 验证集的比例,当val_data_path为空时使用 17 | :return: 18 | """ 19 | train_data_list = [] 20 | for train_path in train_data_path: 21 | train_data = [] 22 | for p in train_path: 23 | with open(p, 'r', encoding='utf-8') as f: 24 | for line in f.readlines(): 25 | line = line.strip('\n').replace('.jpg ', '.jpg\t').split('\t') 26 | if len(line) > 1: 27 | img_path = pathlib.Path(line[0].strip(' ')) 28 | label_path = pathlib.Path(line[1].strip(' ')) 29 | if img_path.exists() and img_path.stat().st_size > 0 and label_path.exists() and label_path.stat().st_size > 0: 30 | train_data.append((str(img_path), str(label_path))) 31 | train_data_list.append(train_data) 32 | return train_data_list 33 | 34 | 35 | def get_dataset(data_list, module_name, transform, dataset_args): 36 | """ 37 | 获取训练dataset 38 | :param data_list: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ 39 | :param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset 40 | :param transform: 该数据集使用的transforms 41 | :param dataset_args: module_name的参数 42 | :return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None 43 | """ 44 | s_dataset = getattr(dataset, module_name)(transform=transform, data_list=data_list, 45 | **dataset_args) 46 | return s_dataset 47 | 48 | 49 | def get_dataloader(module_name, module_args): 50 | train_transfroms = transforms.Compose([ 51 | transforms.ColorJitter(brightness=0.5), 52 | transforms.ToTensor() 53 | ]) 54 | 55 | # 创建数据集 56 | dataset_args = copy.deepcopy(module_args['dataset']) 57 | train_data_path = dataset_args.pop('train_data_path') 58 | train_data_ratio = dataset_args.pop('train_data_ratio') 59 | dataset_args.pop('val_data_path') 60 | train_data_list = get_datalist(train_data_path, module_args['loader']['validation_split']) 61 | train_dataset_list = [] 62 | for train_data in train_data_list: 63 | train_dataset_list.append(get_dataset(data_list=train_data, 64 | module_name=module_name, 65 | transform=train_transfroms, 66 | dataset_args=dataset_args)) 67 | 68 | if len(train_dataset_list) > 1: 69 | train_loader = dataset.Batch_Balanced_Dataset(dataset_list=train_dataset_list, 70 | ratio_list=train_data_ratio, 71 | module_args=module_args, 72 | phase='train') 73 | elif len(train_dataset_list) == 1: 74 | train_loader = DataLoader(dataset=train_dataset_list[0], 75 | batch_size=module_args['loader']['train_batch_size'], 76 | shuffle=module_args['loader']['shuffle'], 77 | num_workers=module_args['loader']['num_workers']) 78 | train_loader.dataset_len = len(train_dataset_list[0]) 79 | else: 80 | raise Exception('no images found') 81 | return train_loader 82 | -------------------------------------------------------------------------------- /data_loader/augment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:52 3 | # @Author : zhoujun 4 | 5 | import cv2 6 | import numbers 7 | import math 8 | import random 9 | import numpy as np 10 | from skimage.util import random_noise 11 | 12 | 13 | def show_pic(img, bboxes=None, name='pic'): 14 | ''' 15 | 输入: 16 | img:图像array 17 | bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....] 18 | names:每个box对应的名称 19 | ''' 20 | show_img = img.copy() 21 | if not isinstance(bboxes, np.ndarray): 22 | bboxes = np.array(bboxes) 23 | for point in bboxes.astype(np.int): 24 | cv2.line(show_img, tuple(point[0]), tuple(point[1]), (255, 0, 0), 2) 25 | cv2.line(show_img, tuple(point[1]), tuple(point[2]), (255, 0, 0), 2) 26 | cv2.line(show_img, tuple(point[2]), tuple(point[3]), (255, 0, 0), 2) 27 | cv2.line(show_img, tuple(point[3]), tuple(point[0]), (255, 0, 0), 2) 28 | # cv2.namedWindow(name, 0) # 1表示原图 29 | # cv2.moveWindow(name, 0, 0) 30 | # cv2.resizeWindow(name, 1200, 800) # 可视化的图片大小 31 | cv2.imshow(name, show_img) 32 | 33 | 34 | # 图像均为cv2读取 35 | class DataAugment(): 36 | def __init__(self): 37 | pass 38 | 39 | def add_noise(self, im: np.ndarray): 40 | """ 41 | 对图片加噪声 42 | :param img: 图像array 43 | :return: 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255 44 | """ 45 | return (random_noise(im, mode='gaussian', clip=True) * 255).astype(im.dtype) 46 | 47 | def random_scale(self, im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray or list) -> tuple: 48 | """ 49 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 50 | :param im: 原图 51 | :param text_polys: 文本框 52 | :param scales: 尺度 53 | :return: 经过缩放的图片和文本 54 | """ 55 | tmp_text_polys = text_polys.copy() 56 | rd_scale = float(np.random.choice(scales)) 57 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) 58 | tmp_text_polys *= rd_scale 59 | return im, tmp_text_polys 60 | 61 | def random_rotate_img_bbox(self, img, text_polys, degrees: numbers.Number or list or tuple or np.ndarray, 62 | same_size=False): 63 | """ 64 | 从给定的角度中选择一个角度,对图片和文本框进行旋转 65 | :param img: 图片 66 | :param text_polys: 文本框 67 | :param degrees: 角度,可以是一个数值或者list 68 | :param same_size: 是否保持和原图一样大 69 | :return: 旋转后的图片和角度 70 | """ 71 | if isinstance(degrees, numbers.Number): 72 | if degrees < 0: 73 | raise ValueError("If degrees is a single number, it must be positive.") 74 | degrees = (-degrees, degrees) 75 | elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray): 76 | if len(degrees) != 2: 77 | raise ValueError("If degrees is a sequence, it must be of len 2.") 78 | degrees = degrees 79 | else: 80 | raise Exception('degrees must in Number or list or tuple or np.ndarray') 81 | # ---------------------- 旋转图像 ---------------------- 82 | w = img.shape[1] 83 | h = img.shape[0] 84 | angle = np.random.uniform(degrees[0], degrees[1]) 85 | 86 | if same_size: 87 | nw = w 88 | nh = h 89 | else: 90 | # 角度变弧度 91 | rangle = np.deg2rad(angle) 92 | # 计算旋转之后图像的w, h 93 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) 94 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) 95 | # 构造仿射矩阵 96 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1) 97 | # 计算原图中心点到新图中心点的偏移量 98 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) 99 | # 更新仿射矩阵 100 | rot_mat[0, 2] += rot_move[0] 101 | rot_mat[1, 2] += rot_move[1] 102 | # 仿射变换 103 | rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4) 104 | 105 | # ---------------------- 矫正bbox坐标 ---------------------- 106 | # rot_mat是最终的旋转矩阵 107 | # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下 108 | rot_text_polys = list() 109 | for bbox in text_polys: 110 | point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) 111 | point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) 112 | point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) 113 | point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) 114 | rot_text_polys.append([point1, point2, point3, point4]) 115 | return rot_img, np.array(rot_text_polys, dtype=np.float32) 116 | 117 | def random_crop(self, imgs, img_size): 118 | h, w = imgs[0].shape[0:2] 119 | th, tw = img_size 120 | if w == tw and h == th: 121 | return imgs 122 | 123 | # label中存在文本实例,并且按照概率进行裁剪 124 | if np.max(imgs[1][:, :, 0]) > 0 and random.random() > 3.0 / 8.0: 125 | # 文本实例的top left点 126 | tl = np.min(np.where(imgs[1][:, :, 0] > 0), axis=1) - img_size 127 | tl[tl < 0] = 0 128 | # 文本实例的 bottom right 点 129 | br = np.max(np.where(imgs[1][:, :, 0] > 0), axis=1) - img_size 130 | br[br < 0] = 0 131 | # 保证选到右下角点是,有足够的距离进行crop 132 | br[0] = min(br[0], h - th) 133 | br[1] = min(br[1], w - tw) 134 | for _ in range(50000): 135 | i = random.randint(tl[0], br[0]) 136 | j = random.randint(tl[1], br[1]) 137 | # 保证最小的图有文本 138 | if imgs[1][:, :, -1][i:i + th, j:j + tw].sum() <= 0: 139 | continue 140 | else: 141 | break 142 | i = random.randint(tl[0], br[0]) 143 | j = random.randint(tl[1], br[1]) 144 | else: 145 | i = random.randint(0, h - th) 146 | j = random.randint(0, w - tw) 147 | 148 | # return i, j, th, tw 149 | for idx in range(len(imgs)): 150 | if len(imgs[idx].shape) == 3: 151 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] 152 | else: 153 | imgs[idx] = imgs[idx][i:i + th, j:j + tw] 154 | return imgs 155 | 156 | def resize(self, im: np.ndarray, text_polys: np.ndarray, 157 | input_size: numbers.Number or list or tuple or np.ndarray, keep_ratio: bool = False) -> tuple: 158 | """ 159 | 对图片和文本框进行resize 160 | :param im: 图片 161 | :param text_polys: 文本框 162 | :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h] 163 | :param keep_ratio: 是否保持长宽比 164 | :return: resize后的图片和文本框 165 | """ 166 | if isinstance(input_size, numbers.Number): 167 | if input_size < 0: 168 | raise ValueError("If input_size is a single number, it must be positive.") 169 | input_size = (input_size, input_size) 170 | elif isinstance(input_size, list) or isinstance(input_size, tuple) or isinstance(input_size, np.ndarray): 171 | if len(input_size) != 2: 172 | raise ValueError("If input_size is a sequence, it must be of len 2.") 173 | input_size = (input_size[0], input_size[1]) 174 | else: 175 | raise Exception('input_size must in Number or list or tuple or np.ndarray') 176 | if keep_ratio: 177 | # 将图片短边pad到和长边一样 178 | h, w, c = im.shape 179 | max_h = max(h, input_size[0]) 180 | max_w = max(w, input_size[1]) 181 | im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8) 182 | im_padded[:h, :w] = im.copy() 183 | im = im_padded 184 | text_polys = text_polys.astype(np.float32) 185 | h, w, _ = im.shape 186 | im = cv2.resize(im, input_size) 187 | w_scale = input_size[0] / float(w) 188 | h_scale = input_size[1] / float(h) 189 | text_polys[:, :, 0] *= w_scale 190 | text_polys[:, :, 1] *= h_scale 191 | return im, text_polys 192 | 193 | def horizontal_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple: 194 | """ 195 | 对图片和文本框进行水平翻转 196 | :param im: 图片 197 | :param text_polys: 文本框 198 | :return: 水平翻转之后的图片和文本框 199 | """ 200 | flip_text_polys = text_polys.copy() 201 | flip_im = cv2.flip(im, 1) 202 | h, w, _ = flip_im.shape 203 | flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0] 204 | return flip_im, flip_text_polys 205 | 206 | def vertical_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple: 207 | """ 208 | 对图片和文本框进行竖直翻转 209 | :param im: 图片 210 | :param text_polys: 文本框 211 | :return: 竖直翻转之后的图片和文本框 212 | """ 213 | flip_text_polys = text_polys.copy() 214 | flip_im = cv2.flip(im, 0) 215 | h, w, _ = flip_im.shape 216 | flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1] 217 | return flip_im, flip_text_polys 218 | 219 | def test(self, im: np.ndarray, text_polys: np.ndarray): 220 | print('随机尺度缩放') 221 | t_im, t_text_polys = self.random_scale(im, text_polys, [0.5, 1, 2, 3]) 222 | print(t_im.shape, t_text_polys.dtype) 223 | show_pic(t_im, t_text_polys, 'random_scale') 224 | 225 | print('随机旋转') 226 | t_im, t_text_polys = self.random_rotate_img_bbox(im, text_polys, 10) 227 | print(t_im.shape, t_text_polys.dtype) 228 | show_pic(t_im, t_text_polys, 'random_rotate_img_bbox') 229 | 230 | print('随机裁剪') 231 | t_im, t_text_polys = self.random_crop_img_bboxes(im, text_polys) 232 | print(t_im.shape, t_text_polys.dtype) 233 | show_pic(t_im, t_text_polys, 'random_crop_img_bboxes') 234 | 235 | print('水平翻转') 236 | t_im, t_text_polys = self.horizontal_flip(im, text_polys) 237 | print(t_im.shape, t_text_polys.dtype) 238 | show_pic(t_im, t_text_polys, 'horizontal_flip') 239 | 240 | print('竖直翻转') 241 | t_im, t_text_polys = self.vertical_flip(im, text_polys) 242 | print(t_im.shape, t_text_polys.dtype) 243 | show_pic(t_im, t_text_polys, 'vertical_flip') 244 | show_pic(im, text_polys, 'vertical_flip_ori') 245 | 246 | print('加噪声') 247 | t_im = self.add_noise(im) 248 | print(t_im.shape) 249 | show_pic(t_im, text_polys, 'add_noise') 250 | show_pic(im, text_polys, 'add_noise_ori') 251 | -------------------------------------------------------------------------------- /data_loader/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:53 3 | # @Author : zhoujun 4 | import math 5 | import random 6 | import pyclipper 7 | import numpy as np 8 | import cv2 9 | from data_loader.augment import DataAugment 10 | 11 | data_aug = DataAugment() 12 | 13 | 14 | def check_and_validate_polys(polys, xxx_todo_changeme): 15 | ''' 16 | check so that the text poly is in the same direction, 17 | and also filter some invalid polygons 18 | :param polys: 19 | :param tags: 20 | :return: 21 | ''' 22 | (h, w) = xxx_todo_changeme 23 | if polys.shape[0] == 0: 24 | return polys 25 | polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) # x coord not max w-1, and not min 0 26 | polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) # y coord not max h-1, and not min 0 27 | 28 | validated_polys = [] 29 | for poly in polys: 30 | p_area = cv2.contourArea(poly) 31 | if abs(p_area) < 1: 32 | continue 33 | validated_polys.append(poly) 34 | return np.array(validated_polys) 35 | 36 | def unshrink_offset(poly,ratio): 37 | area = cv2.contourArea(poly) 38 | peri = cv2.arcLength(poly, True) 39 | a = 8 40 | b = peri - 4 41 | c = 1-0.5 * peri - area/ratio 42 | return quadratic(a,b,c) 43 | 44 | def quadratic(a, b, c): 45 | if (b * b - 4 * a * c) < 0: 46 | return 'None' 47 | Delte = math.sqrt(b * b - 4 * a * c) 48 | if Delte > 0: 49 | x = (- b + Delte) / (2 * a) 50 | y = (- b - Delte) / (2 * a) 51 | return x, y 52 | else: 53 | x = (- b) / (2 * a) 54 | return x 55 | 56 | def generate_rbox(im_size, text_polys, text_tags,training_mask, shrink_ratio): 57 | """ 58 | 生成mask图,白色部分是文本,黑色是北京 59 | :param im_size: 图像的h,w 60 | :param text_polys: 框的坐标 61 | :param text_tags: 标注文本框是否参与训练 62 | :param training_mask: 忽略标注为 DO NOT CARE 的矩阵 63 | :return: 生成的mask图 64 | """ 65 | h, w = im_size 66 | score_map = np.zeros((h, w), dtype=np.uint8) 67 | for i, (poly, tag) in enumerate(zip(text_polys, text_tags)): 68 | try: 69 | poly = poly.astype(np.int) 70 | # d_i = cv2.contourArea(poly) * (1 - shrink_ratio * shrink_ratio) / cv2.arcLength(poly, True) 71 | d_i = cv2.contourArea(poly) * (1 - shrink_ratio) / cv2.arcLength(poly, True) + 0.5 72 | pco = pyclipper.PyclipperOffset() 73 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 74 | shrinked_poly = np.array(pco.Execute(-d_i)) 75 | cv2.fillPoly(score_map, shrinked_poly, i + 1) 76 | if not tag: 77 | cv2.fillPoly(training_mask, shrinked_poly, 0) 78 | except: 79 | print(poly) 80 | return score_map, training_mask 81 | 82 | 83 | def augmentation(im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray, degrees: int) -> tuple: 84 | # the images are rescaled with ratio {0.5, 1.0, 2.0, 3.0} randomly 85 | im, text_polys = data_aug.random_scale(im, text_polys, scales) 86 | # the images are horizontally fliped and rotated in range [−10◦, 10◦] randomly 87 | if random.random() < 0.5: 88 | im, text_polys = data_aug.horizontal_flip(im, text_polys) 89 | if random.random() < 0.5: 90 | im, text_polys = data_aug.random_rotate_img_bbox(im, text_polys, degrees) 91 | return im, text_polys 92 | 93 | 94 | def image_label(im: np.ndarray, text_polys: np.ndarray, text_tags: list, input_size: int = 640, 95 | shrink_ratio: float = 0.5, degrees: int = 10, 96 | scales: np.ndarray = np.array([0.5, 1, 2.0, 3.0])) -> tuple: 97 | """ 98 | 读取图片并生成label 99 | :param im: 图片 100 | :param text_polys: 文本标注框 101 | :param text_tags: 是否忽略文本的标致:true 忽略, false 不忽略 102 | :param input_size: 输出图像的尺寸 103 | :param shrink_ratio: gt收缩的比例 104 | :param degrees: 随机旋转的角度 105 | :param scales: 随机缩放的尺度 106 | :return: 107 | """ 108 | h, w, _ = im.shape 109 | # 检查越界 110 | text_polys = check_and_validate_polys(text_polys, (h, w)) 111 | im, text_polys = augmentation(im, text_polys, scales, degrees) 112 | 113 | h, w, _ = im.shape 114 | short_edge = min(h, w) 115 | if short_edge < input_size: 116 | # 保证短边 >= inputsize 117 | scale = input_size / short_edge 118 | im = cv2.resize(im, dsize=None, fx=scale, fy=scale) 119 | text_polys *= scale 120 | 121 | h, w, _ = im.shape 122 | training_mask = np.ones((h, w), dtype=np.uint8) 123 | score_maps = [] 124 | for i in (1, shrink_ratio): 125 | score_map, training_mask = generate_rbox((h, w), text_polys, text_tags,training_mask, i) 126 | score_maps.append(score_map) 127 | score_maps = np.array(score_maps, dtype=np.float32) 128 | imgs = data_aug.random_crop([im, score_maps.transpose((1, 2, 0)), training_mask], (input_size, input_size)) 129 | return imgs[0], imgs[1].transpose((2, 0, 1)), imgs[2] # im,score_maps,training_mask# 130 | 131 | if __name__ == '__main__': 132 | poly = np.array([377,117,463,117,465,130,378,130]).reshape(-1,2) 133 | shrink_ratio = 0.5 134 | d_i = cv2.contourArea(poly) * (1 - shrink_ratio) / cv2.arcLength(poly, True) + 0.5 135 | pco = pyclipper.PyclipperOffset() 136 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 137 | shrinked_poly = np.array(pco.Execute(-d_i)) 138 | print(d_i) 139 | print(cv2.contourArea(shrinked_poly.astype(int)) / cv2.contourArea(poly)) 140 | print(unshrink_offset(shrinked_poly,shrink_ratio)) 141 | -------------------------------------------------------------------------------- /data_loader/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:54 3 | # @Author : zhoujun 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | from data_loader.data_utils import image_label 9 | from utils import order_points_clockwise 10 | 11 | 12 | class ImageDataset(Dataset): 13 | def __init__(self, data_list: list, input_size: int, img_channel: int, shrink_ratio: float, transform=None, 14 | target_transform=None): 15 | self.data_list = self.load_data(data_list) 16 | self.input_size = input_size 17 | self.img_channel = img_channel 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | self.shrink_ratio = shrink_ratio 21 | 22 | def __getitem__(self, index): 23 | img_path, text_polys, text_tags = self.data_list[index] 24 | im = cv2.imread(img_path, 1 if self.img_channel == 3 else 0) 25 | if self.img_channel == 3: 26 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 27 | img, score_map, training_mask = image_label(im, text_polys, text_tags, self.input_size, 28 | self.shrink_ratio) 29 | # img = draw_bbox(img,text_polys) 30 | img = Image.fromarray(img) 31 | if self.transform: 32 | img = self.transform(img) 33 | if self.target_transform: 34 | score_map = self.target_transform(score_map) 35 | training_mask = self.target_transform(training_mask) 36 | return img, score_map, training_mask 37 | 38 | def load_data(self, data_list: list) -> list: 39 | t_data_list = [] 40 | for img_path, label_path in data_list: 41 | bboxs, text_tags = self._get_annotation(label_path) 42 | if len(bboxs) > 0: 43 | t_data_list.append((img_path, bboxs, text_tags)) 44 | else: 45 | print('there is no suit bbox in {}'.format(label_path)) 46 | return t_data_list 47 | 48 | def _get_annotation(self, label_path: str) -> tuple: 49 | boxes = [] 50 | text_tags = [] 51 | with open(label_path, encoding='utf-8', mode='r') as f: 52 | for line in f.readlines(): 53 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',') 54 | try: 55 | box = order_points_clockwise(np.array(list(map(float, params[:8]))).reshape(-1, 2)) 56 | if cv2.arcLength(box, True) > 0: 57 | boxes.append(box) 58 | label = params[8] 59 | if label == '*' or label == '###': 60 | text_tags.append(False) 61 | else: 62 | text_tags.append(True) 63 | except: 64 | print('load label failed on {}'.format(label_path)) 65 | return np.array(boxes, dtype=np.float32), np.array(text_tags, dtype=np.bool) 66 | 67 | def __len__(self): 68 | return len(self.data_list) 69 | 70 | 71 | class Batch_Balanced_Dataset(object): 72 | def __init__(self, dataset_list: list, ratio_list: list, module_args: dict, 73 | phase: str = 'train'): 74 | """ 75 | 对datasetlist里的dataset按照ratio_list里对应的比例组合,似的每个batch里的数据按按照比例采样的 76 | :param dataset_list: 数据集列表 77 | :param ratio_list: 比例列表 78 | :param module_args: dataloader的配置 79 | :param phase: 训练集还是验证集 80 | """ 81 | assert sum(ratio_list) == 1 and len(dataset_list) == len(ratio_list) 82 | 83 | self.dataset_len = 0 84 | self.data_loader_list = [] 85 | self.dataloader_iter_list = [] 86 | all_batch_size = module_args['loader']['train_batch_size'] if phase == 'train' else module_args['loader'][ 87 | 'val_batch_size'] 88 | for _dataset, batch_ratio_d in zip(dataset_list, ratio_list): 89 | _batch_size = max(round(all_batch_size * float(batch_ratio_d)), 1) 90 | 91 | _data_loader = DataLoader(dataset=_dataset, 92 | batch_size=_batch_size, 93 | shuffle=module_args['loader']['shuffle'], 94 | num_workers=module_args['loader']['num_workers']) 95 | 96 | self.data_loader_list.append(_data_loader) 97 | self.dataloader_iter_list.append(iter(_data_loader)) 98 | self.dataset_len += len(_dataset) 99 | 100 | def __iter__(self): 101 | return self 102 | 103 | def __len__(self): 104 | return min([len(x) for x in self.data_loader_list]) 105 | 106 | def __next__(self): 107 | balanced_batch_images = [] 108 | balanced_batch_score_maps = [] 109 | balanced_batch_training_masks = [] 110 | 111 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 112 | try: 113 | image, score_map, training_mask = next(data_loader_iter) 114 | balanced_batch_images.append(image) 115 | balanced_batch_score_maps.append(score_map) 116 | balanced_batch_training_masks.append(training_mask) 117 | except StopIteration: 118 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 119 | image, score_map, training_mask = next(self.dataloader_iter_list[i]) 120 | balanced_batch_images.append(image) 121 | balanced_batch_score_maps.append(score_map) 122 | balanced_batch_training_masks.append(training_mask) 123 | except ValueError: 124 | pass 125 | 126 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 127 | balanced_batch_score_maps = torch.cat(balanced_batch_score_maps, 0) 128 | balanced_batch_training_masks = torch.cat(balanced_batch_training_masks, 0) 129 | return balanced_batch_images, balanced_batch_score_maps, balanced_batch_training_masks 130 | 131 | 132 | if __name__ == '__main__': 133 | import torch 134 | from utils.util import show_img 135 | from tqdm import tqdm 136 | import matplotlib.pyplot as plt 137 | from torchvision import transforms 138 | 139 | train_data = ImageDataset( 140 | data_list=[ 141 | (r'/data1/zj/ocr/icdar2015/train/img/img_713.jpg', '/data1/zj/ocr/icdar2015/train/gt/gt_img_713.txt')], 142 | input_size=640, 143 | img_channel=3, 144 | shrink_ratio=0.5, 145 | transform=transforms.ToTensor() 146 | ) 147 | train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, num_workers=0) 148 | 149 | pbar = tqdm(total=len(train_loader)) 150 | for i, (img, label, mask) in enumerate(train_loader): 151 | print(label.shape, label[0][0].max()) 152 | print(img.shape) 153 | print(label[0][-1].sum()) 154 | print(mask[0].shape) 155 | # pbar.update(1) 156 | show_img((img[0] * mask[0].to(torch.float)).numpy().transpose(1, 2, 0), color=True) 157 | show_img(label[0]) 158 | show_img(mask[0]) 159 | plt.show() 160 | 161 | pbar.close() 162 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/6/11 15:54 3 | # @Author : zhoujun 4 | import os 5 | import cv2 6 | import torch 7 | import shutil 8 | import numpy as np 9 | from tqdm.auto import tqdm 10 | from predict import Pytorch_model 11 | from utils import cal_recall_precison_f1, draw_bbox 12 | 13 | torch.backends.cudnn.benchmark = True 14 | 15 | 16 | def main(model_path, img_folder, save_path, gpu_id): 17 | if os.path.exists(save_path): 18 | shutil.rmtree(save_path, ignore_errors=True) 19 | if not os.path.exists(save_path): 20 | os.makedirs(save_path) 21 | save_img_folder = os.path.join(save_path, 'img') 22 | if not os.path.exists(save_img_folder): 23 | os.makedirs(save_img_folder) 24 | save_txt_folder = os.path.join(save_path, 'result') 25 | if not os.path.exists(save_txt_folder): 26 | os.makedirs(save_txt_folder) 27 | img_paths = [os.path.join(img_folder, x) for x in os.listdir(img_folder)] 28 | model = Pytorch_model(model_path, gpu_id=gpu_id) 29 | total_frame = 0.0 30 | total_time = 0.0 31 | for img_path in tqdm(img_paths): 32 | img_name = os.path.basename(img_path).split('.')[0] 33 | save_name = os.path.join(save_txt_folder, 'res_' + img_name + '.txt') 34 | _, boxes_list, t = model.predict(img_path) 35 | total_frame += 1 36 | total_time += t 37 | img = draw_bbox(img_path, boxes_list, color=(0, 0, 255)) 38 | cv2.imwrite(os.path.join(save_img_folder, '{}.jpg'.format(img_name)), img) 39 | np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d') 40 | print('fps:{}'.format(total_frame / total_time)) 41 | return save_txt_folder 42 | 43 | 44 | if __name__ == '__main__': 45 | os.environ['CUDA_VISIBLE_DEVICES'] = str('0') 46 | model_path = r'output/PAN_shufflenetv2_FPEM_FFM.pth' 47 | img_path = r'/mnt/e/zj/dataset/icdar2015/test/img' 48 | gt_path = r'/mnt/e/zj/dataset/icdar2015/test/gt' 49 | save_path = './output/result'#model_path.replace('checkpoint/best_model.pth', 'result/') 50 | gpu_id = 0 51 | 52 | save_path = main(model_path, img_path, save_path, gpu_id=gpu_id) 53 | result = cal_recall_precison_f1(gt_path=gt_path, result_path=save_path) 54 | print(result) 55 | -------------------------------------------------------------------------------- /imgs/example/img_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_10.jpg -------------------------------------------------------------------------------- /imgs/example/img_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_2.jpg -------------------------------------------------------------------------------- /imgs/example/img_29.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_29.jpg -------------------------------------------------------------------------------- /imgs/example/img_75.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_75.jpg -------------------------------------------------------------------------------- /imgs/example/img_91.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_91.jpg -------------------------------------------------------------------------------- /imgs/paper/PAN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/paper/PAN.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:55 3 | # @Author : zhoujun 4 | from .model import Model 5 | from .loss import PANLoss 6 | 7 | 8 | def get_model(config): 9 | model_config = config['arch']['args'] 10 | return Model(model_config) 11 | 12 | def get_loss(config): 13 | alpha = config['loss']['args']['alpha'] 14 | beta = config['loss']['args']['beta'] 15 | delta_agg = config['loss']['args']['delta_agg'] 16 | delta_dis = config['loss']['args']['delta_dis'] 17 | ohem_ratio = config['loss']['args']['ohem_ratio'] 18 | return PANLoss(alpha=alpha, beta=beta, delta_agg=delta_agg, delta_dis=delta_dis, ohem_ratio=ohem_ratio) 19 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:56 3 | # @Author : zhoujun 4 | import itertools 5 | import torch 6 | from torch import nn 7 | import numpy as np 8 | 9 | 10 | class PANLoss(nn.Module): 11 | def __init__(self, alpha=0.5, beta=0.25, delta_agg=0.5, delta_dis=3, ohem_ratio=3, reduction='mean'): 12 | """ 13 | Implement PSE Loss. 14 | :param alpha: loss kernel 前面的系数 15 | :param beta: loss agg 和 loss dis 前面的系数 16 | :param delta_agg: 计算loss agg时的常量 17 | :param delta_dis: 计算loss dis时的常量 18 | :param ohem_ratio: OHEM的比例 19 | :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和 20 | """ 21 | super().__init__() 22 | assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']" 23 | self.alpha = alpha 24 | self.beta = beta 25 | self.delta_agg = delta_agg 26 | self.delta_dis = delta_dis 27 | self.ohem_ratio = ohem_ratio 28 | self.reduction = reduction 29 | 30 | def forward(self, outputs, labels, training_masks): 31 | texts = outputs[:, 0, :, :] 32 | kernels = outputs[:, 1, :, :] 33 | gt_texts = labels[:, 0, :, :] 34 | gt_kernels = labels[:, 1, :, :] 35 | 36 | 37 | # 计算 agg loss 和 dis loss 38 | similarity_vectors = outputs[:, 2:, :, :] 39 | loss_aggs, loss_diss = self.agg_dis_loss(texts, kernels, gt_texts, gt_kernels, similarity_vectors) 40 | 41 | # 计算 text loss 42 | selected_masks = self.ohem_batch(texts, gt_texts, training_masks) 43 | selected_masks = selected_masks.to(outputs.device) 44 | 45 | loss_texts = self.dice_loss(texts, gt_texts, selected_masks) 46 | 47 | # 计算 kernel loss 48 | # selected_masks = ((gt_texts > 0.5) & (training_masks > 0.5)).float() 49 | mask0 = torch.sigmoid(texts).detach().cpu().numpy() 50 | mask1 = training_masks.data.cpu().numpy() 51 | selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32') 52 | selected_masks = torch.from_numpy(selected_masks).float().to(texts.device) 53 | loss_kernels = self.dice_loss(kernels, gt_kernels, selected_masks) 54 | 55 | # mean or sum 56 | if self.reduction == 'mean': 57 | loss_text = loss_texts.mean() 58 | loss_kernel = loss_kernels.mean() 59 | loss_agg = loss_aggs.mean() 60 | loss_dis = loss_diss.mean() 61 | elif self.reduction == 'sum': 62 | loss_text = loss_texts.sum() 63 | loss_kernel = loss_kernels.sum() 64 | loss_agg = loss_aggs.sum() 65 | loss_dis = loss_diss.sum() 66 | else: 67 | raise NotImplementedError 68 | 69 | loss_all = loss_text + self.alpha * loss_kernel + self.beta * (loss_agg + loss_dis) 70 | return loss_all, loss_text, loss_kernel, loss_agg, loss_dis 71 | 72 | def agg_dis_loss(self, texts, kernels, gt_texts, gt_kernels, similarity_vectors): 73 | """ 74 | 计算 loss agg 75 | :param texts: 文本实例的分割结果 batch_size * (w*h) 76 | :param kernels: 缩小的文本实例的分割结果 batch_size * (w*h) 77 | :param gt_texts: 文本实例的gt batch_size * (w*h) 78 | :param gt_kernels: 缩小的文本实例的gt batch_size*(w*h) 79 | :param similarity_vectors: 相似度向量的分割结果 batch_size * 4 *(w*h) 80 | :return: 81 | """ 82 | batch_size = texts.size()[0] 83 | texts = texts.contiguous().reshape(batch_size, -1) 84 | kernels = kernels.contiguous().reshape(batch_size, -1) 85 | gt_texts = gt_texts.contiguous().reshape(batch_size, -1) 86 | gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1) 87 | similarity_vectors = similarity_vectors.contiguous().view(batch_size, 4, -1) 88 | loss_aggs = [] 89 | loss_diss = [] 90 | for text_i, kernel_i, gt_text_i, gt_kernel_i, similarity_vector in zip(texts, kernels, gt_texts, gt_kernels, 91 | similarity_vectors): 92 | text_num = gt_text_i.max().item() + 1 93 | loss_agg_single_sample = [] 94 | G_kernel_list = [] # 存储计算好的G_Ki,用于计算loss dis 95 | # 求解每一个文本实例的loss agg 96 | for text_idx in range(1, int(text_num)): 97 | # 计算 D_p_Ki 98 | single_kernel_mask = gt_kernel_i == text_idx 99 | if single_kernel_mask.sum() == 0 or (gt_text_i == text_idx).sum() == 0: 100 | # 这个文本被crop掉了 101 | continue 102 | # G_Ki, shape: 4 103 | G_kernel = similarity_vector[:, single_kernel_mask].mean(1) # 4 104 | G_kernel_list.append(G_kernel) 105 | # 文本像素的矩阵 F(p) shape: 4* nums (num of text pixel) 106 | text_similarity_vector = similarity_vector[:, gt_text_i == text_idx] 107 | # ||F(p) - G(K_i)|| - delta_agg, shape: nums 108 | text_G_ki = (text_similarity_vector - G_kernel.reshape(4, 1)).norm(2, dim=0) - self.delta_agg 109 | # D(p,K_i), shape: nums 110 | D_text_kernel = torch.max(text_G_ki, torch.tensor(0, device=text_G_ki.device, dtype=torch.float)).pow(2) 111 | # 计算单个文本实例的loss, shape: nums 112 | loss_agg_single_text = torch.log(D_text_kernel + 1).mean() 113 | loss_agg_single_sample.append(loss_agg_single_text) 114 | if len(loss_agg_single_sample) > 0: 115 | loss_agg_single_sample = torch.stack(loss_agg_single_sample).mean() 116 | else: 117 | loss_agg_single_sample = torch.tensor(0, device=texts.device, dtype=torch.float) 118 | loss_aggs.append(loss_agg_single_sample) 119 | 120 | # 求解每一个文本实例的loss dis 121 | loss_dis_single_sample = 0 122 | for G_kernel_i, G_kernel_j in itertools.combinations(G_kernel_list, 2): 123 | # delta_dis - ||G(K_i) - G(K_j)|| 124 | kernel_ij = self.delta_dis - (G_kernel_i - G_kernel_j).norm(2) 125 | # D(K_i,K_j) 126 | D_kernel_ij = torch.max(kernel_ij, torch.tensor(0, device=kernel_ij.device, dtype=torch.float)).pow(2) 127 | loss_dis_single_sample += torch.log(D_kernel_ij + 1) 128 | if len(G_kernel_list) > 1: 129 | loss_dis_single_sample /= (len(G_kernel_list) * (len(G_kernel_list) - 1)) 130 | else: 131 | loss_dis_single_sample = torch.tensor(0, device=texts.device, dtype=torch.float) 132 | loss_diss.append(loss_dis_single_sample) 133 | return torch.stack(loss_aggs), torch.stack(loss_diss) 134 | 135 | def dice_loss(self, input, target, mask): 136 | input = torch.sigmoid(input) 137 | target[target <= 0.5] = 0 138 | target[target > 0.5] = 1 139 | input = input.contiguous().view(input.size()[0], -1) 140 | target = target.contiguous().view(target.size()[0], -1) 141 | mask = mask.contiguous().view(mask.size()[0], -1) 142 | 143 | input = input * mask 144 | target = target * mask 145 | 146 | a = torch.sum(input * target, 1) 147 | b = torch.sum(input * input, 1) + 0.001 148 | c = torch.sum(target * target, 1) + 0.001 149 | d = (2 * a) / (b + c) 150 | return 1 - d 151 | 152 | def ohem_single(self, score, gt_text, training_mask): 153 | pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5))) 154 | 155 | if pos_num == 0: 156 | # selected_mask = gt_text.copy() * 0 # may be not good 157 | selected_mask = training_mask 158 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 159 | return selected_mask 160 | 161 | neg_num = (int)(np.sum(gt_text <= 0.5)) 162 | neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) 163 | 164 | if neg_num == 0: 165 | selected_mask = training_mask 166 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 167 | return selected_mask 168 | 169 | neg_score = score[gt_text <= 0.5] 170 | neg_score_sorted = np.sort(-neg_score) 171 | threshold = -neg_score_sorted[neg_num - 1] 172 | selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5) 173 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 174 | return selected_mask 175 | 176 | def ohem_batch(self, scores, gt_texts, training_masks): 177 | scores = scores.data.cpu().numpy() 178 | gt_texts = gt_texts.data.cpu().numpy() 179 | training_masks = training_masks.data.cpu().numpy() 180 | 181 | selected_masks = [] 182 | for i in range(scores.shape[0]): 183 | selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :])) 184 | 185 | selected_masks = np.concatenate(selected_masks, 0) 186 | selected_masks = torch.from_numpy(selected_masks).float() 187 | 188 | return selected_masks 189 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:57 3 | # @Author : zhoujun 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from models.modules import * 9 | 10 | backbone_dict = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]}, 11 | 'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]}, 12 | 'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]}, 13 | 'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]}, 14 | 'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]}, 15 | 'resnext50_32x4d': {'models': resnext50_32x4d, 'out': [256, 512, 1024, 2048]}, 16 | 'resnext101_32x8d': {'models': resnext101_32x8d, 'out': [256, 512, 1024, 2048]}, 17 | 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]} 18 | } 19 | 20 | segmentation_head_dict = {'FPN': FPN, 'FPEM_FFM': FPEM_FFM} 21 | 22 | 23 | # 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]}, 24 | # 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]}, 25 | # 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}} 26 | 27 | 28 | class Model(nn.Module): 29 | def __init__(self, model_config: dict): 30 | """ 31 | PANnet 32 | :param model_config: 模型配置 33 | """ 34 | super().__init__() 35 | backbone = model_config['backbone'] 36 | pretrained = model_config['pretrained'] 37 | segmentation_head = model_config['segmentation_head'] 38 | 39 | assert backbone in backbone_dict, 'backbone must in: {}'.format(backbone_dict) 40 | assert segmentation_head in segmentation_head_dict, 'segmentation_head must in: {}'.format( 41 | segmentation_head_dict) 42 | 43 | backbone_model, backbone_out = backbone_dict[backbone]['models'], backbone_dict[backbone]['out'] 44 | self.backbone = backbone_model(pretrained=pretrained) 45 | self.segmentation_head = segmentation_head_dict[segmentation_head](backbone_out, **model_config) 46 | self.name = '{}_{}'.format(backbone, segmentation_head) 47 | 48 | def forward(self, x): 49 | _, _, H, W = x.size() 50 | backbone_out = self.backbone(x) 51 | segmentation_head_out = self.segmentation_head(backbone_out) 52 | y = F.interpolate(segmentation_head_out, size=(H, W), mode='bilinear', align_corners=True) 53 | return y 54 | 55 | 56 | if __name__ == '__main__': 57 | device = torch.device('cpu') 58 | x = torch.zeros(1, 3, 640, 640).to(device) 59 | 60 | model_config = { 61 | 'backbone': 'shufflenetv2', 62 | 'fpem_repeat': 4, # fpem模块重复的次数 63 | 'pretrained': True, # backbone 是否使用imagesnet的预训练模型 64 | 'result_num': 7, 65 | 'segmentation_head': 'FPEM_FFM' # 分割头,FPN or FPEM_FFM 66 | } 67 | model = Model(model_config=model_config).to(device) 68 | y = model(x) 69 | print(y.shape) 70 | # print(model) 71 | # torch.save(model.state_dict(), 'PAN.pth') 72 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:54 3 | # @Author : zhoujun 4 | from .resnet import * 5 | from .shufflenetv2 import * 6 | from .segmentation_head import FPEM_FFM,FPN -------------------------------------------------------------------------------- /models/modules/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:55 3 | # @Author : zhoujun 4 | import torch.nn as nn 5 | from torchvision.models.utils import load_state_dict_from_url 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 17 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=dilation, groups=groups, bias=False, dilation=dilation) 25 | 26 | 27 | def conv1x1(in_planes, out_planes, stride=1): 28 | """1x1 convolution""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 36 | base_width=64, dilation=1, norm_layer=None): 37 | super(BasicBlock, self).__init__() 38 | if norm_layer is None: 39 | norm_layer = nn.BatchNorm2d 40 | if groups != 1 or base_width != 64: 41 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 42 | if dilation > 1: 43 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 44 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 45 | self.conv1 = conv3x3(inplanes, planes, stride) 46 | self.bn1 = norm_layer(planes) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.conv2 = conv3x3(planes, planes) 49 | self.bn2 = norm_layer(planes) 50 | self.downsample = downsample 51 | self.stride = stride 52 | 53 | def forward(self, x): 54 | identity = x 55 | 56 | out = self.conv1(x) 57 | out = self.bn1(out) 58 | out = self.relu(out) 59 | 60 | out = self.conv2(out) 61 | out = self.bn2(out) 62 | 63 | if self.downsample is not None: 64 | identity = self.downsample(x) 65 | 66 | out += identity 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 76 | base_width=64, dilation=1, norm_layer=None): 77 | super(Bottleneck, self).__init__() 78 | if norm_layer is None: 79 | norm_layer = nn.BatchNorm2d 80 | width = int(planes * (base_width / 64.)) * groups 81 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 82 | self.conv1 = conv1x1(inplanes, width) 83 | self.bn1 = norm_layer(width) 84 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 85 | self.bn2 = norm_layer(width) 86 | self.conv3 = conv1x1(width, planes * self.expansion) 87 | self.bn3 = norm_layer(planes * self.expansion) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.downsample = downsample 90 | self.stride = stride 91 | 92 | def forward(self, x): 93 | identity = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv3(out) 104 | out = self.bn3(out) 105 | 106 | if self.downsample is not None: 107 | identity = self.downsample(x) 108 | 109 | out += identity 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | class ResNet(nn.Module): 116 | 117 | def __init__(self, block, layers, zero_init_residual=False, 118 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 119 | norm_layer=None): 120 | super(ResNet, self).__init__() 121 | if norm_layer is None: 122 | norm_layer = nn.BatchNorm2d 123 | self._norm_layer = norm_layer 124 | 125 | self.inplanes = 64 126 | self.dilation = 1 127 | if replace_stride_with_dilation is None: 128 | # each element in the tuple indicates if we should replace 129 | # the 2x2 stride with a dilated convolution instead 130 | replace_stride_with_dilation = [False, False, False] 131 | if len(replace_stride_with_dilation) != 3: 132 | raise ValueError("replace_stride_with_dilation should be None " 133 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 134 | self.groups = groups 135 | self.base_width = width_per_group 136 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 137 | bias=False) 138 | self.bn1 = norm_layer(self.inplanes) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 141 | self.layer1 = self._make_layer(block, 64, layers[0]) 142 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 143 | dilate=replace_stride_with_dilation[0]) 144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 145 | dilate=replace_stride_with_dilation[1]) 146 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 147 | dilate=replace_stride_with_dilation[2]) 148 | 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 152 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 153 | nn.init.constant_(m.weight, 1) 154 | nn.init.constant_(m.bias, 0) 155 | 156 | # Zero-initialize the last BN in each residual branch, 157 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 158 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 159 | if zero_init_residual: 160 | for m in self.modules(): 161 | if isinstance(m, Bottleneck): 162 | nn.init.constant_(m.bn3.weight, 0) 163 | elif isinstance(m, BasicBlock): 164 | nn.init.constant_(m.bn2.weight, 0) 165 | 166 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 167 | norm_layer = self._norm_layer 168 | downsample = None 169 | previous_dilation = self.dilation 170 | if dilate: 171 | self.dilation *= stride 172 | stride = 1 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | conv1x1(self.inplanes, planes * block.expansion, stride), 176 | norm_layer(planes * block.expansion), 177 | ) 178 | 179 | layers = [] 180 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 181 | self.base_width, previous_dilation, norm_layer)) 182 | self.inplanes = planes * block.expansion 183 | for _ in range(1, blocks): 184 | layers.append(block(self.inplanes, planes, groups=self.groups, 185 | base_width=self.base_width, dilation=self.dilation, 186 | norm_layer=norm_layer)) 187 | 188 | return nn.Sequential(*layers) 189 | 190 | def forward(self, x): 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | x = self.maxpool(x) 195 | 196 | c2 = self.layer1(x) 197 | c3 = self.layer2(c2) 198 | c4 = self.layer3(c3) 199 | c5 = self.layer4(c4) 200 | 201 | return c2, c3, c4, c5 202 | 203 | 204 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 205 | model = ResNet(block, layers, **kwargs) 206 | if pretrained: 207 | state_dict = load_state_dict_from_url(model_urls[arch], 208 | progress=progress) 209 | model.load_state_dict(state_dict, strict=False) 210 | print('load pretrained models from imagenet') 211 | return model 212 | 213 | 214 | def resnet18(pretrained=False, progress=True, **kwargs): 215 | """Constructs a ResNet-18 model. 216 | 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | progress (bool): If True, displays a progress bar of the download to stderr 220 | """ 221 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 222 | **kwargs) 223 | 224 | 225 | def resnet34(pretrained=False, progress=True, **kwargs): 226 | """Constructs a ResNet-34 model. 227 | 228 | Args: 229 | pretrained (bool): If True, returns a model pre-trained on ImageNet 230 | progress (bool): If True, displays a progress bar of the download to stderr 231 | """ 232 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 233 | **kwargs) 234 | 235 | 236 | def resnet50(pretrained=False, progress=True, **kwargs): 237 | """Constructs a ResNet-50 model. 238 | 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 244 | **kwargs) 245 | 246 | 247 | def resnet101(pretrained=False, progress=True, **kwargs): 248 | """Constructs a ResNet-101 model. 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet152(pretrained=False, progress=True, **kwargs): 259 | """Constructs a ResNet-152 model. 260 | 261 | Args: 262 | pretrained (bool): If True, returns a model pre-trained on ImageNet 263 | progress (bool): If True, displays a progress bar of the download to stderr 264 | """ 265 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 266 | **kwargs) 267 | 268 | 269 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 270 | """Constructs a ResNeXt-50 32x4d model. 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | kwargs['groups'] = 32 277 | kwargs['width_per_group'] = 4 278 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 279 | pretrained, progress, **kwargs) 280 | 281 | 282 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 283 | """Constructs a ResNeXt-101 32x8d model. 284 | 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | progress (bool): If True, displays a progress bar of the download to stderr 288 | """ 289 | kwargs['groups'] = 32 290 | kwargs['width_per_group'] = 8 291 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 292 | pretrained, progress, **kwargs) 293 | 294 | if __name__ == '__main__': 295 | import torch 296 | x = torch.zeros(1, 3, 640, 640) 297 | net = resnext101_32x8d(pretrained=False) 298 | y = net(x) 299 | for u in y: 300 | print(u.shape) -------------------------------------------------------------------------------- /models/modules/segmentation_head.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/9/13 10:29 3 | # @Author : zhoujun 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class FPN(nn.Module): 10 | def __init__(self, backbone_out_channels, **kwargs): 11 | """ 12 | :param backbone_out_channels: 基础网络输出的维度 13 | :param kwargs: 14 | """ 15 | super().__init__() 16 | result_num = kwargs.get('result_num', 6) 17 | inplace = True 18 | conv_out = 256 19 | # reduce layers 20 | self.reduce_conv_c2 = nn.Sequential( 21 | nn.Conv2d(backbone_out_channels[0], conv_out, kernel_size=1, stride=1, padding=0), 22 | nn.BatchNorm2d(conv_out), 23 | nn.ReLU(inplace=inplace) 24 | ) 25 | self.reduce_conv_c3 = nn.Sequential( 26 | nn.Conv2d(backbone_out_channels[1], conv_out, kernel_size=1, stride=1, padding=0), 27 | nn.BatchNorm2d(conv_out), 28 | nn.ReLU(inplace=inplace) 29 | ) 30 | self.reduce_conv_c4 = nn.Sequential( 31 | nn.Conv2d(backbone_out_channels[2], conv_out, kernel_size=1, stride=1, padding=0), 32 | nn.BatchNorm2d(conv_out), 33 | nn.ReLU(inplace=inplace) 34 | ) 35 | 36 | self.reduce_conv_c5 = nn.Sequential( 37 | nn.Conv2d(backbone_out_channels[3], conv_out, kernel_size=1, stride=1, padding=0), 38 | nn.BatchNorm2d(conv_out), 39 | nn.ReLU(inplace=inplace) 40 | ) 41 | # Smooth layers 42 | self.smooth_p4 = nn.Sequential( 43 | nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1), 44 | nn.BatchNorm2d(conv_out), 45 | nn.ReLU(inplace=inplace) 46 | ) 47 | self.smooth_p3 = nn.Sequential( 48 | nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1), 49 | nn.BatchNorm2d(conv_out), 50 | nn.ReLU(inplace=inplace) 51 | ) 52 | self.smooth_p2 = nn.Sequential( 53 | nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1), 54 | nn.BatchNorm2d(conv_out), 55 | nn.ReLU(inplace=inplace) 56 | ) 57 | 58 | self.conv = nn.Sequential( 59 | nn.Conv2d(conv_out * 4, conv_out, kernel_size=3, padding=1, stride=1), 60 | nn.BatchNorm2d(conv_out), 61 | nn.ReLU(inplace=inplace) 62 | ) 63 | self.out_conv = nn.Conv2d(conv_out, result_num, kernel_size=1, stride=1) 64 | 65 | def forward(self, x): 66 | c2, c3, c4, c5 = x 67 | # Top-down 68 | p5 = self.reduce_conv_c5(c5) 69 | p4 = self._upsample_add(p5, self.reduce_conv_c4(c4)) 70 | p4 = self.smooth_p4(p4) 71 | p3 = self._upsample_add(p4, self.reduce_conv_c3(c3)) 72 | p3 = self.smooth_p3(p3) 73 | p2 = self._upsample_add(p3, self.reduce_conv_c2(c2)) 74 | p2 = self.smooth_p2(p2) 75 | 76 | x = self._upsample_cat(p2, p3, p4, p5) 77 | x = self.conv(x) 78 | x = self.out_conv(x) 79 | return x 80 | 81 | def _upsample_add(self, x, y): 82 | return F.interpolate(x, size=y.size()[2:], mode='bilinear') + y 83 | 84 | def _upsample_cat(self, p2, p3, p4, p5): 85 | h, w = p2.size()[2:] 86 | p3 = F.interpolate(p3, size=(h, w), mode='bilinear') 87 | p4 = F.interpolate(p4, size=(h, w), mode='bilinear') 88 | p5 = F.interpolate(p5, size=(h, w), mode='bilinear') 89 | return torch.cat([p2, p3, p4, p5], dim=1) 90 | 91 | 92 | class FPEM_FFM(nn.Module): 93 | def __init__(self, backbone_out_channels, **kwargs): 94 | """ 95 | PANnet 96 | :param backbone_out_channels: 基础网络输出的维度 97 | """ 98 | super().__init__() 99 | fpem_repeat = kwargs.get('fpem_repeat', 2) 100 | conv_out = 128 101 | # reduce layers 102 | self.reduce_conv_c2 = nn.Sequential( 103 | nn.Conv2d(in_channels=backbone_out_channels[0], out_channels=conv_out, kernel_size=1), 104 | nn.BatchNorm2d(conv_out), 105 | nn.ReLU() 106 | ) 107 | self.reduce_conv_c3 = nn.Sequential( 108 | nn.Conv2d(in_channels=backbone_out_channels[1], out_channels=conv_out, kernel_size=1), 109 | nn.BatchNorm2d(conv_out), 110 | nn.ReLU() 111 | ) 112 | self.reduce_conv_c4 = nn.Sequential( 113 | nn.Conv2d(in_channels=backbone_out_channels[2], out_channels=conv_out, kernel_size=1), 114 | nn.BatchNorm2d(conv_out), 115 | nn.ReLU() 116 | ) 117 | self.reduce_conv_c5 = nn.Sequential( 118 | nn.Conv2d(in_channels=backbone_out_channels[3], out_channels=conv_out, kernel_size=1), 119 | nn.BatchNorm2d(conv_out), 120 | nn.ReLU() 121 | ) 122 | self.fpems = nn.ModuleList() 123 | for i in range(fpem_repeat): 124 | self.fpems.append(FPEM(conv_out)) 125 | self.out_conv = nn.Conv2d(in_channels=conv_out * 4, out_channels=6, kernel_size=1) 126 | 127 | def forward(self, x): 128 | c2, c3, c4, c5 = x 129 | # reduce channel 130 | c2 = self.reduce_conv_c2(c2) 131 | c3 = self.reduce_conv_c3(c3) 132 | c4 = self.reduce_conv_c4(c4) 133 | c5 = self.reduce_conv_c5(c5) 134 | 135 | # FPEM 136 | for i, fpem in enumerate(self.fpems): 137 | c2, c3, c4, c5 = fpem(c2, c3, c4, c5) 138 | if i == 0: 139 | c2_ffm = c2 140 | c3_ffm = c3 141 | c4_ffm = c4 142 | c5_ffm = c5 143 | else: 144 | c2_ffm += c2 145 | c3_ffm += c3 146 | c4_ffm += c4 147 | c5_ffm += c5 148 | 149 | # FFM 150 | c5 = F.interpolate(c5_ffm, c2_ffm.size()[-2:], mode='bilinear') 151 | c4 = F.interpolate(c4_ffm, c2_ffm.size()[-2:], mode='bilinear') 152 | c3 = F.interpolate(c3_ffm, c2_ffm.size()[-2:], mode='bilinear') 153 | Fy = torch.cat([c2_ffm, c3, c4, c5], dim=1) 154 | y = self.out_conv(Fy) 155 | return y 156 | 157 | 158 | class FPEM(nn.Module): 159 | def __init__(self, in_channels=128): 160 | super().__init__() 161 | self.up_add1 = SeparableConv2d(in_channels, in_channels, 1) 162 | self.up_add2 = SeparableConv2d(in_channels, in_channels, 1) 163 | self.up_add3 = SeparableConv2d(in_channels, in_channels, 1) 164 | self.down_add1 = SeparableConv2d(in_channels, in_channels, 2) 165 | self.down_add2 = SeparableConv2d(in_channels, in_channels, 2) 166 | self.down_add3 = SeparableConv2d(in_channels, in_channels, 2) 167 | 168 | def forward(self, c2, c3, c4, c5): 169 | # up阶段 170 | c4 = self.up_add1(self._upsample_add(c5, c4)) 171 | c3 = self.up_add2(self._upsample_add(c4, c3)) 172 | c2 = self.up_add3(self._upsample_add(c3, c2)) 173 | 174 | # down 阶段 175 | c3 = self.down_add1(self._upsample_add(c3, c2)) 176 | c4 = self.down_add2(self._upsample_add(c4, c3)) 177 | c5 = self.down_add3(self._upsample_add(c5, c4)) 178 | return c2, c3, c4, c5 179 | 180 | def _upsample_add(self, x, y): 181 | return F.interpolate(x, size=y.size()[2:], mode='bilinear') + y 182 | 183 | 184 | class SeparableConv2d(nn.Module): 185 | def __init__(self, in_channels, out_channels, stride=1): 186 | super(SeparableConv2d, self).__init__() 187 | 188 | self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1, 189 | stride=stride, groups=in_channels) 190 | self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) 191 | self.bn = nn.BatchNorm2d(out_channels) 192 | self.relu = nn.ReLU() 193 | 194 | def forward(self, x): 195 | x = self.depthwise_conv(x) 196 | x = self.pointwise_conv(x) 197 | x = self.bn(x) 198 | x = self.relu(x) 199 | return x 200 | -------------------------------------------------------------------------------- /models/modules/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/11/1 15:31 3 | # @Author : zhoujun 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.models.utils import load_state_dict_from_url 8 | 9 | __all__ = [ 10 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 11 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 12 | ] 13 | 14 | model_urls = { 15 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 16 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 17 | 'shufflenetv2_x1.5': None, 18 | 'shufflenetv2_x2.0': None, 19 | } 20 | 21 | 22 | def channel_shuffle(x, groups): 23 | batchsize, num_channels, height, width = x.data.size() 24 | channels_per_group = num_channels // groups 25 | 26 | # reshape 27 | x = x.view(batchsize, groups, 28 | channels_per_group, height, width) 29 | 30 | x = torch.transpose(x, 1, 2).contiguous() 31 | 32 | # flatten 33 | x = x.view(batchsize, -1, height, width) 34 | 35 | return x 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride): 40 | super(InvertedResidual, self).__init__() 41 | 42 | if not (1 <= stride <= 3): 43 | raise ValueError('illegal stride value') 44 | self.stride = stride 45 | 46 | branch_features = oup // 2 47 | assert (self.stride != 1) or (inp == branch_features << 1) 48 | 49 | if self.stride > 1: 50 | self.branch1 = nn.Sequential( 51 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 52 | nn.BatchNorm2d(inp), 53 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 54 | nn.BatchNorm2d(branch_features), 55 | nn.ReLU(inplace=True), 56 | ) 57 | 58 | self.branch2 = nn.Sequential( 59 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 60 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 61 | nn.BatchNorm2d(branch_features), 62 | nn.ReLU(inplace=True), 63 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 64 | nn.BatchNorm2d(branch_features), 65 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 66 | nn.BatchNorm2d(branch_features), 67 | nn.ReLU(inplace=True), 68 | ) 69 | 70 | @staticmethod 71 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 72 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 73 | 74 | def forward(self, x): 75 | if self.stride == 1: 76 | x1, x2 = x.chunk(2, dim=1) 77 | out = torch.cat((x1, self.branch2(x2)), dim=1) 78 | else: 79 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 80 | 81 | out = channel_shuffle(out, 2) 82 | 83 | return out 84 | 85 | 86 | class ShuffleNetV2(nn.Module): 87 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000): 88 | super(ShuffleNetV2, self).__init__() 89 | 90 | if len(stages_repeats) != 3: 91 | raise ValueError('expected stages_repeats as list of 3 positive ints') 92 | if len(stages_out_channels) != 5: 93 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 94 | self._stage_out_channels = stages_out_channels 95 | 96 | input_channels = 3 97 | output_channels = self._stage_out_channels[0] 98 | self.conv1 = nn.Sequential( 99 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 100 | nn.BatchNorm2d(output_channels), 101 | nn.ReLU(inplace=True), 102 | ) 103 | input_channels = output_channels 104 | 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | 107 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 108 | for name, repeats, output_channels in zip( 109 | stage_names, stages_repeats, self._stage_out_channels[1:]): 110 | seq = [InvertedResidual(input_channels, output_channels, 2)] 111 | for i in range(repeats - 1): 112 | seq.append(InvertedResidual(output_channels, output_channels, 1)) 113 | setattr(self, name, nn.Sequential(*seq)) 114 | input_channels = output_channels 115 | 116 | output_channels = self._stage_out_channels[-1] 117 | self.conv5 = nn.Sequential( 118 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 119 | nn.BatchNorm2d(output_channels), 120 | nn.ReLU(inplace=True), 121 | ) 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | c2 = self.maxpool(x) 126 | c3 = self.stage2(c2) 127 | c4 = self.stage3(c3) 128 | c5 = self.stage4(c4) 129 | # c5 = self.conv5(c5) 130 | return c2, c3, c4, c5 131 | 132 | 133 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 134 | model = ShuffleNetV2(*args, **kwargs) 135 | 136 | if pretrained: 137 | model_url = model_urls[arch] 138 | if model_url is None: 139 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 140 | else: 141 | state_dict = load_state_dict_from_url(model_url, progress=progress) 142 | model.load_state_dict(state_dict,strict=False) 143 | 144 | return model 145 | 146 | 147 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 148 | """ 149 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 150 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 151 | `_. 152 | 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | progress (bool): If True, displays a progress bar of the download to stderr 156 | """ 157 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 158 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 159 | 160 | 161 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 162 | """ 163 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 164 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 165 | `_. 166 | 167 | Args: 168 | pretrained (bool): If True, returns a model pre-trained on ImageNet 169 | progress (bool): If True, displays a progress bar of the download to stderr 170 | """ 171 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 172 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 173 | 174 | 175 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): 176 | """ 177 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 178 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 179 | `_. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | progress (bool): If True, displays a progress bar of the download to stderr 184 | """ 185 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 186 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 187 | 188 | 189 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): 190 | """ 191 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 192 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 193 | `_. 194 | 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | progress (bool): If True, displays a progress bar of the download to stderr 198 | """ 199 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 200 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 201 | -------------------------------------------------------------------------------- /post_processing/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags) 2 | LDFLAGS = $(shell python3-config --ldflags) 3 | 4 | DEPS = $(shell find include -xtype f) 5 | CXX_SOURCES = pse.cpp 6 | 7 | LIB_SO = pse.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | -------------------------------------------------------------------------------- /post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/9/8 14:18 3 | # @Author : zhoujun 4 | import os 5 | import cv2 6 | import torch 7 | import time 8 | import subprocess 9 | import numpy as np 10 | 11 | from .pypse import pse_py 12 | from .kmeans import km 13 | 14 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 15 | 16 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value 17 | raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR)) 18 | 19 | 20 | def decode(preds, scale=1, threshold=0.7311, min_area=5): 21 | """ 22 | 在输出上使用sigmoid 将值转换为置信度,并使用阈值来进行文字和背景的区分 23 | :param preds: 网络输出 24 | :param scale: 网络的scale 25 | :param threshold: sigmoid的阈值 26 | :return: 最后的输出图和文本框 27 | """ 28 | from .pse import pse_cpp, get_points, get_num 29 | preds[:2, :, :] = torch.sigmoid(preds[:2, :, :]) 30 | preds = preds.detach().cpu().numpy() 31 | score = preds[0].astype(np.float32) 32 | text = preds[0] > threshold # text 33 | kernel = (preds[1] > threshold) * text # kernel 34 | similarity_vectors = preds[2:].transpose((1, 2, 0)) 35 | 36 | label_num, label = cv2.connectedComponents(kernel.astype(np.uint8), connectivity=4) 37 | label_values = [] 38 | label_sum = get_num(label, label_num) 39 | for label_idx in range(1, label_num): 40 | if label_sum[label_idx] < min_area: 41 | continue 42 | label_values.append(label_idx) 43 | 44 | pred = pse_cpp(text.astype(np.uint8), similarity_vectors, label, label_num, 0.8) 45 | pred = pred.reshape(text.shape) 46 | 47 | bbox_list = [] 48 | label_points = get_points(pred, score, label_num) 49 | for label_value, label_point in label_points.items(): 50 | if label_value not in label_values: 51 | continue 52 | score_i = label_point[0] 53 | label_point = label_point[2:] 54 | points = np.array(label_point, dtype=int).reshape(-1, 2) 55 | 56 | if points.shape[0] < 100 / (scale * scale): 57 | continue 58 | 59 | if score_i < 0.93: 60 | continue 61 | 62 | rect = cv2.minAreaRect(points) 63 | bbox = cv2.boxPoints(rect) 64 | bbox_list.append([bbox[1], bbox[2], bbox[3], bbox[0]]) 65 | return pred, np.array(bbox_list) 66 | 67 | 68 | def decode_dice(preds, scale=1, threshold=0.7311, min_area=5): 69 | import pyclipper 70 | preds[:2, :, :] = torch.sigmoid(preds[:2, :, :]) 71 | preds = preds.detach().cpu().numpy() 72 | text = preds[0] > threshold # text 73 | kernel = (preds[1] > threshold) * text # kernel 74 | 75 | label_num, label = cv2.connectedComponents(kernel.astype(np.uint8), connectivity=4) 76 | bbox_list = [] 77 | for label_idx in range(1, label_num): 78 | points = np.array(np.where(label_num == label_idx)).transpose((1, 0))[:, ::-1] 79 | 80 | rect = cv2.minAreaRect(points) 81 | poly = cv2.boxPoints(rect).astype(int) 82 | 83 | d_i = cv2.contourArea(poly) * 1.5 / cv2.arcLength(poly, True) 84 | pco = pyclipper.PyclipperOffset() 85 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 86 | shrinked_poly = np.array(pco.Execute(-d_i)) 87 | 88 | if cv2.contourArea(shrinked_poly) < 800 / (scale * scale): 89 | continue 90 | 91 | bbox_list.append([shrinked_poly[1], shrinked_poly[2], shrinked_poly[3], shrinked_poly[0]]) 92 | return label, np.array(bbox_list) 93 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/buffer_info.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/buffer_info.h: Python buffer object interface 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | /// Information record describing a Python buffer object 17 | struct buffer_info { 18 | void *ptr = nullptr; // Pointer to the underlying storage 19 | ssize_t itemsize = 0; // Size of individual items in bytes 20 | ssize_t size = 0; // Total number of entries 21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() 22 | ssize_t ndim = 0; // Number of dimensions 23 | std::vector shape; // Shape of the tensor (1 entry per dimension) 24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension) 25 | 26 | buffer_info() { } 27 | 28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 29 | detail::any_container shape_in, detail::any_container strides_in) 30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), 31 | shape(std::move(shape_in)), strides(std::move(strides_in)) { 32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) 33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); 34 | for (size_t i = 0; i < (size_t) ndim; ++i) 35 | size *= shape[i]; 36 | } 37 | 38 | template 39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) 40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } 41 | 42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) 43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } 44 | 45 | template 46 | buffer_info(T *ptr, ssize_t size) 47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } 48 | 49 | explicit buffer_info(Py_buffer *view, bool ownview = true) 50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim, 51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { 52 | this->view = view; 53 | this->ownview = ownview; 54 | } 55 | 56 | buffer_info(const buffer_info &) = delete; 57 | buffer_info& operator=(const buffer_info &) = delete; 58 | 59 | buffer_info(buffer_info &&other) { 60 | (*this) = std::move(other); 61 | } 62 | 63 | buffer_info& operator=(buffer_info &&rhs) { 64 | ptr = rhs.ptr; 65 | itemsize = rhs.itemsize; 66 | size = rhs.size; 67 | format = std::move(rhs.format); 68 | ndim = rhs.ndim; 69 | shape = std::move(rhs.shape); 70 | strides = std::move(rhs.strides); 71 | std::swap(view, rhs.view); 72 | std::swap(ownview, rhs.ownview); 73 | return *this; 74 | } 75 | 76 | ~buffer_info() { 77 | if (view && ownview) { PyBuffer_Release(view); delete view; } 78 | } 79 | 80 | private: 81 | struct private_ctr_tag { }; 82 | 83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 84 | detail::any_container &&shape_in, detail::any_container &&strides_in) 85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } 86 | 87 | Py_buffer *view = nullptr; 88 | bool ownview = false; 89 | }; 90 | 91 | NAMESPACE_BEGIN(detail) 92 | 93 | template struct compare_buffer_info { 94 | static bool compare(const buffer_info& b) { 95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); 96 | } 97 | }; 98 | 99 | template struct compare_buffer_info::value>> { 100 | static bool compare(const buffer_info& b) { 101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || 102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || 103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); 104 | } 105 | }; 106 | 107 | NAMESPACE_END(detail) 108 | NAMESPACE_END(PYBIND11_NAMESPACE) 109 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/chrono.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime 3 | 4 | Copyright (c) 2016 Trent Houliston and 5 | Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "pybind11.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required 20 | #ifndef PyDateTime_DELTA_GET_DAYS 21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) 22 | #endif 23 | #ifndef PyDateTime_DELTA_GET_SECONDS 24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) 25 | #endif 26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS 27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) 28 | #endif 29 | 30 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template class duration_caster { 34 | public: 35 | typedef typename type::rep rep; 36 | typedef typename type::period period; 37 | 38 | typedef std::chrono::duration> days; 39 | 40 | bool load(handle src, bool) { 41 | using namespace std::chrono; 42 | 43 | // Lazy initialise the PyDateTime import 44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 45 | 46 | if (!src) return false; 47 | // If invoked with datetime.delta object 48 | if (PyDelta_Check(src.ptr())) { 49 | value = type(duration_cast>( 50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr())) 51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) 52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); 53 | return true; 54 | } 55 | // If invoked with a float we assume it is seconds and convert 56 | else if (PyFloat_Check(src.ptr())) { 57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); 58 | return true; 59 | } 60 | else return false; 61 | } 62 | 63 | // If this is a duration just return it back 64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) { 65 | return src; 66 | } 67 | 68 | // If this is a time_point get the time_since_epoch 69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { 70 | return src.time_since_epoch(); 71 | } 72 | 73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { 74 | using namespace std::chrono; 75 | 76 | // Use overloaded function to get our duration from our source 77 | // Works out if it is a duration or time_point and get the duration 78 | auto d = get_duration(src); 79 | 80 | // Lazy initialise the PyDateTime import 81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 82 | 83 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 84 | using dd_t = duration>; 85 | using ss_t = duration>; 86 | using us_t = duration; 87 | 88 | auto dd = duration_cast(d); 89 | auto subd = d - dd; 90 | auto ss = duration_cast(subd); 91 | auto us = duration_cast(subd - ss); 92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); 93 | } 94 | 95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); 96 | }; 97 | 98 | // This is for casting times on the system clock into datetime.datetime instances 99 | template class type_caster> { 100 | public: 101 | typedef std::chrono::time_point type; 102 | bool load(handle src, bool) { 103 | using namespace std::chrono; 104 | 105 | // Lazy initialise the PyDateTime import 106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 107 | 108 | if (!src) return false; 109 | if (PyDateTime_Check(src.ptr())) { 110 | std::tm cal; 111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); 112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); 113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); 114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); 115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; 116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; 117 | cal.tm_isdst = -1; 118 | 119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); 120 | return true; 121 | } 122 | else return false; 123 | } 124 | 125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { 126 | using namespace std::chrono; 127 | 128 | // Lazy initialise the PyDateTime import 129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 130 | 131 | std::time_t tt = system_clock::to_time_t(src); 132 | // this function uses static memory so it's best to copy it out asap just in case 133 | // otherwise other code that is using localtime may break this (not just python code) 134 | std::tm localtime = *std::localtime(&tt); 135 | 136 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 137 | using us_t = duration; 138 | 139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, 140 | localtime.tm_mon + 1, 141 | localtime.tm_mday, 142 | localtime.tm_hour, 143 | localtime.tm_min, 144 | localtime.tm_sec, 145 | (duration_cast(src.time_since_epoch() % seconds(1))).count()); 146 | } 147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); 148 | }; 149 | 150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects 151 | // since they are not measured on calendar time. So instead we just make them timedeltas 152 | // Or if they have passed us a time as a float we convert that 153 | template class type_caster> 154 | : public duration_caster> { 155 | }; 156 | 157 | template class type_caster> 158 | : public duration_caster> { 159 | }; 160 | 161 | NAMESPACE_END(detail) 162 | NAMESPACE_END(PYBIND11_NAMESPACE) 163 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/common.h: -------------------------------------------------------------------------------- 1 | #include "detail/common.h" 2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." 3 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | #ifndef PYBIND11_CPP17 29 | 30 | template constexpr const char format_descriptor< 31 | std::complex, detail::enable_if_t::value>>::value[3]; 32 | 33 | #endif 34 | 35 | NAMESPACE_BEGIN(detail) 36 | 37 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 38 | static constexpr bool value = true; 39 | static constexpr int index = is_fmt_numeric::index + 3; 40 | }; 41 | 42 | template class type_caster> { 43 | public: 44 | bool load(handle src, bool convert) { 45 | if (!src) 46 | return false; 47 | if (!convert && !PyComplex_Check(src.ptr())) 48 | return false; 49 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 50 | if (result.real == -1.0 && PyErr_Occurred()) { 51 | PyErr_Clear(); 52 | return false; 53 | } 54 | value = std::complex((T) result.real, (T) result.imag); 55 | return true; 56 | } 57 | 58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 60 | } 61 | 62 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 63 | }; 64 | NAMESPACE_END(detail) 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/descr.h: Helper type for concatenating type signatures 3 | either at runtime (C++11) or compile time (C++14) 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "common.h" 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | /* Concatenate type signatures at compile time using C++14 */ 19 | #if defined(PYBIND11_CPP14) && !defined(_MSC_VER) 20 | #define PYBIND11_CONSTEXPR_DESCR 21 | 22 | template class descr { 23 | template friend class descr; 24 | public: 25 | constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1]) 26 | : descr(text, types, 27 | make_index_sequence(), 28 | make_index_sequence()) { } 29 | 30 | constexpr const char *text() const { return m_text; } 31 | constexpr const std::type_info * const * types() const { return m_types; } 32 | 33 | template 34 | constexpr descr operator+(const descr &other) const { 35 | return concat(other, 36 | make_index_sequence(), 37 | make_index_sequence(), 38 | make_index_sequence(), 39 | make_index_sequence()); 40 | } 41 | 42 | protected: 43 | template 44 | constexpr descr( 45 | char const (&text) [Size1+1], 46 | const std::type_info * const (&types) [Size2+1], 47 | index_sequence, index_sequence) 48 | : m_text{text[Indices1]..., '\0'}, 49 | m_types{types[Indices2]..., nullptr } {} 50 | 51 | template 53 | constexpr descr 54 | concat(const descr &other, 55 | index_sequence, index_sequence, 56 | index_sequence, index_sequence) const { 57 | return descr( 58 | { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' }, 59 | { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr } 60 | ); 61 | } 62 | 63 | protected: 64 | char m_text[Size1 + 1]; 65 | const std::type_info * m_types[Size2 + 1]; 66 | }; 67 | 68 | template constexpr descr _(char const(&text)[Size]) { 69 | return descr(text, { nullptr }); 70 | } 71 | 72 | template struct int_to_str : int_to_str { }; 73 | template struct int_to_str<0, Digits...> { 74 | static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr }); 75 | }; 76 | 77 | // Ternary description (like std::conditional) 78 | template 79 | constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) { 80 | return _(text1); 81 | } 82 | template 83 | constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) { 84 | return _(text2); 85 | } 86 | template 87 | constexpr enable_if_t> _(descr d, descr) { return d; } 88 | template 89 | constexpr enable_if_t> _(descr, descr d) { return d; } 90 | 91 | template auto constexpr _() -> decltype(int_to_str::digits) { 92 | return int_to_str::digits; 93 | } 94 | 95 | template constexpr descr<1, 1> _() { 96 | return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr }); 97 | } 98 | 99 | inline constexpr descr<0, 0> concat() { return _(""); } 100 | template auto constexpr concat(descr descr) { return descr; } 101 | template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); } 102 | template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); } 103 | 104 | #define PYBIND11_DESCR constexpr auto 105 | 106 | #else /* Simpler C++11 implementation based on run-time memory allocation and copying */ 107 | 108 | class descr { 109 | public: 110 | PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) { 111 | size_t nChars = len(text), nTypes = len(types); 112 | m_text = new char[nChars]; 113 | m_types = new const std::type_info *[nTypes]; 114 | memcpy(m_text, text, nChars * sizeof(char)); 115 | memcpy(m_types, types, nTypes * sizeof(const std::type_info *)); 116 | } 117 | 118 | PYBIND11_NOINLINE descr operator+(descr &&d2) && { 119 | descr r; 120 | 121 | size_t nChars1 = len(m_text), nTypes1 = len(m_types); 122 | size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types); 123 | 124 | r.m_text = new char[nChars1 + nChars2 - 1]; 125 | r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1]; 126 | memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char)); 127 | memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char)); 128 | memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *)); 129 | memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *)); 130 | 131 | delete[] m_text; delete[] m_types; 132 | delete[] d2.m_text; delete[] d2.m_types; 133 | 134 | return r; 135 | } 136 | 137 | char *text() { return m_text; } 138 | const std::type_info * * types() { return m_types; } 139 | 140 | protected: 141 | PYBIND11_NOINLINE descr() { } 142 | 143 | template static size_t len(const T *ptr) { // return length including null termination 144 | const T *it = ptr; 145 | while (*it++ != (T) 0) 146 | ; 147 | return static_cast(it - ptr); 148 | } 149 | 150 | const std::type_info **m_types = nullptr; 151 | char *m_text = nullptr; 152 | }; 153 | 154 | /* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */ 155 | 156 | PYBIND11_NOINLINE inline descr _(const char *text) { 157 | const std::type_info *types[1] = { nullptr }; 158 | return descr(text, types); 159 | } 160 | 161 | template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); } 162 | template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); } 163 | template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; } 164 | template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; } 165 | 166 | template PYBIND11_NOINLINE descr _() { 167 | const std::type_info *types[2] = { &typeid(Type), nullptr }; 168 | return descr("%", types); 169 | } 170 | 171 | template PYBIND11_NOINLINE descr _() { 172 | const std::type_info *types[1] = { nullptr }; 173 | return descr(std::to_string(Size).c_str(), types); 174 | } 175 | 176 | PYBIND11_NOINLINE inline descr concat() { return _(""); } 177 | PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; } 178 | template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); } 179 | PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); } 180 | 181 | #define PYBIND11_DESCR ::pybind11::detail::descr 182 | #endif 183 | 184 | NAMESPACE_END(detail) 185 | NAMESPACE_END(pybind11) 186 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/detail/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | #if !defined(_MSC_VER) 18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr 19 | #else 20 | # define PYBIND11_DESCR_CONSTEXPR const 21 | #endif 22 | 23 | /* Concatenate type signatures at compile time */ 24 | template 25 | struct descr { 26 | char text[N + 1]; 27 | 28 | constexpr descr() : text{'\0'} { } 29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } 30 | 31 | template 32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } 33 | 34 | template 35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } 36 | 37 | static constexpr std::array types() { 38 | return {{&typeid(Ts)..., nullptr}}; 39 | } 40 | }; 41 | 42 | template 43 | constexpr descr plus_impl(const descr &a, const descr &b, 44 | index_sequence, index_sequence) { 45 | return {a.text[Is1]..., b.text[Is2]...}; 46 | } 47 | 48 | template 49 | constexpr descr operator+(const descr &a, const descr &b) { 50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence()); 51 | } 52 | 53 | template 54 | constexpr descr _(char const(&text)[N]) { return descr(text); } 55 | constexpr descr<0> _(char const(&)[1]) { return {}; } 56 | 57 | template struct int_to_str : int_to_str { }; 58 | template struct int_to_str<0, Digits...> { 59 | static constexpr auto digits = descr(('0' + Digits)...); 60 | }; 61 | 62 | // Ternary description (like std::conditional) 63 | template 64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { 65 | return _(text1); 66 | } 67 | template 68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { 69 | return _(text2); 70 | } 71 | 72 | template 73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } 74 | template 75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } 76 | 77 | template auto constexpr _() -> decltype(int_to_str::digits) { 78 | return int_to_str::digits; 79 | } 80 | 81 | template constexpr descr<1, Type> _() { return {'%'}; } 82 | 83 | constexpr descr<0> concat() { return {}; } 84 | 85 | template 86 | constexpr descr concat(const descr &descr) { return descr; } 87 | 88 | template 89 | constexpr auto concat(const descr &d, const Args &...args) 90 | -> decltype(std::declval>() + concat(args...)) { 91 | return d + _(", ") + concat(args...); 92 | } 93 | 94 | template 95 | constexpr descr type_descr(const descr &descr) { 96 | return _("{") + descr + _("}"); 97 | } 98 | 99 | NAMESPACE_END(detail) 100 | NAMESPACE_END(PYBIND11_NAMESPACE) 101 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/detail/internals.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/internals.h: Internal data structure and related functions 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "../pytypes.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | // Forward declarations 17 | inline PyTypeObject *make_static_property_type(); 18 | inline PyTypeObject *make_default_metaclass(); 19 | inline PyObject *make_object_base_type(PyTypeObject *metaclass); 20 | 21 | // The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new 22 | // Thread Specific Storage (TSS) API. 23 | #if PY_VERSION_HEX >= 0x03070000 24 | # define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr 25 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key)) 26 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate)) 27 | # define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr) 28 | #else 29 | // Usually an int but a long on Cygwin64 with Python 3.x 30 | # define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0 31 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key)) 32 | # if PY_MAJOR_VERSION < 3 33 | # define PYBIND11_TLS_DELETE_VALUE(key) \ 34 | PyThread_delete_key_value(key) 35 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \ 36 | do { \ 37 | PyThread_delete_key_value((key)); \ 38 | PyThread_set_key_value((key), (value)); \ 39 | } while (false) 40 | # else 41 | # define PYBIND11_TLS_DELETE_VALUE(key) \ 42 | PyThread_set_key_value((key), nullptr) 43 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \ 44 | PyThread_set_key_value((key), (value)) 45 | # endif 46 | #endif 47 | 48 | // Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly 49 | // other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module 50 | // even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under 51 | // libstdc++, this doesn't happen: equality and the type_index hash are based on the type name, 52 | // which works. If not under a known-good stl, provide our own name-based hash and equality 53 | // functions that use the type name. 54 | #if defined(__GLIBCXX__) 55 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; } 56 | using type_hash = std::hash; 57 | using type_equal_to = std::equal_to; 58 | #else 59 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { 60 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; 61 | } 62 | 63 | struct type_hash { 64 | size_t operator()(const std::type_index &t) const { 65 | size_t hash = 5381; 66 | const char *ptr = t.name(); 67 | while (auto c = static_cast(*ptr++)) 68 | hash = (hash * 33) ^ c; 69 | return hash; 70 | } 71 | }; 72 | 73 | struct type_equal_to { 74 | bool operator()(const std::type_index &lhs, const std::type_index &rhs) const { 75 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; 76 | } 77 | }; 78 | #endif 79 | 80 | template 81 | using type_map = std::unordered_map; 82 | 83 | struct overload_hash { 84 | inline size_t operator()(const std::pair& v) const { 85 | size_t value = std::hash()(v.first); 86 | value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2); 87 | return value; 88 | } 89 | }; 90 | 91 | /// Internal data structure used to track registered instances and types. 92 | /// Whenever binary incompatible changes are made to this structure, 93 | /// `PYBIND11_INTERNALS_VERSION` must be incremented. 94 | struct internals { 95 | type_map registered_types_cpp; // std::type_index -> pybind11's type information 96 | std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s) 97 | std::unordered_multimap registered_instances; // void * -> instance* 98 | std::unordered_set, overload_hash> inactive_overload_cache; 99 | type_map> direct_conversions; 100 | std::unordered_map> patients; 101 | std::forward_list registered_exception_translators; 102 | std::unordered_map shared_data; // Custom data to be shared across extensions 103 | std::vector loader_patient_stack; // Used by `loader_life_support` 104 | std::forward_list static_strings; // Stores the std::strings backing detail::c_str() 105 | PyTypeObject *static_property_type; 106 | PyTypeObject *default_metaclass; 107 | PyObject *instance_base; 108 | #if defined(WITH_THREAD) 109 | PYBIND11_TLS_KEY_INIT(tstate); 110 | PyInterpreterState *istate = nullptr; 111 | #endif 112 | }; 113 | 114 | /// Additional type information which does not fit into the PyTypeObject. 115 | /// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`. 116 | struct type_info { 117 | PyTypeObject *type; 118 | const std::type_info *cpptype; 119 | size_t type_size, type_align, holder_size_in_ptrs; 120 | void *(*operator_new)(size_t); 121 | void (*init_instance)(instance *, const void *); 122 | void (*dealloc)(value_and_holder &v_h); 123 | std::vector implicit_conversions; 124 | std::vector> implicit_casts; 125 | std::vector *direct_conversions; 126 | buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; 127 | void *get_buffer_data = nullptr; 128 | void *(*module_local_load)(PyObject *, const type_info *) = nullptr; 129 | /* A simple type never occurs as a (direct or indirect) parent 130 | * of a class that makes use of multiple inheritance */ 131 | bool simple_type : 1; 132 | /* True if there is no multiple inheritance in this type's inheritance tree */ 133 | bool simple_ancestors : 1; 134 | /* for base vs derived holder_type checks */ 135 | bool default_holder : 1; 136 | /* true if this is a type registered with py::module_local */ 137 | bool module_local : 1; 138 | }; 139 | 140 | /// Tracks the `internals` and `type_info` ABI version independent of the main library version 141 | #define PYBIND11_INTERNALS_VERSION 3 142 | 143 | #if defined(_DEBUG) 144 | # define PYBIND11_BUILD_TYPE "_debug" 145 | #else 146 | # define PYBIND11_BUILD_TYPE "" 147 | #endif 148 | 149 | #if defined(WITH_THREAD) 150 | # define PYBIND11_INTERNALS_KIND "" 151 | #else 152 | # define PYBIND11_INTERNALS_KIND "_without_thread" 153 | #endif 154 | 155 | #define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \ 156 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" 157 | 158 | #define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \ 159 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" 160 | 161 | /// Each module locally stores a pointer to the `internals` data. The data 162 | /// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`. 163 | inline internals **&get_internals_pp() { 164 | static internals **internals_pp = nullptr; 165 | return internals_pp; 166 | } 167 | 168 | /// Return a reference to the current `internals` data 169 | PYBIND11_NOINLINE inline internals &get_internals() { 170 | auto **&internals_pp = get_internals_pp(); 171 | if (internals_pp && *internals_pp) 172 | return **internals_pp; 173 | 174 | constexpr auto *id = PYBIND11_INTERNALS_ID; 175 | auto builtins = handle(PyEval_GetBuiltins()); 176 | if (builtins.contains(id) && isinstance(builtins[id])) { 177 | internals_pp = static_cast(capsule(builtins[id])); 178 | 179 | // We loaded builtins through python's builtins, which means that our `error_already_set` 180 | // and `builtin_exception` may be different local classes than the ones set up in the 181 | // initial exception translator, below, so add another for our local exception classes. 182 | // 183 | // libstdc++ doesn't require this (types there are identified only by name) 184 | #if !defined(__GLIBCXX__) 185 | (*internals_pp)->registered_exception_translators.push_front( 186 | [](std::exception_ptr p) -> void { 187 | try { 188 | if (p) std::rethrow_exception(p); 189 | } catch (error_already_set &e) { e.restore(); return; 190 | } catch (const builtin_exception &e) { e.set_error(); return; 191 | } 192 | } 193 | ); 194 | #endif 195 | } else { 196 | if (!internals_pp) internals_pp = new internals*(); 197 | auto *&internals_ptr = *internals_pp; 198 | internals_ptr = new internals(); 199 | #if defined(WITH_THREAD) 200 | PyEval_InitThreads(); 201 | PyThreadState *tstate = PyThreadState_Get(); 202 | #if PY_VERSION_HEX >= 0x03070000 203 | internals_ptr->tstate = PyThread_tss_alloc(); 204 | if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate)) 205 | pybind11_fail("get_internals: could not successfully initialize the TSS key!"); 206 | PyThread_tss_set(internals_ptr->tstate, tstate); 207 | #else 208 | internals_ptr->tstate = PyThread_create_key(); 209 | if (internals_ptr->tstate == -1) 210 | pybind11_fail("get_internals: could not successfully initialize the TLS key!"); 211 | PyThread_set_key_value(internals_ptr->tstate, tstate); 212 | #endif 213 | internals_ptr->istate = tstate->interp; 214 | #endif 215 | builtins[id] = capsule(internals_pp); 216 | internals_ptr->registered_exception_translators.push_front( 217 | [](std::exception_ptr p) -> void { 218 | try { 219 | if (p) std::rethrow_exception(p); 220 | } catch (error_already_set &e) { e.restore(); return; 221 | } catch (const builtin_exception &e) { e.set_error(); return; 222 | } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return; 223 | } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 224 | } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 225 | } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 226 | } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return; 227 | } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 228 | } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return; 229 | } catch (...) { 230 | PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!"); 231 | return; 232 | } 233 | } 234 | ); 235 | internals_ptr->static_property_type = make_static_property_type(); 236 | internals_ptr->default_metaclass = make_default_metaclass(); 237 | internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass); 238 | } 239 | return **internals_pp; 240 | } 241 | 242 | /// Works like `internals.registered_types_cpp`, but for module-local registered types: 243 | inline type_map ®istered_local_types_cpp() { 244 | static type_map locals{}; 245 | return locals; 246 | } 247 | 248 | /// Constructs a std::string with the given arguments, stores it in `internals`, and returns its 249 | /// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only 250 | /// cleared when the program exits or after interpreter shutdown (when embedding), and so are 251 | /// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name). 252 | template 253 | const char *c_str(Args &&...args) { 254 | auto &strings = get_internals().static_strings; 255 | strings.emplace_front(std::forward(args)...); 256 | return strings.front().c_str(); 257 | } 258 | 259 | NAMESPACE_END(detail) 260 | 261 | /// Returns a named pointer that is shared among all extension modules (using the same 262 | /// pybind11 version) running in the current interpreter. Names starting with underscores 263 | /// are reserved for internal usage. Returns `nullptr` if no matching entry was found. 264 | inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { 265 | auto &internals = detail::get_internals(); 266 | auto it = internals.shared_data.find(name); 267 | return it != internals.shared_data.end() ? it->second : nullptr; 268 | } 269 | 270 | /// Set the shared data that can be later recovered by `get_shared_data()`. 271 | inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { 272 | detail::get_internals().shared_data[name] = data; 273 | return data; 274 | } 275 | 276 | /// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if 277 | /// such entry exists. Otherwise, a new object of default-constructible type `T` is 278 | /// added to the shared data under the given name and a reference to it is returned. 279 | template 280 | T &get_or_create_shared_data(const std::string &name) { 281 | auto &internals = detail::get_internals(); 282 | auto it = internals.shared_data.find(name); 283 | T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr); 284 | if (!ptr) { 285 | ptr = new T(); 286 | internals.shared_data[name] = ptr; 287 | } 288 | return *ptr; 289 | } 290 | 291 | NAMESPACE_END(PYBIND11_NAMESPACE) 292 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/detail/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(PYBIND11_NAMESPACE) 54 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/embed.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/embed.h: Support for embedding the interpreter 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include "eval.h" 14 | 15 | #if defined(PYPY_VERSION) 16 | # error Embedding the interpreter is not supported with PyPy 17 | #endif 18 | 19 | #if PY_MAJOR_VERSION >= 3 20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 21 | extern "C" PyObject *pybind11_init_impl_##name() { \ 22 | return pybind11_init_wrapper_##name(); \ 23 | } 24 | #else 25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 26 | extern "C" void pybind11_init_impl_##name() { \ 27 | pybind11_init_wrapper_##name(); \ 28 | } 29 | #endif 30 | 31 | /** \rst 32 | Add a new module to the table of builtins for the interpreter. Must be 33 | defined in global scope. The first macro parameter is the name of the 34 | module (without quotes). The second parameter is the variable which will 35 | be used as the interface to add functions and classes to the module. 36 | 37 | .. code-block:: cpp 38 | 39 | PYBIND11_EMBEDDED_MODULE(example, m) { 40 | // ... initialize functions and classes here 41 | m.def("foo", []() { 42 | return "Hello, World!"; 43 | }); 44 | } 45 | \endrst */ 46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \ 47 | static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ 48 | static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ 49 | auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ 50 | try { \ 51 | PYBIND11_CONCAT(pybind11_init_, name)(m); \ 52 | return m.ptr(); \ 53 | } catch (pybind11::error_already_set &e) { \ 54 | PyErr_SetString(PyExc_ImportError, e.what()); \ 55 | return nullptr; \ 56 | } catch (const std::exception &e) { \ 57 | PyErr_SetString(PyExc_ImportError, e.what()); \ 58 | return nullptr; \ 59 | } \ 60 | } \ 61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 62 | pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \ 63 | PYBIND11_CONCAT(pybind11_init_impl_, name)); \ 64 | void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) 65 | 66 | 67 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 68 | NAMESPACE_BEGIN(detail) 69 | 70 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. 71 | struct embedded_module { 72 | #if PY_MAJOR_VERSION >= 3 73 | using init_t = PyObject *(*)(); 74 | #else 75 | using init_t = void (*)(); 76 | #endif 77 | embedded_module(const char *name, init_t init) { 78 | if (Py_IsInitialized()) 79 | pybind11_fail("Can't add new modules after the interpreter has been initialized"); 80 | 81 | auto result = PyImport_AppendInittab(name, init); 82 | if (result == -1) 83 | pybind11_fail("Insufficient memory to add a new module"); 84 | } 85 | }; 86 | 87 | NAMESPACE_END(detail) 88 | 89 | /** \rst 90 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be 91 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The 92 | optional parameter can be used to skip the registration of signal handlers (see the 93 | `Python documentation`_ for details). Calling this function again after the interpreter 94 | has already been initialized is a fatal error. 95 | 96 | If initializing the Python interpreter fails, then the program is terminated. (This 97 | is controlled by the CPython runtime and is an exception to pybind11's normal behavior 98 | of throwing exceptions on errors.) 99 | 100 | .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx 101 | \endrst */ 102 | inline void initialize_interpreter(bool init_signal_handlers = true) { 103 | if (Py_IsInitialized()) 104 | pybind11_fail("The interpreter is already running"); 105 | 106 | Py_InitializeEx(init_signal_handlers ? 1 : 0); 107 | 108 | // Make .py files in the working directory available by default 109 | module::import("sys").attr("path").cast().append("."); 110 | } 111 | 112 | /** \rst 113 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called 114 | after this. In addition, pybind11 objects must not outlive the interpreter: 115 | 116 | .. code-block:: cpp 117 | 118 | { // BAD 119 | py::initialize_interpreter(); 120 | auto hello = py::str("Hello, World!"); 121 | py::finalize_interpreter(); 122 | } // <-- BOOM, hello's destructor is called after interpreter shutdown 123 | 124 | { // GOOD 125 | py::initialize_interpreter(); 126 | { // scoped 127 | auto hello = py::str("Hello, World!"); 128 | } // <-- OK, hello is cleaned up properly 129 | py::finalize_interpreter(); 130 | } 131 | 132 | { // BETTER 133 | py::scoped_interpreter guard{}; 134 | auto hello = py::str("Hello, World!"); 135 | } 136 | 137 | .. warning:: 138 | 139 | The interpreter can be restarted by calling `initialize_interpreter` again. 140 | Modules created using pybind11 can be safely re-initialized. However, Python 141 | itself cannot completely unload binary extension modules and there are several 142 | caveats with regard to interpreter restarting. All the details can be found 143 | in the CPython documentation. In short, not all interpreter memory may be 144 | freed, either due to reference cycles or user-created global data. 145 | 146 | \endrst */ 147 | inline void finalize_interpreter() { 148 | handle builtins(PyEval_GetBuiltins()); 149 | const char *id = PYBIND11_INTERNALS_ID; 150 | 151 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the 152 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` 153 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). 154 | detail::internals **internals_ptr_ptr = detail::get_internals_pp(); 155 | // It could also be stashed in builtins, so look there too: 156 | if (builtins.contains(id) && isinstance(builtins[id])) 157 | internals_ptr_ptr = capsule(builtins[id]); 158 | 159 | Py_Finalize(); 160 | 161 | if (internals_ptr_ptr) { 162 | delete *internals_ptr_ptr; 163 | *internals_ptr_ptr = nullptr; 164 | } 165 | } 166 | 167 | /** \rst 168 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`. 169 | This a move-only guard and only a single instance can exist. 170 | 171 | .. code-block:: cpp 172 | 173 | #include 174 | 175 | int main() { 176 | py::scoped_interpreter guard{}; 177 | py::print(Hello, World!); 178 | } // <-- interpreter shutdown 179 | \endrst */ 180 | class scoped_interpreter { 181 | public: 182 | scoped_interpreter(bool init_signal_handlers = true) { 183 | initialize_interpreter(init_signal_handlers); 184 | } 185 | 186 | scoped_interpreter(const scoped_interpreter &) = delete; 187 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } 188 | scoped_interpreter &operator=(const scoped_interpreter &) = delete; 189 | scoped_interpreter &operator=(scoped_interpreter &&) = delete; 190 | 191 | ~scoped_interpreter() { 192 | if (is_valid) 193 | finalize_interpreter(); 194 | } 195 | 196 | private: 197 | bool is_valid = true; 198 | }; 199 | 200 | NAMESPACE_END(PYBIND11_NAMESPACE) 201 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/eval.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/exec.h: Support for evaluating Python expressions and statements 3 | from strings and files 4 | 5 | Copyright (c) 2016 Klemens Morgenstern and 6 | Wenzel Jakob 7 | 8 | All rights reserved. Use of this source code is governed by a 9 | BSD-style license that can be found in the LICENSE file. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "pybind11.h" 15 | 16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 17 | 18 | enum eval_mode { 19 | /// Evaluate a string containing an isolated expression 20 | eval_expr, 21 | 22 | /// Evaluate a string containing a single statement. Returns \c none 23 | eval_single_statement, 24 | 25 | /// Evaluate a string containing a sequence of statement. Returns \c none 26 | eval_statements 27 | }; 28 | 29 | template 30 | object eval(str expr, object global = globals(), object local = object()) { 31 | if (!local) 32 | local = global; 33 | 34 | /* PyRun_String does not accept a PyObject / encoding specifier, 35 | this seems to be the only alternative */ 36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; 37 | 38 | int start; 39 | switch (mode) { 40 | case eval_expr: start = Py_eval_input; break; 41 | case eval_single_statement: start = Py_single_input; break; 42 | case eval_statements: start = Py_file_input; break; 43 | default: pybind11_fail("invalid evaluation mode"); 44 | } 45 | 46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); 47 | if (!result) 48 | throw error_already_set(); 49 | return reinterpret_steal(result); 50 | } 51 | 52 | template 53 | object eval(const char (&s)[N], object global = globals(), object local = object()) { 54 | /* Support raw string literals by removing common leading whitespace */ 55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) 56 | : str(s); 57 | return eval(expr, global, local); 58 | } 59 | 60 | inline void exec(str expr, object global = globals(), object local = object()) { 61 | eval(expr, global, local); 62 | } 63 | 64 | template 65 | void exec(const char (&s)[N], object global = globals(), object local = object()) { 66 | eval(s, global, local); 67 | } 68 | 69 | template 70 | object eval_file(str fname, object global = globals(), object local = object()) { 71 | if (!local) 72 | local = global; 73 | 74 | int start; 75 | switch (mode) { 76 | case eval_expr: start = Py_eval_input; break; 77 | case eval_single_statement: start = Py_single_input; break; 78 | case eval_statements: start = Py_file_input; break; 79 | default: pybind11_fail("invalid evaluation mode"); 80 | } 81 | 82 | int closeFile = 1; 83 | std::string fname_str = (std::string) fname; 84 | #if PY_VERSION_HEX >= 0x03040000 85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r"); 86 | #elif PY_VERSION_HEX >= 0x03000000 87 | FILE *f = _Py_fopen(fname.ptr(), "r"); 88 | #else 89 | /* No unicode support in open() :( */ 90 | auto fobj = reinterpret_steal(PyFile_FromString( 91 | const_cast(fname_str.c_str()), 92 | const_cast("r"))); 93 | FILE *f = nullptr; 94 | if (fobj) 95 | f = PyFile_AsFile(fobj.ptr()); 96 | closeFile = 0; 97 | #endif 98 | if (!f) { 99 | PyErr_Clear(); 100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!"); 101 | } 102 | 103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) 104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), 105 | local.ptr()); 106 | (void) closeFile; 107 | #else 108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), 109 | local.ptr(), closeFile); 110 | #endif 111 | 112 | if (!result) 113 | throw error_already_set(); 114 | return reinterpret_steal(result); 115 | } 116 | 117 | NAMESPACE_END(PYBIND11_NAMESPACE) 118 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") 79 | + make_caster::name + _("]")); 80 | }; 81 | 82 | NAMESPACE_END(detail) 83 | NAMESPACE_END(PYBIND11_NAMESPACE) 84 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/iostream.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python 3 | 4 | Copyright (c) 2017 Henry F. Schreiner 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | NAMESPACE_BEGIN(detail) 22 | 23 | // Buffer that writes to Python instead of C++ 24 | class pythonbuf : public std::streambuf { 25 | private: 26 | using traits_type = std::streambuf::traits_type; 27 | 28 | char d_buffer[1024]; 29 | object pywrite; 30 | object pyflush; 31 | 32 | int overflow(int c) { 33 | if (!traits_type::eq_int_type(c, traits_type::eof())) { 34 | *pptr() = traits_type::to_char_type(c); 35 | pbump(1); 36 | } 37 | return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); 38 | } 39 | 40 | int sync() { 41 | if (pbase() != pptr()) { 42 | // This subtraction cannot be negative, so dropping the sign 43 | str line(pbase(), static_cast(pptr() - pbase())); 44 | 45 | pywrite(line); 46 | pyflush(); 47 | 48 | setp(pbase(), epptr()); 49 | } 50 | return 0; 51 | } 52 | 53 | public: 54 | pythonbuf(object pyostream) 55 | : pywrite(pyostream.attr("write")), 56 | pyflush(pyostream.attr("flush")) { 57 | setp(d_buffer, d_buffer + sizeof(d_buffer) - 1); 58 | } 59 | 60 | /// Sync before destroy 61 | ~pythonbuf() { 62 | sync(); 63 | } 64 | }; 65 | 66 | NAMESPACE_END(detail) 67 | 68 | 69 | /** \rst 70 | This a move-only guard that redirects output. 71 | 72 | .. code-block:: cpp 73 | 74 | #include 75 | 76 | ... 77 | 78 | { 79 | py::scoped_ostream_redirect output; 80 | std::cout << "Hello, World!"; // Python stdout 81 | } // <-- return std::cout to normal 82 | 83 | You can explicitly pass the c++ stream and the python object, 84 | for example to guard stderr instead. 85 | 86 | .. code-block:: cpp 87 | 88 | { 89 | py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; 90 | std::cerr << "Hello, World!"; 91 | } 92 | \endrst */ 93 | class scoped_ostream_redirect { 94 | protected: 95 | std::streambuf *old; 96 | std::ostream &costream; 97 | detail::pythonbuf buffer; 98 | 99 | public: 100 | scoped_ostream_redirect( 101 | std::ostream &costream = std::cout, 102 | object pyostream = module::import("sys").attr("stdout")) 103 | : costream(costream), buffer(pyostream) { 104 | old = costream.rdbuf(&buffer); 105 | } 106 | 107 | ~scoped_ostream_redirect() { 108 | costream.rdbuf(old); 109 | } 110 | 111 | scoped_ostream_redirect(const scoped_ostream_redirect &) = delete; 112 | scoped_ostream_redirect(scoped_ostream_redirect &&other) = default; 113 | scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete; 114 | scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete; 115 | }; 116 | 117 | 118 | /** \rst 119 | Like `scoped_ostream_redirect`, but redirects cerr by default. This class 120 | is provided primary to make ``py::call_guard`` easier to make. 121 | 122 | .. code-block:: cpp 123 | 124 | m.def("noisy_func", &noisy_func, 125 | py::call_guard()); 127 | 128 | \endrst */ 129 | class scoped_estream_redirect : public scoped_ostream_redirect { 130 | public: 131 | scoped_estream_redirect( 132 | std::ostream &costream = std::cerr, 133 | object pyostream = module::import("sys").attr("stderr")) 134 | : scoped_ostream_redirect(costream,pyostream) {} 135 | }; 136 | 137 | 138 | NAMESPACE_BEGIN(detail) 139 | 140 | // Class to redirect output as a context manager. C++ backend. 141 | class OstreamRedirect { 142 | bool do_stdout_; 143 | bool do_stderr_; 144 | std::unique_ptr redirect_stdout; 145 | std::unique_ptr redirect_stderr; 146 | 147 | public: 148 | OstreamRedirect(bool do_stdout = true, bool do_stderr = true) 149 | : do_stdout_(do_stdout), do_stderr_(do_stderr) {} 150 | 151 | void enter() { 152 | if (do_stdout_) 153 | redirect_stdout.reset(new scoped_ostream_redirect()); 154 | if (do_stderr_) 155 | redirect_stderr.reset(new scoped_estream_redirect()); 156 | } 157 | 158 | void exit() { 159 | redirect_stdout.reset(); 160 | redirect_stderr.reset(); 161 | } 162 | }; 163 | 164 | NAMESPACE_END(detail) 165 | 166 | /** \rst 167 | This is a helper function to add a C++ redirect context manager to Python 168 | instead of using a C++ guard. To use it, add the following to your binding code: 169 | 170 | .. code-block:: cpp 171 | 172 | #include 173 | 174 | ... 175 | 176 | py::add_ostream_redirect(m, "ostream_redirect"); 177 | 178 | You now have a Python context manager that redirects your output: 179 | 180 | .. code-block:: python 181 | 182 | with m.ostream_redirect(): 183 | m.print_to_cout_function() 184 | 185 | This manager can optionally be told which streams to operate on: 186 | 187 | .. code-block:: python 188 | 189 | with m.ostream_redirect(stdout=true, stderr=true): 190 | m.noisy_function_with_error_printing() 191 | 192 | \endrst */ 193 | inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { 194 | return class_(m, name.c_str(), module_local()) 195 | .def(init(), arg("stdout")=true, arg("stderr")=true) 196 | .def("__enter__", &detail::OstreamRedirect::enter) 197 | .def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); }); 198 | } 199 | 200 | NAMESPACE_END(PYBIND11_NAMESPACE) 201 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/operators.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/operator.h: Metatemplates for operator overloading 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #if defined(__clang__) && !defined(__INTEL_COMPILER) 15 | # pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) 16 | #elif defined(_MSC_VER) 17 | # pragma warning(push) 18 | # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 19 | #endif 20 | 21 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 22 | NAMESPACE_BEGIN(detail) 23 | 24 | /// Enumeration with all supported operator types 25 | enum op_id : int { 26 | op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, 27 | op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, 28 | op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, 29 | op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, 30 | op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, 31 | op_repr, op_truediv, op_itruediv, op_hash 32 | }; 33 | 34 | enum op_type : int { 35 | op_l, /* base type on left */ 36 | op_r, /* base type on right */ 37 | op_u /* unary operator */ 38 | }; 39 | 40 | struct self_t { }; 41 | static const self_t self = self_t(); 42 | 43 | /// Type for an unused type slot 44 | struct undefined_t { }; 45 | 46 | /// Don't warn about an unused variable 47 | inline self_t __self() { return self; } 48 | 49 | /// base template of operator implementations 50 | template struct op_impl { }; 51 | 52 | /// Operator implementation generator 53 | template struct op_ { 54 | template void execute(Class &cl, const Extra&... extra) const { 55 | using Base = typename Class::type; 56 | using L_type = conditional_t::value, Base, L>; 57 | using R_type = conditional_t::value, Base, R>; 58 | using op = op_impl; 59 | cl.def(op::name(), &op::execute, is_operator(), extra...); 60 | #if PY_MAJOR_VERSION < 3 61 | if (id == op_truediv || id == op_itruediv) 62 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 63 | &op::execute, is_operator(), extra...); 64 | #endif 65 | } 66 | template void execute_cast(Class &cl, const Extra&... extra) const { 67 | using Base = typename Class::type; 68 | using L_type = conditional_t::value, Base, L>; 69 | using R_type = conditional_t::value, Base, R>; 70 | using op = op_impl; 71 | cl.def(op::name(), &op::execute_cast, is_operator(), extra...); 72 | #if PY_MAJOR_VERSION < 3 73 | if (id == op_truediv || id == op_itruediv) 74 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 75 | &op::execute, is_operator(), extra...); 76 | #endif 77 | } 78 | }; 79 | 80 | #define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ 81 | template struct op_impl { \ 82 | static char const* name() { return "__" #id "__"; } \ 83 | static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ 84 | static B execute_cast(const L &l, const R &r) { return B(expr); } \ 85 | }; \ 86 | template struct op_impl { \ 87 | static char const* name() { return "__" #rid "__"; } \ 88 | static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ 89 | static B execute_cast(const R &r, const L &l) { return B(expr); } \ 90 | }; \ 91 | inline op_ op(const self_t &, const self_t &) { \ 92 | return op_(); \ 93 | } \ 94 | template op_ op(const self_t &, const T &) { \ 95 | return op_(); \ 96 | } \ 97 | template op_ op(const T &, const self_t &) { \ 98 | return op_(); \ 99 | } 100 | 101 | #define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ 102 | template struct op_impl { \ 103 | static char const* name() { return "__" #id "__"; } \ 104 | static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ 105 | static B execute_cast(L &l, const R &r) { return B(expr); } \ 106 | }; \ 107 | template op_ op(const self_t &, const T &) { \ 108 | return op_(); \ 109 | } 110 | 111 | #define PYBIND11_UNARY_OPERATOR(id, op, expr) \ 112 | template struct op_impl { \ 113 | static char const* name() { return "__" #id "__"; } \ 114 | static auto execute(const L &l) -> decltype(expr) { return expr; } \ 115 | static B execute_cast(const L &l) { return B(expr); } \ 116 | }; \ 117 | inline op_ op(const self_t &) { \ 118 | return op_(); \ 119 | } 120 | 121 | PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) 122 | PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) 123 | PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) 124 | PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) 125 | PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) 126 | PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) 127 | PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) 128 | PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) 129 | PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) 130 | PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) 131 | PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) 132 | PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) 133 | PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) 134 | PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) 135 | PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) 136 | PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) 137 | //PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) 138 | PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) 139 | PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) 140 | PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) 141 | PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) 142 | PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) 143 | PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) 144 | PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) 145 | PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) 146 | PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) 147 | PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) 148 | PYBIND11_UNARY_OPERATOR(neg, operator-, -l) 149 | PYBIND11_UNARY_OPERATOR(pos, operator+, +l) 150 | PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) 151 | PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l)) 152 | PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) 153 | PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) 154 | PYBIND11_UNARY_OPERATOR(int, int_, (int) l) 155 | PYBIND11_UNARY_OPERATOR(float, float_, (double) l) 156 | 157 | #undef PYBIND11_BINARY_OPERATOR 158 | #undef PYBIND11_INPLACE_OPERATOR 159 | #undef PYBIND11_UNARY_OPERATOR 160 | NAMESPACE_END(detail) 161 | 162 | using detail::self; 163 | 164 | NAMESPACE_END(PYBIND11_NAMESPACE) 165 | 166 | #if defined(_MSC_VER) 167 | # pragma warning(pop) 168 | #endif 169 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /post_processing/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /post_processing/kmeans.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/9/11 16:24 3 | # @Author : zhoujun 4 | 5 | import numpy as np 6 | from sklearn.cluster import KMeans 7 | 8 | def km(text, similarity_vectors, label, label_values, dis_threshold=0.8): 9 | similarity_vectors = similarity_vectors * np.expand_dims(text,2) 10 | # 计算聚类中心 11 | cluster_centers = [[0,0,0,0]] 12 | for i in label_values: 13 | kernel_idx = label == i 14 | kernel_similarity_vector = similarity_vectors[kernel_idx].mean(0) # 4 15 | cluster_centers.append(kernel_similarity_vector) 16 | n = len(label_values) + 1 17 | similarity_vectors = similarity_vectors.reshape(-1,4) 18 | y_pred = KMeans(n,init=np.array(cluster_centers),n_init=1).fit_predict(similarity_vectors) 19 | y_pred = y_pred.reshape(text.shape) 20 | return y_pred 21 | -------------------------------------------------------------------------------- /post_processing/pse.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // pse 3 | // Created by zhoujun on 11/9/19. 4 | // Copyright © 2019年 zhoujun. All rights reserved. 5 | // 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "include/pybind11/pybind11.h" 12 | #include "include/pybind11/numpy.h" 13 | #include "include/pybind11/stl.h" 14 | #include "include/pybind11/stl_bind.h" 15 | 16 | namespace py = pybind11; 17 | 18 | 19 | namespace pan{ 20 | py::array_t pse( 21 | py::array_t text, 22 | py::array_t similarity_vectors, 23 | py::array_t label_map, 24 | int label_num, 25 | float dis_threshold = 0.8) 26 | { 27 | auto pbuf_text = text.request(); 28 | auto pbuf_similarity_vectors = similarity_vectors.request(); 29 | auto pbuf_label_map = label_map.request(); 30 | if (pbuf_label_map.ndim != 2 || pbuf_label_map.shape[0]==0 || pbuf_label_map.shape[1]==0) 31 | throw std::runtime_error("label map must have a shape of (h>0, w>0)"); 32 | int h = pbuf_label_map.shape[0]; 33 | int w = pbuf_label_map.shape[1]; 34 | if (pbuf_similarity_vectors.ndim != 3 || pbuf_similarity_vectors.shape[0]!=h || pbuf_similarity_vectors.shape[1]!=w || pbuf_similarity_vectors.shape[2]!=4 || 35 | pbuf_text.shape[0]!=h || pbuf_text.shape[1]!=w) 36 | throw std::runtime_error("similarity_vectors must have a shape of (h,w,4) and text must have a shape of (h,w,4)"); 37 | //初始化结果 38 | auto res = py::array_t(pbuf_text.size); 39 | auto pbuf_res = res.request(); 40 | // 获取 text similarity_vectors 和 label_map的指针 41 | auto ptr_label_map = static_cast(pbuf_label_map.ptr); 42 | auto ptr_text = static_cast(pbuf_text.ptr); 43 | auto ptr_similarity_vectors = static_cast(pbuf_similarity_vectors.ptr); 44 | auto ptr_res = static_cast(pbuf_res.ptr); 45 | 46 | std::queue> q; 47 | // 计算各个kernel的similarity_vectors 48 | float kernel_vector[label_num][5] = {0}; 49 | 50 | // 文本像素入队列 51 | for (int i = 0; i0) 60 | { 61 | kernel_vector[label][0] += p_similarity_vectors[k]; 62 | kernel_vector[label][1] += p_similarity_vectors[k+1]; 63 | kernel_vector[label][2] += p_similarity_vectors[k+2]; 64 | kernel_vector[label][3] += p_similarity_vectors[k+3]; 65 | kernel_vector[label][4] += 1; 66 | q.push(std::make_tuple(i, j, label)); 67 | } 68 | p_res[j] = label; 69 | } 70 | } 71 | 72 | for(int i=0;i(q_n); 87 | int x = std::get<1>(q_n); 88 | int32_t l = std::get<2>(q_n); 89 | //store the edge pixel after one expansion 90 | auto kernel_cv = kernel_vector[l]; 91 | for (int idx=0; idx<4; idx++) 92 | { 93 | int tmpy = y + dy[idx]; 94 | int tmpx = x + dx[idx]; 95 | auto p_res = ptr_res + tmpy*w; 96 | if (tmpy<0 || tmpy>=h || tmpx<0 || tmpx>=w) 97 | continue; 98 | if (!ptr_text[tmpy*w+tmpx] || p_res[tmpx]>0) 99 | continue; 100 | // 计算距离 101 | float dis = 0; 102 | auto p_similarity_vectors = ptr_similarity_vectors + tmpy * w*4; 103 | for(size_t i=0;i<4;i++) 104 | { 105 | dis += pow(kernel_cv[i] - p_similarity_vectors[tmpx*4 + i],2); 106 | } 107 | dis = sqrt(dis); 108 | if(dis >= dis_threshold) 109 | continue; 110 | q.push(std::make_tuple(tmpy, tmpx, l)); 111 | p_res[tmpx]=l; 112 | } 113 | } 114 | return res; 115 | } 116 | 117 | std::map> get_points( 118 | py::array_t label_map, 119 | py::array_t score_map, 120 | int label_num) 121 | { 122 | auto pbuf_label_map = label_map.request(); 123 | auto pbuf_score_map = score_map.request(); 124 | auto ptr_label_map = static_cast(pbuf_label_map.ptr); 125 | auto ptr_score_map = static_cast(pbuf_score_map.ptr); 126 | int h = pbuf_label_map.shape[0]; 127 | int w = pbuf_label_map.shape[1]; 128 | 129 | std::map> point_dict; 130 | std::vector> point_vector; 131 | for(int i=0;i point; 134 | point.push_back(0); 135 | point.push_back(0); 136 | point_vector.push_back(point); 137 | } 138 | for (int i = 0; i 2) 159 | { 160 | point_vector[i][0] /= point_vector[i][1]; 161 | point_dict[i] = point_vector[i]; 162 | } 163 | } 164 | return point_dict; 165 | } 166 | std::vector get_num( 167 | py::array_t label_map, 168 | int label_num) 169 | { 170 | auto pbuf_label_map = label_map.request(); 171 | auto ptr_label_map = static_cast(pbuf_label_map.ptr); 172 | int h = pbuf_label_map.shape[0]; 173 | int w = pbuf_label_map.shape[1]; 174 | 175 | std::vector point_vector; 176 | for(int i=0;i 0)).transpose((1, 0)) 17 | 18 | for point_idx in range(points.shape[0]): 19 | y, x = points[point_idx, 0], points[point_idx, 1] 20 | label_value = label[y, x] 21 | queue.put((y, x, label_value)) 22 | pred[y, x] = label_value 23 | # 计算kernel的值 24 | d = {} 25 | for i in label_values: 26 | kernel_idx = label == i 27 | kernel_similarity_vector = similarity_vectors[kernel_idx].mean(0) # 4 28 | d[i] = kernel_similarity_vector 29 | 30 | dx = [-1, 1, 0, 0] 31 | dy = [0, 0, -1, 1] 32 | kernal = text.copy() 33 | while not queue.empty(): 34 | (y, x, label_value) = queue.get() 35 | cur_kernel_sv = d[label_value] 36 | for j in range(4): 37 | tmpx = x + dx[j] 38 | tmpy = y + dy[j] 39 | if tmpx < 0 or tmpy >= kernal.shape[0] or tmpy < 0 or tmpx >= kernal.shape[1]: 40 | continue 41 | if kernal[tmpy, tmpx] == 0 or pred[tmpy, tmpx] > 0: 42 | continue 43 | if np.linalg.norm(similarity_vectors[tmpy, tmpx] - cur_kernel_sv) >= dis_threshold: 44 | continue 45 | queue.put((tmpy, tmpx, label_value)) 46 | pred[tmpy, tmpx] = label_value 47 | return pred 48 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/24 12:06 3 | # @Author : zhoujun 4 | 5 | import torch 6 | from torchvision import transforms 7 | import os 8 | import cv2 9 | import time 10 | 11 | from models import get_model 12 | 13 | from post_processing import decode 14 | 15 | def decode_clip(preds, scale=1, threshold=0.7311, min_area=5): 16 | import pyclipper 17 | import numpy as np 18 | preds[:2, :, :] = torch.sigmoid(preds[:2, :, :]) 19 | preds = preds.detach().cpu().numpy() 20 | text = preds[0] > threshold # text 21 | kernel = (preds[1] > threshold) * text # kernel 22 | 23 | label_num, label = cv2.connectedComponents(kernel.astype(np.uint8), connectivity=4) 24 | bbox_list = [] 25 | for label_idx in range(1, label_num): 26 | points = np.array(np.where(label == label_idx)).transpose((1, 0))[:, ::-1] 27 | if points.shape[0] < min_area: 28 | continue 29 | rect = cv2.minAreaRect(points) 30 | poly = cv2.boxPoints(rect).astype(int) 31 | 32 | d_i = cv2.contourArea(poly) * 1.5 / cv2.arcLength(poly, True) 33 | pco = pyclipper.PyclipperOffset() 34 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 35 | shrinked_poly = np.array(pco.Execute(d_i)) 36 | if shrinked_poly.size == 0: 37 | continue 38 | rect = cv2.minAreaRect(shrinked_poly) 39 | shrinked_poly = cv2.boxPoints(rect).astype(int) 40 | if cv2.contourArea(shrinked_poly) < 800 / (scale * scale): 41 | continue 42 | 43 | bbox_list.append([shrinked_poly[1], shrinked_poly[2], shrinked_poly[3], shrinked_poly[0]]) 44 | return label, np.array(bbox_list) 45 | 46 | 47 | class Pytorch_model: 48 | def __init__(self, model_path, gpu_id=None): 49 | ''' 50 | 初始化pytorch模型 51 | :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) 52 | :param gpu_id: 在哪一块gpu上运行 53 | ''' 54 | self.gpu_id = gpu_id 55 | 56 | if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available(): 57 | self.device = torch.device("cuda:%s" % self.gpu_id) 58 | else: 59 | self.device = torch.device("cpu") 60 | print('device:', self.device) 61 | checkpoint = torch.load(model_path, map_location=self.device) 62 | 63 | config = checkpoint['config'] 64 | config['arch']['args']['pretrained'] = False 65 | self.net = get_model(config) 66 | 67 | self.img_channel = config['data_loader']['args']['dataset']['img_channel'] 68 | self.net.load_state_dict(checkpoint['state_dict']) 69 | self.net.to(self.device) 70 | self.net.eval() 71 | 72 | def predict(self, img: str, short_size: int = 736): 73 | ''' 74 | 对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢 75 | :param img: 图像地址 76 | :param is_numpy: 77 | :return: 78 | ''' 79 | assert os.path.exists(img), 'file is not exists' 80 | img = cv2.imread(img) 81 | if self.img_channel == 3: 82 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 83 | h, w = img.shape[:2] 84 | scale = short_size / min(h, w) 85 | img = cv2.resize(img, None, fx=scale, fy=scale) 86 | # 将图片由(w,h)变为(1,img_channel,h,w) 87 | tensor = transforms.ToTensor()(img) 88 | tensor = tensor.unsqueeze_(0) 89 | 90 | tensor = tensor.to(self.device) 91 | with torch.no_grad(): 92 | if str(self.device).__contains__('cuda'): 93 | torch.cuda.synchronize(self.device) 94 | start = time.time() 95 | preds = self.net(tensor)[0] 96 | if str(self.device).__contains__('cuda'): 97 | torch.cuda.synchronize(self.device) 98 | preds, boxes_list = decode(preds) 99 | scale = (preds.shape[1] / w, preds.shape[0] / h) 100 | if len(boxes_list): 101 | boxes_list = boxes_list / scale 102 | t = time.time() - start 103 | return preds, boxes_list, t 104 | 105 | 106 | if __name__ == '__main__': 107 | import matplotlib.pyplot as plt 108 | from utils.util import show_img, draw_bbox 109 | 110 | os.environ['CUDA_VISIBLE_DEVICES'] = str('0') 111 | 112 | model_path = 'output/PAN_shufflenetv2_FPEM_FFM.pth' 113 | 114 | img_id = 10 115 | img_path = 'E:/zj/dataset/icdar2015/test/img/img_{}.jpg'.format(img_id) 116 | 117 | # 初始化网络 118 | model = Pytorch_model(model_path, gpu_id=0) 119 | preds, boxes_list, t = model.predict(img_path) 120 | show_img(preds) 121 | img = draw_bbox(cv2.imread(img_path)[:, :, ::-1], boxes_list) 122 | show_img(img, color=True) 123 | plt.show() 124 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 22:00 3 | # @Author : zhoujun 4 | 5 | from __future__ import print_function 6 | import os 7 | from utils import load_json 8 | 9 | config = load_json('config.json') 10 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in config['trainer']['gpus']]) 11 | 12 | from models import get_model, get_loss 13 | from data_loader import get_dataloader 14 | from trainer import Trainer 15 | 16 | 17 | def main(config): 18 | train_loader = get_dataloader(config['data_loader']['type'], config['data_loader']['args']) 19 | 20 | criterion = get_loss(config).cuda() 21 | 22 | model = get_model(config) 23 | 24 | trainer = Trainer(config=config, 25 | model=model, 26 | criterion=criterion, 27 | train_loader=train_loader) 28 | trainer.train() 29 | 30 | 31 | if __name__ == '__main__': 32 | main(config) 33 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:58 3 | # @Author : zhoujun 4 | from .trainer import Trainer -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:58 3 | # @Author : zhoujun 4 | import os 5 | import cv2 6 | import shutil 7 | import numpy as np 8 | import traceback 9 | import time 10 | from tqdm import tqdm 11 | import torch 12 | import torchvision.utils as vutils 13 | from torchvision import transforms 14 | from post_processing import decode 15 | from utils import PolynomialLR, runningScore, cal_text_score, cal_kernel_score, cal_recall_precison_f1 16 | 17 | from base import BaseTrainer 18 | 19 | 20 | class Trainer(BaseTrainer): 21 | def __init__(self, config, model, criterion, train_loader, weights_init=None): 22 | super(Trainer, self).__init__(config, model, criterion, weights_init) 23 | self.show_images_interval = self.config['trainer']['show_images_interval'] 24 | self.test_path = self.config['data_loader']['args']['dataset']['val_data_path'] 25 | self.train_loader = train_loader 26 | self.train_loader_len = len(train_loader) 27 | if self.config['lr_scheduler']['type'] == 'PolynomialLR': 28 | self.scheduler = PolynomialLR(self.optimizer, self.epochs * self.train_loader_len) 29 | 30 | self.logger.info('train dataset has {} samples,{} in dataloader'.format(self.train_loader.dataset_len, 31 | self.train_loader_len)) 32 | 33 | def _train_epoch(self, epoch): 34 | self.model.train() 35 | epoch_start = time.time() 36 | batch_start = time.time() 37 | train_loss = 0. 38 | running_metric_text = runningScore(2) 39 | running_metric_kernel = runningScore(2) 40 | lr = self.optimizer.param_groups[0]['lr'] 41 | for i, (images, labels, training_masks) in enumerate(self.train_loader): 42 | if i >= self.train_loader_len: 43 | break 44 | self.global_step += 1 45 | lr = self.optimizer.param_groups[0]['lr'] 46 | 47 | # 数据进行转换和丢到gpu 48 | cur_batch_size = images.size()[0] 49 | images, labels, training_masks = images.to(self.device), labels.to(self.device), training_masks.to( 50 | self.device) 51 | 52 | preds = self.model(images) 53 | loss_all, loss_tex, loss_ker, loss_agg, loss_dis = self.criterion(preds, labels, training_masks) 54 | # backward 55 | self.optimizer.zero_grad() 56 | loss_all.backward() 57 | self.optimizer.step() 58 | if self.config['lr_scheduler']['type'] == 'PolynomialLR': 59 | self.scheduler.step() 60 | # acc iou 61 | score_text = cal_text_score(preds[:, 0, :, :], labels[:, 0, :, :], training_masks, running_metric_text) 62 | score_kernel = cal_kernel_score(preds[:, 1, :, :], labels[:, 1, :, :], labels[:, 0, :, :], training_masks, 63 | running_metric_kernel) 64 | 65 | # loss 和 acc 记录到日志 66 | loss_all = loss_all.item() 67 | loss_tex = loss_tex.item() 68 | loss_ker = loss_ker.item() 69 | loss_agg = loss_agg.item() 70 | loss_dis = loss_dis.item() 71 | train_loss += loss_all 72 | acc = score_text['Mean Acc'] 73 | iou_text = score_text['Mean IoU'] 74 | iou_kernel = score_kernel['Mean IoU'] 75 | 76 | if (i + 1) % self.display_interval == 0: 77 | batch_time = time.time() - batch_start 78 | self.logger.info( 79 | '[{}/{}], [{}/{}], global_step: {}, Speed: {:.1f} samples/sec, acc: {:.4f}, iou_text: {:.4f}, iou_kernel: {:.4f}, loss_all: {:.4f}, loss_tex: {:.4f}, loss_ker: {:.4f}, loss_agg: {:.4f}, loss_dis: {:.4f}, lr:{:.6}, time:{:.2f}'.format( 80 | epoch, self.epochs, i + 1, self.train_loader_len, self.global_step, 81 | self.display_interval * cur_batch_size / batch_time, acc, iou_text, 82 | iou_kernel, loss_all, loss_tex, loss_ker, loss_agg, loss_dis, lr, batch_time)) 83 | batch_start = time.time() 84 | 85 | if self.tensorboard_enable: 86 | # write tensorboard 87 | self.writer.add_scalar('TRAIN/LOSS/loss_all', loss_all, self.global_step) 88 | self.writer.add_scalar('TRAIN/LOSS/loss_tex', loss_tex, self.global_step) 89 | self.writer.add_scalar('TRAIN/LOSS/loss_ker', loss_ker, self.global_step) 90 | self.writer.add_scalar('TRAIN/LOSS/loss_agg', loss_agg, self.global_step) 91 | self.writer.add_scalar('TRAIN/LOSS/loss_dis', loss_dis, self.global_step) 92 | self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc, self.global_step) 93 | self.writer.add_scalar('TRAIN/ACC_IOU/iou_text', iou_text, self.global_step) 94 | self.writer.add_scalar('TRAIN/ACC_IOU/iou_kernel', iou_kernel, self.global_step) 95 | self.writer.add_scalar('TRAIN/lr', lr, self.global_step) 96 | if i % self.show_images_interval == 0: 97 | # show images on tensorboard 98 | self.writer.add_images('TRAIN/imgs', images, self.global_step) 99 | # text kernel and training_masks 100 | gt_texts, gt_kernels = labels[:, 0, :, :], labels[:, 1, :, :] 101 | gt_texts[gt_texts <= 0.5] = 0 102 | gt_texts[gt_texts > 0.5] = 1 103 | gt_kernels[gt_kernels <= 0.5] = 0 104 | gt_kernels[gt_kernels > 0.5] = 1 105 | show_label = torch.cat([gt_texts, gt_kernels, training_masks.float()]) 106 | show_label = vutils.make_grid(show_label.unsqueeze(1), nrow=cur_batch_size, normalize=False, 107 | padding=20, 108 | pad_value=1) 109 | self.writer.add_image('TRAIN/gt', show_label, self.global_step) 110 | # model output 111 | preds[:, :2, :, :] = torch.sigmoid(preds[:, :2, :, :]) 112 | show_pred = torch.cat([preds[:, 0, :, :], preds[:, 1, :, :]]) 113 | show_pred = vutils.make_grid(show_pred.unsqueeze(1), nrow=cur_batch_size, normalize=False, 114 | padding=20, 115 | pad_value=1) 116 | self.writer.add_image('TRAIN/preds', show_pred, self.global_step) 117 | 118 | return {'train_loss': train_loss / self.train_loader_len, 'lr': lr, 'time': time.time() - epoch_start, 119 | 'epoch': epoch} 120 | 121 | def _eval(self): 122 | self.model.eval() 123 | # torch.cuda.empty_cache() # speed up evaluating after training finished 124 | img_path = os.path.join(self.test_path, 'img') 125 | gt_path = os.path.join(self.test_path, 'gt') 126 | result_save_path = os.path.join(self.save_dir, 'result') 127 | if os.path.exists(result_save_path): 128 | shutil.rmtree(result_save_path, ignore_errors=True) 129 | if not os.path.exists(result_save_path): 130 | os.makedirs(result_save_path) 131 | short_size = 736 132 | # 预测所有测试图片 133 | img_paths = [os.path.join(img_path, x) for x in os.listdir(img_path)] 134 | for img_path in tqdm(img_paths, desc='test models'): 135 | img_name = os.path.basename(img_path).split('.')[0] 136 | save_name = os.path.join(result_save_path, 'res_' + img_name + '.txt') 137 | 138 | assert os.path.exists(img_path), 'file is not exists' 139 | img = cv2.imread(img_path) 140 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 141 | h, w = img.shape[:2] 142 | scale = short_size / min(h, w) 143 | img = cv2.resize(img, None, fx=scale, fy=scale) 144 | # 将图片由(w,h)变为(1,img_channel,h,w) 145 | tensor = transforms.ToTensor()(img) 146 | tensor = tensor.unsqueeze_(0) 147 | 148 | tensor = tensor.to(self.device) 149 | with torch.no_grad(): 150 | torch.cuda.synchronize(self.device) 151 | preds = self.model(tensor)[0] 152 | torch.cuda.synchronize(self.device) 153 | preds, boxes_list = decode(preds) 154 | scale = (preds.shape[1] / w, preds.shape[0] / h) 155 | if len(boxes_list): 156 | boxes_list = boxes_list / scale 157 | np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d') 158 | # 开始计算 recall precision f1 159 | result_dict = cal_recall_precison_f1(gt_path=gt_path, result_path=result_save_path) 160 | return result_dict['recall'], result_dict['precision'], result_dict['hmean'] 161 | 162 | def _on_epoch_finish(self): 163 | self.logger.info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format( 164 | self.epoch_result['epoch'], self.epochs, self.epoch_result['train_loss'], self.epoch_result['time'], 165 | self.epoch_result['lr'])) 166 | net_save_path = '{}/PANNet_latest.pth'.format(self.checkpoint_dir) 167 | 168 | save_best = False 169 | if self.config['trainer']['metrics'] == 'hmean': # 使用f1作为最优模型指标 170 | recall, precision, hmean = self._eval() 171 | 172 | if self.tensorboard_enable: 173 | self.writer.add_scalar('EVAL/recall', recall, self.global_step) 174 | self.writer.add_scalar('EVAL/precision', precision, self.global_step) 175 | self.writer.add_scalar('EVAL/hmean', hmean, self.global_step) 176 | self.logger.info('test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.format(recall, precision, hmean)) 177 | 178 | if hmean > self.metrics['hmean']: 179 | save_best = True 180 | self.metrics['train_loss'] = self.epoch_result['train_loss'] 181 | self.metrics['hmean'] = hmean 182 | self.metrics['precision'] = precision 183 | self.metrics['recall'] = recall 184 | self.metrics['best_model'] = net_save_path 185 | else: 186 | if self.epoch_result['train_loss'] < self.metrics['train_loss']: 187 | save_best = True 188 | self.metrics['train_loss'] = self.epoch_result['train_loss'] 189 | self.metrics['best_model'] = net_save_path 190 | self._save_checkpoint(self.epoch_result['epoch'], net_save_path, save_best) 191 | 192 | def _on_train_finish(self): 193 | for k, v in self.metrics.items(): 194 | self.logger.info('{}:{}'.format(k, v)) 195 | self.logger.info('finish train') 196 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:58 3 | # @Author : zhoujun 4 | from .util import * 5 | from .metrics import * 6 | from .schedulers import * 7 | from .cal_recall.script import cal_recall_precison_f1 8 | -------------------------------------------------------------------------------- /utils/cal_recall/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/16/19 6:40 AM 3 | # @Author : zhoujun 4 | from .script import cal_recall_precison_f1 5 | __all__ = ['cal_recall_precison_f1'] -------------------------------------------------------------------------------- /utils/make_trainfile.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/24 12:06 3 | # @Author : zhoujun 4 | import os 5 | import glob 6 | import pathlib 7 | 8 | data_path = r'E:\zj\dataset\icdar2015\test' 9 | # data_path/img 存放图片 10 | # data_path/gt 存放标签文件 11 | 12 | f_w = open(os.path.join(data_path, 'test.txt'), 'w', encoding='utf8') 13 | for img_path in glob.glob(data_path + '/img/*.jpg', recursive=True): 14 | d = pathlib.Path(img_path) 15 | label_path = os.path.join(data_path, 'gt', ('gt_' + str(d.stem) + '.txt')) 16 | if os.path.exists(img_path) and os.path.exists(label_path): 17 | print(img_path, label_path) 18 | else: 19 | print('不存在', img_path, label_path) 20 | f_w.write('{}\t{}\n'.format(img_path, label_path)) 21 | f_w.close() -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | 7 | class runningScore(object): 8 | 9 | def __init__(self, n_classes): 10 | self.n_classes = n_classes 11 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 12 | 13 | def _fast_hist(self, label_true, label_pred, n_class): 14 | mask = (label_true >= 0) & (label_true < n_class) 15 | 16 | if np.sum((label_pred[mask] < 0)) > 0: 17 | print(label_pred[label_pred < 0]) 18 | hist = np.bincount(n_class * label_true[mask].astype(int) + 19 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 20 | return hist 21 | 22 | def update(self, label_trues, label_preds): 23 | # print label_trues.dtype, label_preds.dtype 24 | for lt, lp in zip(label_trues, label_preds): 25 | try: 26 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 27 | except: 28 | pass 29 | 30 | def get_scores(self): 31 | """Returns accuracy score evaluation result. 32 | - overall accuracy 33 | - mean accuracy 34 | - mean IU 35 | - fwavacc 36 | """ 37 | hist = self.confusion_matrix 38 | acc = np.diag(hist).sum() / (hist.sum() + 0.0001) 39 | acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001) 40 | acc_cls = np.nanmean(acc_cls) 41 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001) 42 | mean_iu = np.nanmean(iu) 43 | freq = hist.sum(axis=1) / (hist.sum() + 0.0001) 44 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 45 | cls_iu = dict(zip(range(self.n_classes), iu)) 46 | 47 | return {'Overall Acc': acc, 48 | 'Mean Acc': acc_cls, 49 | 'FreqW Acc': fwavacc, 50 | 'Mean IoU': mean_iu, }, cls_iu 51 | 52 | def reset(self): 53 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 54 | -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class ConstantLR(_LRScheduler): 5 | def __init__(self, optimizer, last_epoch=-1): 6 | super(ConstantLR, self).__init__(optimizer, last_epoch) 7 | 8 | def get_lr(self): 9 | return [base_lr for base_lr in self.base_lrs] 10 | 11 | 12 | class PolynomialLR(_LRScheduler): 13 | def __init__(self, optimizer, max_iter, power=0.9, last_epoch=-1): 14 | self.max_iter = max_iter 15 | self.power = power 16 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 17 | 18 | def get_lr(self): 19 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power 20 | return [base_lr * factor for base_lr in self.base_lrs] 21 | 22 | 23 | class WarmUpLR(_LRScheduler): 24 | def __init__( 25 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1 26 | ): 27 | self.mode = mode 28 | self.scheduler = scheduler 29 | self.warmup_iters = warmup_iters 30 | self.gamma = gamma 31 | super(WarmUpLR, self).__init__(optimizer, last_epoch) 32 | 33 | def get_lr(self): 34 | cold_lrs = self.scheduler.get_lr() 35 | 36 | if self.last_epoch < self.warmup_iters: 37 | if self.mode == "linear": 38 | alpha = self.last_epoch / float(self.warmup_iters) 39 | factor = self.gamma * (1 - alpha) + alpha 40 | 41 | elif self.mode == "constant": 42 | factor = self.gamma 43 | else: 44 | raise KeyError("WarmUp type {} not implemented".format(self.mode)) 45 | 46 | return [factor * base_lr for base_lr in cold_lrs] 47 | 48 | return cold_lrs 49 | 50 | if __name__ == '__main__': 51 | import torch 52 | from torchvision.models import resnet18 53 | max_iter = 600 * 125 54 | model = resnet18() 55 | op = torch.optim.SGD(model.parameters(),0.001) 56 | sc = PolynomialLR(op,max_iter) 57 | lr = [] 58 | for i in range(max_iter): 59 | sc.step() 60 | print(i,sc.last_epoch,sc.get_lr()[0]) 61 | lr.append(sc.get_lr()[0]) 62 | from matplotlib import pyplot as plt 63 | plt.plot(list(range(max_iter)),lr) 64 | plt.show() -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/8/23 21:59 3 | # @Author : zhoujun 4 | import time 5 | import json 6 | import cv2 7 | import torch 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def setup_logger(log_file_path: str = None): 13 | import logging 14 | from colorlog import ColoredFormatter 15 | logging.basicConfig(filename=log_file_path, 16 | format='%(asctime)s %(levelname)-8s %(filename)s[line:%(lineno)d]: %(message)s', 17 | # 定义输出log的格式 18 | datefmt='%Y-%m-%d %H:%M:%S', ) 19 | """Return a logger with a default ColoredFormatter.""" 20 | formatter = ColoredFormatter( 21 | "%(asctime)s %(log_color)s%(levelname)-8s %(reset)s %(filename)s[line:%(lineno)d]: %(message)s", 22 | datefmt='%Y-%m-%d %H:%M:%S', 23 | reset=True, 24 | log_colors={ 25 | 'DEBUG': 'blue', 26 | 'INFO': 'green', 27 | 'WARNING': 'yellow', 28 | 'ERROR': 'red', 29 | 'CRITICAL': 'red', 30 | }) 31 | 32 | logger = logging.getLogger('PAN') 33 | handler = logging.StreamHandler() 34 | handler.setFormatter(formatter) 35 | logger.addHandler(handler) 36 | logger.setLevel(logging.DEBUG) 37 | logger.info('logger init finished') 38 | return logger 39 | 40 | 41 | # --exeTime 42 | def exe_time(func): 43 | def newFunc(*args, **args2): 44 | t0 = time.time() 45 | back = func(*args, **args2) 46 | print("{} cost {:.3f}s".format(func.__name__, time.time() - t0)) 47 | return back 48 | 49 | return newFunc 50 | 51 | 52 | def save_json(data, json_path): 53 | with open(json_path, mode='w', encoding='utf8') as f: 54 | json.dump(data, f, indent=4) 55 | 56 | 57 | def load_json(json_path): 58 | with open(json_path, mode='r', encoding='utf8') as f: 59 | data = json.load(f) 60 | return data 61 | 62 | 63 | def show_img(imgs: np.ndarray, color=False): 64 | if (len(imgs.shape) == 3 and color) or (len(imgs.shape) == 2 and not color): 65 | imgs = np.expand_dims(imgs, axis=0) 66 | for img in imgs: 67 | plt.figure() 68 | plt.imshow(img, cmap=None if color else 'gray') 69 | 70 | 71 | def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2): 72 | if isinstance(img_path, str): 73 | img_path = cv2.imread(img_path) 74 | # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB) 75 | img_path = img_path.copy() 76 | for point in result: 77 | point = point.astype(int) 78 | cv2.line(img_path, tuple(point[0]), tuple(point[1]), color, thickness) 79 | cv2.line(img_path, tuple(point[1]), tuple(point[2]), color, thickness) 80 | cv2.line(img_path, tuple(point[2]), tuple(point[3]), color, thickness) 81 | cv2.line(img_path, tuple(point[3]), tuple(point[0]), color, thickness) 82 | return img_path 83 | 84 | 85 | def cal_text_score(texts, gt_texts, training_masks, running_metric_text): 86 | training_masks = training_masks.data.cpu().numpy() 87 | pred_text = torch.sigmoid(texts).data.cpu().numpy() * training_masks 88 | pred_text[pred_text <= 0.5] = 0 89 | pred_text[pred_text > 0.5] = 1 90 | pred_text = pred_text.astype(np.int32) 91 | gt_text = gt_texts.data.cpu().numpy() * training_masks 92 | gt_text = gt_text.astype(np.int32) 93 | running_metric_text.update(gt_text, pred_text) 94 | score_text, _ = running_metric_text.get_scores() 95 | return score_text 96 | 97 | 98 | def cal_kernel_score(kernel, gt_kernel, gt_texts, training_masks, running_metric_kernel): 99 | mask = (gt_texts * training_masks.float()).data.cpu().numpy() 100 | pred_kernel = torch.sigmoid(kernel).data.cpu().numpy() 101 | pred_kernel[pred_kernel <= 0.5] = 0 102 | pred_kernel[pred_kernel > 0.5] = 1 103 | pred_kernel = (pred_kernel * mask).astype(np.int32) 104 | gt_kernel = gt_kernel.data.cpu().numpy() 105 | gt_kernel = (gt_kernel * mask).astype(np.int32) 106 | running_metric_kernel.update(gt_kernel, pred_kernel) 107 | score_kernel, _ = running_metric_kernel.get_scores() 108 | return score_kernel 109 | 110 | 111 | def order_points_clockwise(pts): 112 | rect = np.zeros((4, 2), dtype="float32") 113 | s = pts.sum(axis=1) 114 | rect[0] = pts[np.argmin(s)] 115 | rect[2] = pts[np.argmax(s)] 116 | diff = np.diff(pts, axis=1) 117 | rect[1] = pts[np.argmin(diff)] 118 | rect[3] = pts[np.argmax(diff)] 119 | return rect 120 | 121 | 122 | def order_points_clockwise_list(pts): 123 | pts = pts.tolist() 124 | pts.sort(key=lambda x: (x[1], x[0])) 125 | pts[:2] = sorted(pts[:2], key=lambda x: x[0]) 126 | pts[2:] = sorted(pts[2:], key=lambda x: -x[0]) 127 | pts = np.array(pts) 128 | return pts 129 | 130 | 131 | if __name__ == '__main__': 132 | box = np.array([382, 1080, 443, 999, 423, 1014, 362, 1095]).reshape(-1, 2) 133 | # box = np.array([0, 4, 2, 2, 0, 8, 4, 4]).reshape(-1, 2) 134 | # box = np.array([0, 0, 2, 2, 0, 4, 4, 4]).reshape(-1, 2) 135 | from scipy.spatial import ConvexHull 136 | 137 | # print(order_points_colckwise(box)) 138 | print(order_points_clockwise_list(box)) 139 | --------------------------------------------------------------------------------