├── .gitignore
├── LICENSE.md
├── README.MD
├── __init__.py
├── base
├── __init__.py
└── base_trainer.py
├── config.json
├── config
├── __init__.py
└── default.py
├── data_loader
├── __init__.py
├── augment.py
├── data_utils.py
└── dataset.py
├── eval.py
├── imgs
├── example
│ ├── img_10.jpg
│ ├── img_2.jpg
│ ├── img_29.jpg
│ ├── img_75.jpg
│ └── img_91.jpg
└── paper
│ └── PAN.jpg
├── models
├── __init__.py
├── loss.py
├── model.py
└── modules
│ ├── __init__.py
│ ├── resnet.py
│ ├── segmentation_head.py
│ └── shufflenetv2.py
├── post_processing
├── Makefile
├── __init__.py
├── include
│ └── pybind11
│ │ ├── attr.h
│ │ ├── buffer_info.h
│ │ ├── cast.h
│ │ ├── chrono.h
│ │ ├── class_support.h
│ │ ├── common.h
│ │ ├── complex.h
│ │ ├── descr.h
│ │ ├── detail
│ │ ├── class.h
│ │ ├── common.h
│ │ ├── descr.h
│ │ ├── init.h
│ │ ├── internals.h
│ │ └── typeid.h
│ │ ├── eigen.h
│ │ ├── embed.h
│ │ ├── eval.h
│ │ ├── functional.h
│ │ ├── iostream.h
│ │ ├── numpy.h
│ │ ├── operators.h
│ │ ├── options.h
│ │ ├── pybind11.h
│ │ ├── pytypes.h
│ │ ├── stl.h
│ │ ├── stl_bind.h
│ │ └── typeid.h
├── kmeans.py
├── pse.cpp
├── pse.so
└── pypse.py
├── predict.py
├── train.py
├── trainer
├── __init__.py
└── trainer.py
└── utils
├── __init__.py
├── cal_recall
├── __init__.py
├── rrc_evaluation_funcs.py
└── script.py
├── make_trainfile.py
├── metrics.py
├── schedulers.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.pth
3 | *.pyc
4 | *.pyo
5 | *.log
6 | *.tmp
7 | *.pkl
8 | __pycache__/
9 | .idea/
10 | output/
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | # Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network
2 |
3 | 
4 |
5 | ## Requirements
6 | * pytorch 1.1+
7 | * torchvision 0.3+
8 | * pyclipper
9 | * opencv3
10 | * gcc 4.9+
11 |
12 | ## Download
13 |
14 | `PAN_resnet18_FPEM_FFM` and `PAN_resnet18_FPEM_FFM` on icdar2015:
15 |
16 | the updated model(resnet18:78.8,shufflenetv2: 72.4,lr:le-3) is not the best model
17 |
18 | [google drive](https://drive.google.com/drive/folders/1bKPQEEOJ5kgSSRMpnDB8HIRecnD_s4bR?usp=sharing)
19 |
20 | ## Data Preparation
21 |
22 | train: prepare a text in the following format, use '\t' as a separator
23 | ```bash
24 | /path/to/img.jpg path/to/label.txt
25 | ...
26 | ```
27 | val:
28 | use a folder
29 | ```bash
30 | img/ store img
31 | gt/ store gt file
32 | ```
33 |
34 | ## Train
35 | 1. config the `train_data_path`,`val_data_path`in [config.json](config.json)
36 | 2. use following script to run
37 | ```sh
38 | python3 train.py
39 | ```
40 |
41 | ## Test
42 |
43 | [eval.py](eval.py) is used to test model on test dataset
44 |
45 | 1. config `model_path`, `img_path`, `gt_path`, `save_path` in [eval.py](eval.py)
46 | 2. use following script to test
47 | ```sh
48 | python3 eval.py
49 | ```
50 |
51 | ## Predict
52 | [predict.py](predict.py) is used to inference on single image
53 |
54 | 1. config `model_path`, `img_path`, in [predict.py](predict.py)
55 | 2. use following script to predict
56 | ```sh
57 | python3 predict.py
58 | ```
59 |
60 | The project is still under development.
61 |
62 |
63 |
64 | ### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4)
65 | only train on ICDAR2015 dataset
66 |
67 | | Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS |
68 | |:--------------------------:|:-------:|:--------:|:--------:|:------------:|:---------------:|:-----:|
69 | | paper(resnet18) | 736 |x | x | x | 80.4 | 26.1 |
70 | | my (ShuffleNetV2+FPEM_FFM+pse扩张) |736 |1e-3| 81.72 | 66.73 | 73.47 | 24.71 (P100)|
71 | | my (resnet18+FPEM_FFM+pse扩张) |736 |1e-3| 84.93 | 74.09 | 79.14 | 21.31 (P100)|
72 | | my (resnet50+FPEM_FFM+pse扩张) |736 |1e-3| 84.23 | 76.12 | 79.96 | 14.22 (P100)|
73 | | my (ShuffleNetV2+FPEM_FFM+pse扩张) |736 |1e-4| 75.14 | 57.34 | 65.04 | 24.71 (P100)|
74 | | my (resnet18+FPEM_FFM+pse扩张) |736 |1e-4| 83.89 | 69.23 | 75.86 | 21.31 (P100)|
75 | | my (resnet50+FPEM_FFM+pse扩张) |736 |1e-4| 85.29 | 75.1 | 79.87 | 14.22 (P100)|
76 | | my (resnet18+FPN+pse扩张) | 736 |1e-3| 76.50 | 74.70 | 75.59 | 14.47 (P100)|
77 | | my (resnet50+FPN+pse扩张) | 736 |1e-3| 71.82 | 75.73 | 73.72 | 10.67 (P100)|
78 | | my (resnet18+FPN+pse扩张) | 736 |1e-4| 74.19 | 72.34 | 73.25 | 14.47 (P100)|
79 | | my (resnet50+FPN+pse扩张) | 736 |1e-4| 78.96 | 76.27 | 77.59 | 10.67 (P100)|
80 |
81 | ### examples
82 | 
83 |
84 | 
85 |
86 | 
87 |
88 | 
89 |
90 | 
91 |
92 | ### todo
93 | - [ ] MobileNet backbone
94 |
95 | - [x] ShuffleNet backbone
96 | ### reference
97 | 1. https://arxiv.org/pdf/1908.05900.pdf
98 | 2. https://github.com/WenmuZhou/PSENet.pytorch
99 |
100 | **If this repository helps you,please star it. Thanks.**
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/__init__.py
--------------------------------------------------------------------------------
/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_trainer import BaseTrainer
--------------------------------------------------------------------------------
/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:50
3 | # @Author : zhoujun
4 |
5 | import os
6 | import shutil
7 | import pathlib
8 | from pprint import pformat
9 | import torch
10 | from torch import nn
11 |
12 | from utils import setup_logger
13 |
14 |
15 | class BaseTrainer:
16 | def __init__(self, config, model, criterion, weights_init):
17 | config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent),
18 | config['trainer']['output_dir'])
19 | config['name'] = config['name'] + '_' + model.name
20 | self.save_dir = os.path.join(config['trainer']['output_dir'], config['name'])
21 | self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
22 |
23 | if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '':
24 | shutil.rmtree(self.save_dir, ignore_errors=True)
25 | if not os.path.exists(self.checkpoint_dir):
26 | os.makedirs(self.checkpoint_dir)
27 |
28 | self.global_step = 0
29 | self.start_epoch = 1
30 | self.config = config
31 |
32 | self.model = model
33 | self.criterion = criterion
34 | # logger and tensorboard
35 | self.tensorboard_enable = self.config['trainer']['tensorboard']
36 | self.epochs = self.config['trainer']['epochs']
37 | self.display_interval = self.config['trainer']['display_interval']
38 | if self.tensorboard_enable:
39 | from torch.utils.tensorboard import SummaryWriter
40 | self.writer = SummaryWriter(self.save_dir)
41 |
42 | self.logger = setup_logger(os.path.join(self.save_dir, 'train_log'))
43 | self.logger.info(pformat(self.config))
44 |
45 | # device
46 | torch.manual_seed(self.config['trainer']['seed']) # 为CPU设置随机种子
47 | if len(self.config['trainer']['gpus']) > 0 and torch.cuda.is_available():
48 | self.with_cuda = True
49 | torch.backends.cudnn.benchmark = True
50 | self.logger.info(
51 | 'train with gpu {} and pytorch {}'.format(self.config['trainer']['gpus'], torch.__version__))
52 | self.gpus = {i: item for i, item in enumerate(self.config['trainer']['gpus'])}
53 | self.device = torch.device("cuda:0")
54 | torch.cuda.manual_seed(self.config['trainer']['seed']) # 为当前GPU设置随机种子
55 | torch.cuda.manual_seed_all(self.config['trainer']['seed']) # 为所有GPU设置随机种子
56 | else:
57 | self.with_cuda = False
58 | self.logger.info('train with cpu and pytorch {}'.format(torch.__version__))
59 | self.device = torch.device("cpu")
60 | self.logger.info('device {}'.format(self.device))
61 | self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf'), 'best_model': ''}
62 |
63 | self.optimizer = self._initialize('optimizer', torch.optim, model.parameters())
64 |
65 | if self.config['trainer']['resume_checkpoint'] != '':
66 | self._laod_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True)
67 | elif self.config['trainer']['finetune_checkpoint'] != '':
68 | self._laod_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False)
69 | else:
70 | if weights_init is not None:
71 | model.apply(weights_init)
72 | if self.config['lr_scheduler']['type'] != 'PolynomialLR':
73 | self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer)
74 |
75 | # 单机多卡
76 | num_gpus = torch.cuda.device_count()
77 | if num_gpus > 1:
78 | self.model = nn.DataParallel(self.model)
79 |
80 | self.model.to(self.device)
81 |
82 | if self.tensorboard_enable:
83 | try:
84 | # add graph
85 | dummy_input = torch.zeros(1, self.config['data_loader']['args']['dataset']['img_channel'],
86 | self.config['data_loader']['args']['dataset']['input_size'],
87 | self.config['data_loader']['args']['dataset']['input_size']).to(self.device)
88 | self.writer.add_graph(model, dummy_input)
89 | except:
90 | import traceback
91 | # self.logger.error(traceback.format_exc())
92 | self.logger.warn('add graph to tensorboard failed')
93 |
94 | def train(self):
95 | """
96 | Full training logic
97 | """
98 | for epoch in range(self.start_epoch, self.epochs + 1):
99 | try:
100 | self.epoch_result = self._train_epoch(epoch)
101 | if self.config['lr_scheduler']['type'] != 'PolynomialLR':
102 | self.scheduler.step()
103 | self._on_epoch_finish()
104 | except torch.cuda.CudaError:
105 | self._log_memory_usage()
106 | if self.tensorboard_enable:
107 | self.writer.close()
108 | self._on_train_finish()
109 |
110 | def _train_epoch(self, epoch):
111 | """
112 | Training logic for an epoch
113 |
114 | :param epoch: Current epoch number
115 | """
116 | raise NotImplementedError
117 |
118 | def _eval(self):
119 | """
120 | eval logic for an epoch
121 |
122 | :param epoch: Current epoch number
123 | """
124 | raise NotImplementedError
125 |
126 | def _on_epoch_finish(self):
127 | raise NotImplementedError
128 |
129 | def _on_train_finish(self):
130 | raise NotImplementedError
131 |
132 | def _log_memory_usage(self):
133 | if not self.with_cuda:
134 | return
135 |
136 | template = """Memory Usage: \n{}"""
137 | usage = []
138 | for deviceID, device in self.gpus.items():
139 | deviceID = int(deviceID)
140 | allocated = torch.cuda.memory_allocated(deviceID) / (1024 * 1024)
141 | cached = torch.cuda.memory_cached(deviceID) / (1024 * 1024)
142 |
143 | usage.append(' CUDA: {} Allocated: {} MB Cached: {} MB \n'.format(device, allocated, cached))
144 |
145 | content = ''.join(usage)
146 | content = template.format(content)
147 |
148 | self.logger.debug(content)
149 |
150 | def _save_checkpoint(self, epoch, file_name, save_best=False):
151 | """
152 | Saving checkpoints
153 |
154 | :param epoch: current epoch number
155 | :param log: logging information of the epoch
156 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
157 | """
158 | state = {
159 | 'epoch': epoch,
160 | 'global_step': self.global_step,
161 | 'state_dict': self.model.state_dict(),
162 | 'optimizer': self.optimizer.state_dict(),
163 | 'scheduler': self.scheduler.state_dict(),
164 | 'config': self.config,
165 | 'metrics': self.metrics
166 | }
167 | filename = os.path.join(self.checkpoint_dir, file_name)
168 | torch.save(state, filename)
169 | if save_best:
170 | shutil.copy(filename, os.path.join(self.checkpoint_dir, 'model_best.pth'))
171 | self.logger.info("Saving current best: {}".format(file_name))
172 | else:
173 | self.logger.info("Saving checkpoint: {}".format(filename))
174 |
175 | def _laod_checkpoint(self, checkpoint_path, resume):
176 | """
177 | Resume from saved checkpoints
178 | :param checkpoint_path: Checkpoint path to be resumed
179 | """
180 | self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path))
181 | checkpoint = torch.load(checkpoint_path)
182 | self.model.load_state_dict(checkpoint['state_dict'])
183 | if resume:
184 | self.global_step = checkpoint['global_step']
185 | self.start_epoch = checkpoint['epoch'] + 1
186 | self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch
187 | # self.scheduler.load_state_dict(checkpoint['scheduler'])
188 | self.optimizer.load_state_dict(checkpoint['optimizer'])
189 | if 'metrics' in checkpoint:
190 | self.metrics = checkpoint['metrics']
191 | if self.with_cuda:
192 | for state in self.optimizer.state.values():
193 | for k, v in state.items():
194 | if isinstance(v, torch.Tensor):
195 | state[k] = v.to(self.device)
196 | self.logger.info("resume from checkpoint {} (epoch {})".format(checkpoint_path, self.start_epoch))
197 | else:
198 | self.logger.info("finetune from checkpoint {}".format(checkpoint_path))
199 |
200 | def _initialize(self, name, module, *args, **kwargs):
201 | module_name = self.config[name]['type']
202 | module_args = self.config[name]['args']
203 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
204 | module_args.update(kwargs)
205 | return getattr(module, module_name)(*args, **module_args)
206 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "PAN",
3 | "data_loader": {
4 | "type": "ImageDataset",
5 | "args": {
6 | "dataset": {
7 | "train_data_path": [
8 | [
9 | "dataset1.txt1",
10 | "dataset1.txt2"
11 | ],
12 | [
13 | "dataset2.txt1",
14 | "dataset2.txt2"
15 | ]
16 | ],
17 | "train_data_ratio": [
18 | 0.5,
19 | 0.5
20 | ],
21 | "val_data_path": "path/to/test/",
22 | "input_size": 640,
23 | "img_channel": 3,
24 | "shrink_ratio": 0.5
25 | },
26 | "loader": {
27 | "validation_split": 0.1,
28 | "train_batch_size": 16,
29 | "shuffle": true,
30 | "pin_memory": false,
31 | "num_workers": 6
32 | }
33 | }
34 | },
35 | "arch": {
36 | "type": "PANModel",
37 | "args": {
38 | "backbone": "resnet18",
39 | "fpem_repeat": 2,
40 | "pretrained": true,
41 | "segmentation_head": "FPEM_FFM"
42 | }
43 | },
44 | "loss": {
45 | "type": "PANLoss",
46 | "args": {
47 | "alpha": 0.5,
48 | "beta": 0.25,
49 | "delta_agg": 0.5,
50 | "delta_dis": 3,
51 | "ohem_ratio": 3
52 | }
53 | },
54 | "optimizer": {
55 | "type": "Adam",
56 | "args": {
57 | "lr": 0.001,
58 | "weight_decay": 0,
59 | "amsgrad": true
60 | }
61 | },
62 | "lr_scheduler": {
63 | "type": "StepLR",
64 | "args": {
65 | "step_size": 200,
66 | "gamma": 0.1
67 | }
68 | },
69 | "trainer": {
70 | "seed": 2,
71 | "gpus": [
72 | 0
73 | ],
74 | "epochs": 600,
75 | "display_interval": 10,
76 | "show_images_interval": 50,
77 | "resume_checkpoint": "",
78 | "finetune_checkpoint": "",
79 | "output_dir": "output",
80 | "tensorboard": true,
81 | "metrics": "hmean"
82 | }
83 | }
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:50
3 | # @Author : zhoujun
--------------------------------------------------------------------------------
/config/default.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:51
3 | # @Author : zhoujun
4 |
5 | name = 'PAN'
6 | arch = {
7 | "type": "PANModel", # name of model architecture to train
8 | "args": {
9 | 'backbone': 'resnet18',
10 | 'fpem_repeat': 2, # fpem模块重复的次数
11 | 'pretrained': True, # backbone 是否使用imagesnet的预训练模型
12 | 'segmentation_head': 'FPN' #分割头,FPN or FPEM_FFM
13 | }
14 | }
15 |
16 |
17 | data_loader = {
18 | "type": "ImageDataset", # selecting data loader
19 | "args": {
20 | 'dataset': {
21 | 'train_data_path': [['dataset1.txt1', 'dataset1.txt2'], ['dataset2.txt1', 'dataset2.txt2']],
22 | 'train_data_ratio': [0.5, 0.5],
23 | 'val_data_path': ['path/to/test/'],
24 | 'input_size': 640,
25 | 'img_channel': 3,
26 | 'shrink_ratio': 0.5 # cv or PIL
27 | },
28 | 'loader': {
29 | 'validation_split': 0.1,
30 | 'train_batch_size': 16,
31 | 'val_batch_size': 4,
32 | 'shuffle': True,
33 | 'pin_memory': False,
34 | 'num_workers': 6
35 | }
36 | }
37 | }
38 | loss = {
39 | "type": "PANLoss", # name of model architecture to train
40 | "args": {
41 | 'alpha': 0.5,
42 | 'beta': 0.25,
43 | 'delta_agg': 0.5,
44 | 'delta_dis': 3,
45 | 'ohem_ratio': 3
46 | }
47 | }
48 |
49 | optimizer = {
50 | "type": "Adam",
51 | "args": {
52 | "lr": 0.001,
53 | "weight_decay": 0,
54 | "amsgrad": True
55 | }
56 | }
57 |
58 | lr_scheduler = {
59 | "type": "StepLR",
60 | "args": {
61 | "step_size": 200,
62 | "gamma": 0.1
63 | }
64 | }
65 |
66 | resume = {
67 | 'restart_training': True,
68 | 'checkpoint': ''
69 | }
70 |
71 | trainer = {
72 | # random seed
73 | 'seed': 2,
74 | 'gpus': [0],
75 | 'epochs': 600,
76 | 'display_interval': 10,
77 | 'show_images_interval': 50,
78 | 'resume': resume,
79 | 'output_dir': 'output',
80 | 'tensorboard': True
81 | }
82 |
83 | config_dict = {}
84 | config_dict['name'] = name
85 | config_dict['data_loader'] = data_loader
86 | config_dict['arch'] = arch
87 | config_dict['loss'] = loss
88 | config_dict['optimizer'] = optimizer
89 | config_dict['lr_scheduler'] = lr_scheduler
90 | config_dict['trainer'] = trainer
91 |
92 | from utils import save_json
93 |
94 | save_json(config_dict, '../config.json')
95 |
--------------------------------------------------------------------------------
/data_loader/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:52
3 | # @Author : zhoujun
4 |
5 | from torch.utils.data import DataLoader
6 | from torchvision import transforms
7 | import copy
8 | import pathlib
9 | from . import dataset
10 |
11 |
12 | def get_datalist(train_data_path, validation_split=0.1):
13 | """
14 | 获取训练和验证的数据list
15 | :param train_data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
16 | :param validation_split: 验证集的比例,当val_data_path为空时使用
17 | :return:
18 | """
19 | train_data_list = []
20 | for train_path in train_data_path:
21 | train_data = []
22 | for p in train_path:
23 | with open(p, 'r', encoding='utf-8') as f:
24 | for line in f.readlines():
25 | line = line.strip('\n').replace('.jpg ', '.jpg\t').split('\t')
26 | if len(line) > 1:
27 | img_path = pathlib.Path(line[0].strip(' '))
28 | label_path = pathlib.Path(line[1].strip(' '))
29 | if img_path.exists() and img_path.stat().st_size > 0 and label_path.exists() and label_path.stat().st_size > 0:
30 | train_data.append((str(img_path), str(label_path)))
31 | train_data_list.append(train_data)
32 | return train_data_list
33 |
34 |
35 | def get_dataset(data_list, module_name, transform, dataset_args):
36 | """
37 | 获取训练dataset
38 | :param data_list: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
39 | :param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset
40 | :param transform: 该数据集使用的transforms
41 | :param dataset_args: module_name的参数
42 | :return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None
43 | """
44 | s_dataset = getattr(dataset, module_name)(transform=transform, data_list=data_list,
45 | **dataset_args)
46 | return s_dataset
47 |
48 |
49 | def get_dataloader(module_name, module_args):
50 | train_transfroms = transforms.Compose([
51 | transforms.ColorJitter(brightness=0.5),
52 | transforms.ToTensor()
53 | ])
54 |
55 | # 创建数据集
56 | dataset_args = copy.deepcopy(module_args['dataset'])
57 | train_data_path = dataset_args.pop('train_data_path')
58 | train_data_ratio = dataset_args.pop('train_data_ratio')
59 | dataset_args.pop('val_data_path')
60 | train_data_list = get_datalist(train_data_path, module_args['loader']['validation_split'])
61 | train_dataset_list = []
62 | for train_data in train_data_list:
63 | train_dataset_list.append(get_dataset(data_list=train_data,
64 | module_name=module_name,
65 | transform=train_transfroms,
66 | dataset_args=dataset_args))
67 |
68 | if len(train_dataset_list) > 1:
69 | train_loader = dataset.Batch_Balanced_Dataset(dataset_list=train_dataset_list,
70 | ratio_list=train_data_ratio,
71 | module_args=module_args,
72 | phase='train')
73 | elif len(train_dataset_list) == 1:
74 | train_loader = DataLoader(dataset=train_dataset_list[0],
75 | batch_size=module_args['loader']['train_batch_size'],
76 | shuffle=module_args['loader']['shuffle'],
77 | num_workers=module_args['loader']['num_workers'])
78 | train_loader.dataset_len = len(train_dataset_list[0])
79 | else:
80 | raise Exception('no images found')
81 | return train_loader
82 |
--------------------------------------------------------------------------------
/data_loader/augment.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:52
3 | # @Author : zhoujun
4 |
5 | import cv2
6 | import numbers
7 | import math
8 | import random
9 | import numpy as np
10 | from skimage.util import random_noise
11 |
12 |
13 | def show_pic(img, bboxes=None, name='pic'):
14 | '''
15 | 输入:
16 | img:图像array
17 | bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....]
18 | names:每个box对应的名称
19 | '''
20 | show_img = img.copy()
21 | if not isinstance(bboxes, np.ndarray):
22 | bboxes = np.array(bboxes)
23 | for point in bboxes.astype(np.int):
24 | cv2.line(show_img, tuple(point[0]), tuple(point[1]), (255, 0, 0), 2)
25 | cv2.line(show_img, tuple(point[1]), tuple(point[2]), (255, 0, 0), 2)
26 | cv2.line(show_img, tuple(point[2]), tuple(point[3]), (255, 0, 0), 2)
27 | cv2.line(show_img, tuple(point[3]), tuple(point[0]), (255, 0, 0), 2)
28 | # cv2.namedWindow(name, 0) # 1表示原图
29 | # cv2.moveWindow(name, 0, 0)
30 | # cv2.resizeWindow(name, 1200, 800) # 可视化的图片大小
31 | cv2.imshow(name, show_img)
32 |
33 |
34 | # 图像均为cv2读取
35 | class DataAugment():
36 | def __init__(self):
37 | pass
38 |
39 | def add_noise(self, im: np.ndarray):
40 | """
41 | 对图片加噪声
42 | :param img: 图像array
43 | :return: 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
44 | """
45 | return (random_noise(im, mode='gaussian', clip=True) * 255).astype(im.dtype)
46 |
47 | def random_scale(self, im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray or list) -> tuple:
48 | """
49 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
50 | :param im: 原图
51 | :param text_polys: 文本框
52 | :param scales: 尺度
53 | :return: 经过缩放的图片和文本
54 | """
55 | tmp_text_polys = text_polys.copy()
56 | rd_scale = float(np.random.choice(scales))
57 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
58 | tmp_text_polys *= rd_scale
59 | return im, tmp_text_polys
60 |
61 | def random_rotate_img_bbox(self, img, text_polys, degrees: numbers.Number or list or tuple or np.ndarray,
62 | same_size=False):
63 | """
64 | 从给定的角度中选择一个角度,对图片和文本框进行旋转
65 | :param img: 图片
66 | :param text_polys: 文本框
67 | :param degrees: 角度,可以是一个数值或者list
68 | :param same_size: 是否保持和原图一样大
69 | :return: 旋转后的图片和角度
70 | """
71 | if isinstance(degrees, numbers.Number):
72 | if degrees < 0:
73 | raise ValueError("If degrees is a single number, it must be positive.")
74 | degrees = (-degrees, degrees)
75 | elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray):
76 | if len(degrees) != 2:
77 | raise ValueError("If degrees is a sequence, it must be of len 2.")
78 | degrees = degrees
79 | else:
80 | raise Exception('degrees must in Number or list or tuple or np.ndarray')
81 | # ---------------------- 旋转图像 ----------------------
82 | w = img.shape[1]
83 | h = img.shape[0]
84 | angle = np.random.uniform(degrees[0], degrees[1])
85 |
86 | if same_size:
87 | nw = w
88 | nh = h
89 | else:
90 | # 角度变弧度
91 | rangle = np.deg2rad(angle)
92 | # 计算旋转之后图像的w, h
93 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
94 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
95 | # 构造仿射矩阵
96 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
97 | # 计算原图中心点到新图中心点的偏移量
98 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
99 | # 更新仿射矩阵
100 | rot_mat[0, 2] += rot_move[0]
101 | rot_mat[1, 2] += rot_move[1]
102 | # 仿射变换
103 | rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
104 |
105 | # ---------------------- 矫正bbox坐标 ----------------------
106 | # rot_mat是最终的旋转矩阵
107 | # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
108 | rot_text_polys = list()
109 | for bbox in text_polys:
110 | point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
111 | point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
112 | point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
113 | point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
114 | rot_text_polys.append([point1, point2, point3, point4])
115 | return rot_img, np.array(rot_text_polys, dtype=np.float32)
116 |
117 | def random_crop(self, imgs, img_size):
118 | h, w = imgs[0].shape[0:2]
119 | th, tw = img_size
120 | if w == tw and h == th:
121 | return imgs
122 |
123 | # label中存在文本实例,并且按照概率进行裁剪
124 | if np.max(imgs[1][:, :, 0]) > 0 and random.random() > 3.0 / 8.0:
125 | # 文本实例的top left点
126 | tl = np.min(np.where(imgs[1][:, :, 0] > 0), axis=1) - img_size
127 | tl[tl < 0] = 0
128 | # 文本实例的 bottom right 点
129 | br = np.max(np.where(imgs[1][:, :, 0] > 0), axis=1) - img_size
130 | br[br < 0] = 0
131 | # 保证选到右下角点是,有足够的距离进行crop
132 | br[0] = min(br[0], h - th)
133 | br[1] = min(br[1], w - tw)
134 | for _ in range(50000):
135 | i = random.randint(tl[0], br[0])
136 | j = random.randint(tl[1], br[1])
137 | # 保证最小的图有文本
138 | if imgs[1][:, :, -1][i:i + th, j:j + tw].sum() <= 0:
139 | continue
140 | else:
141 | break
142 | i = random.randint(tl[0], br[0])
143 | j = random.randint(tl[1], br[1])
144 | else:
145 | i = random.randint(0, h - th)
146 | j = random.randint(0, w - tw)
147 |
148 | # return i, j, th, tw
149 | for idx in range(len(imgs)):
150 | if len(imgs[idx].shape) == 3:
151 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
152 | else:
153 | imgs[idx] = imgs[idx][i:i + th, j:j + tw]
154 | return imgs
155 |
156 | def resize(self, im: np.ndarray, text_polys: np.ndarray,
157 | input_size: numbers.Number or list or tuple or np.ndarray, keep_ratio: bool = False) -> tuple:
158 | """
159 | 对图片和文本框进行resize
160 | :param im: 图片
161 | :param text_polys: 文本框
162 | :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
163 | :param keep_ratio: 是否保持长宽比
164 | :return: resize后的图片和文本框
165 | """
166 | if isinstance(input_size, numbers.Number):
167 | if input_size < 0:
168 | raise ValueError("If input_size is a single number, it must be positive.")
169 | input_size = (input_size, input_size)
170 | elif isinstance(input_size, list) or isinstance(input_size, tuple) or isinstance(input_size, np.ndarray):
171 | if len(input_size) != 2:
172 | raise ValueError("If input_size is a sequence, it must be of len 2.")
173 | input_size = (input_size[0], input_size[1])
174 | else:
175 | raise Exception('input_size must in Number or list or tuple or np.ndarray')
176 | if keep_ratio:
177 | # 将图片短边pad到和长边一样
178 | h, w, c = im.shape
179 | max_h = max(h, input_size[0])
180 | max_w = max(w, input_size[1])
181 | im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8)
182 | im_padded[:h, :w] = im.copy()
183 | im = im_padded
184 | text_polys = text_polys.astype(np.float32)
185 | h, w, _ = im.shape
186 | im = cv2.resize(im, input_size)
187 | w_scale = input_size[0] / float(w)
188 | h_scale = input_size[1] / float(h)
189 | text_polys[:, :, 0] *= w_scale
190 | text_polys[:, :, 1] *= h_scale
191 | return im, text_polys
192 |
193 | def horizontal_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple:
194 | """
195 | 对图片和文本框进行水平翻转
196 | :param im: 图片
197 | :param text_polys: 文本框
198 | :return: 水平翻转之后的图片和文本框
199 | """
200 | flip_text_polys = text_polys.copy()
201 | flip_im = cv2.flip(im, 1)
202 | h, w, _ = flip_im.shape
203 | flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
204 | return flip_im, flip_text_polys
205 |
206 | def vertical_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple:
207 | """
208 | 对图片和文本框进行竖直翻转
209 | :param im: 图片
210 | :param text_polys: 文本框
211 | :return: 竖直翻转之后的图片和文本框
212 | """
213 | flip_text_polys = text_polys.copy()
214 | flip_im = cv2.flip(im, 0)
215 | h, w, _ = flip_im.shape
216 | flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
217 | return flip_im, flip_text_polys
218 |
219 | def test(self, im: np.ndarray, text_polys: np.ndarray):
220 | print('随机尺度缩放')
221 | t_im, t_text_polys = self.random_scale(im, text_polys, [0.5, 1, 2, 3])
222 | print(t_im.shape, t_text_polys.dtype)
223 | show_pic(t_im, t_text_polys, 'random_scale')
224 |
225 | print('随机旋转')
226 | t_im, t_text_polys = self.random_rotate_img_bbox(im, text_polys, 10)
227 | print(t_im.shape, t_text_polys.dtype)
228 | show_pic(t_im, t_text_polys, 'random_rotate_img_bbox')
229 |
230 | print('随机裁剪')
231 | t_im, t_text_polys = self.random_crop_img_bboxes(im, text_polys)
232 | print(t_im.shape, t_text_polys.dtype)
233 | show_pic(t_im, t_text_polys, 'random_crop_img_bboxes')
234 |
235 | print('水平翻转')
236 | t_im, t_text_polys = self.horizontal_flip(im, text_polys)
237 | print(t_im.shape, t_text_polys.dtype)
238 | show_pic(t_im, t_text_polys, 'horizontal_flip')
239 |
240 | print('竖直翻转')
241 | t_im, t_text_polys = self.vertical_flip(im, text_polys)
242 | print(t_im.shape, t_text_polys.dtype)
243 | show_pic(t_im, t_text_polys, 'vertical_flip')
244 | show_pic(im, text_polys, 'vertical_flip_ori')
245 |
246 | print('加噪声')
247 | t_im = self.add_noise(im)
248 | print(t_im.shape)
249 | show_pic(t_im, text_polys, 'add_noise')
250 | show_pic(im, text_polys, 'add_noise_ori')
251 |
--------------------------------------------------------------------------------
/data_loader/data_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:53
3 | # @Author : zhoujun
4 | import math
5 | import random
6 | import pyclipper
7 | import numpy as np
8 | import cv2
9 | from data_loader.augment import DataAugment
10 |
11 | data_aug = DataAugment()
12 |
13 |
14 | def check_and_validate_polys(polys, xxx_todo_changeme):
15 | '''
16 | check so that the text poly is in the same direction,
17 | and also filter some invalid polygons
18 | :param polys:
19 | :param tags:
20 | :return:
21 | '''
22 | (h, w) = xxx_todo_changeme
23 | if polys.shape[0] == 0:
24 | return polys
25 | polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) # x coord not max w-1, and not min 0
26 | polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) # y coord not max h-1, and not min 0
27 |
28 | validated_polys = []
29 | for poly in polys:
30 | p_area = cv2.contourArea(poly)
31 | if abs(p_area) < 1:
32 | continue
33 | validated_polys.append(poly)
34 | return np.array(validated_polys)
35 |
36 | def unshrink_offset(poly,ratio):
37 | area = cv2.contourArea(poly)
38 | peri = cv2.arcLength(poly, True)
39 | a = 8
40 | b = peri - 4
41 | c = 1-0.5 * peri - area/ratio
42 | return quadratic(a,b,c)
43 |
44 | def quadratic(a, b, c):
45 | if (b * b - 4 * a * c) < 0:
46 | return 'None'
47 | Delte = math.sqrt(b * b - 4 * a * c)
48 | if Delte > 0:
49 | x = (- b + Delte) / (2 * a)
50 | y = (- b - Delte) / (2 * a)
51 | return x, y
52 | else:
53 | x = (- b) / (2 * a)
54 | return x
55 |
56 | def generate_rbox(im_size, text_polys, text_tags,training_mask, shrink_ratio):
57 | """
58 | 生成mask图,白色部分是文本,黑色是北京
59 | :param im_size: 图像的h,w
60 | :param text_polys: 框的坐标
61 | :param text_tags: 标注文本框是否参与训练
62 | :param training_mask: 忽略标注为 DO NOT CARE 的矩阵
63 | :return: 生成的mask图
64 | """
65 | h, w = im_size
66 | score_map = np.zeros((h, w), dtype=np.uint8)
67 | for i, (poly, tag) in enumerate(zip(text_polys, text_tags)):
68 | try:
69 | poly = poly.astype(np.int)
70 | # d_i = cv2.contourArea(poly) * (1 - shrink_ratio * shrink_ratio) / cv2.arcLength(poly, True)
71 | d_i = cv2.contourArea(poly) * (1 - shrink_ratio) / cv2.arcLength(poly, True) + 0.5
72 | pco = pyclipper.PyclipperOffset()
73 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
74 | shrinked_poly = np.array(pco.Execute(-d_i))
75 | cv2.fillPoly(score_map, shrinked_poly, i + 1)
76 | if not tag:
77 | cv2.fillPoly(training_mask, shrinked_poly, 0)
78 | except:
79 | print(poly)
80 | return score_map, training_mask
81 |
82 |
83 | def augmentation(im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray, degrees: int) -> tuple:
84 | # the images are rescaled with ratio {0.5, 1.0, 2.0, 3.0} randomly
85 | im, text_polys = data_aug.random_scale(im, text_polys, scales)
86 | # the images are horizontally fliped and rotated in range [−10◦, 10◦] randomly
87 | if random.random() < 0.5:
88 | im, text_polys = data_aug.horizontal_flip(im, text_polys)
89 | if random.random() < 0.5:
90 | im, text_polys = data_aug.random_rotate_img_bbox(im, text_polys, degrees)
91 | return im, text_polys
92 |
93 |
94 | def image_label(im: np.ndarray, text_polys: np.ndarray, text_tags: list, input_size: int = 640,
95 | shrink_ratio: float = 0.5, degrees: int = 10,
96 | scales: np.ndarray = np.array([0.5, 1, 2.0, 3.0])) -> tuple:
97 | """
98 | 读取图片并生成label
99 | :param im: 图片
100 | :param text_polys: 文本标注框
101 | :param text_tags: 是否忽略文本的标致:true 忽略, false 不忽略
102 | :param input_size: 输出图像的尺寸
103 | :param shrink_ratio: gt收缩的比例
104 | :param degrees: 随机旋转的角度
105 | :param scales: 随机缩放的尺度
106 | :return:
107 | """
108 | h, w, _ = im.shape
109 | # 检查越界
110 | text_polys = check_and_validate_polys(text_polys, (h, w))
111 | im, text_polys = augmentation(im, text_polys, scales, degrees)
112 |
113 | h, w, _ = im.shape
114 | short_edge = min(h, w)
115 | if short_edge < input_size:
116 | # 保证短边 >= inputsize
117 | scale = input_size / short_edge
118 | im = cv2.resize(im, dsize=None, fx=scale, fy=scale)
119 | text_polys *= scale
120 |
121 | h, w, _ = im.shape
122 | training_mask = np.ones((h, w), dtype=np.uint8)
123 | score_maps = []
124 | for i in (1, shrink_ratio):
125 | score_map, training_mask = generate_rbox((h, w), text_polys, text_tags,training_mask, i)
126 | score_maps.append(score_map)
127 | score_maps = np.array(score_maps, dtype=np.float32)
128 | imgs = data_aug.random_crop([im, score_maps.transpose((1, 2, 0)), training_mask], (input_size, input_size))
129 | return imgs[0], imgs[1].transpose((2, 0, 1)), imgs[2] # im,score_maps,training_mask#
130 |
131 | if __name__ == '__main__':
132 | poly = np.array([377,117,463,117,465,130,378,130]).reshape(-1,2)
133 | shrink_ratio = 0.5
134 | d_i = cv2.contourArea(poly) * (1 - shrink_ratio) / cv2.arcLength(poly, True) + 0.5
135 | pco = pyclipper.PyclipperOffset()
136 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
137 | shrinked_poly = np.array(pco.Execute(-d_i))
138 | print(d_i)
139 | print(cv2.contourArea(shrinked_poly.astype(int)) / cv2.contourArea(poly))
140 | print(unshrink_offset(shrinked_poly,shrink_ratio))
141 |
--------------------------------------------------------------------------------
/data_loader/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:54
3 | # @Author : zhoujun
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 | from torch.utils.data import Dataset, DataLoader
8 | from data_loader.data_utils import image_label
9 | from utils import order_points_clockwise
10 |
11 |
12 | class ImageDataset(Dataset):
13 | def __init__(self, data_list: list, input_size: int, img_channel: int, shrink_ratio: float, transform=None,
14 | target_transform=None):
15 | self.data_list = self.load_data(data_list)
16 | self.input_size = input_size
17 | self.img_channel = img_channel
18 | self.transform = transform
19 | self.target_transform = target_transform
20 | self.shrink_ratio = shrink_ratio
21 |
22 | def __getitem__(self, index):
23 | img_path, text_polys, text_tags = self.data_list[index]
24 | im = cv2.imread(img_path, 1 if self.img_channel == 3 else 0)
25 | if self.img_channel == 3:
26 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
27 | img, score_map, training_mask = image_label(im, text_polys, text_tags, self.input_size,
28 | self.shrink_ratio)
29 | # img = draw_bbox(img,text_polys)
30 | img = Image.fromarray(img)
31 | if self.transform:
32 | img = self.transform(img)
33 | if self.target_transform:
34 | score_map = self.target_transform(score_map)
35 | training_mask = self.target_transform(training_mask)
36 | return img, score_map, training_mask
37 |
38 | def load_data(self, data_list: list) -> list:
39 | t_data_list = []
40 | for img_path, label_path in data_list:
41 | bboxs, text_tags = self._get_annotation(label_path)
42 | if len(bboxs) > 0:
43 | t_data_list.append((img_path, bboxs, text_tags))
44 | else:
45 | print('there is no suit bbox in {}'.format(label_path))
46 | return t_data_list
47 |
48 | def _get_annotation(self, label_path: str) -> tuple:
49 | boxes = []
50 | text_tags = []
51 | with open(label_path, encoding='utf-8', mode='r') as f:
52 | for line in f.readlines():
53 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',')
54 | try:
55 | box = order_points_clockwise(np.array(list(map(float, params[:8]))).reshape(-1, 2))
56 | if cv2.arcLength(box, True) > 0:
57 | boxes.append(box)
58 | label = params[8]
59 | if label == '*' or label == '###':
60 | text_tags.append(False)
61 | else:
62 | text_tags.append(True)
63 | except:
64 | print('load label failed on {}'.format(label_path))
65 | return np.array(boxes, dtype=np.float32), np.array(text_tags, dtype=np.bool)
66 |
67 | def __len__(self):
68 | return len(self.data_list)
69 |
70 |
71 | class Batch_Balanced_Dataset(object):
72 | def __init__(self, dataset_list: list, ratio_list: list, module_args: dict,
73 | phase: str = 'train'):
74 | """
75 | 对datasetlist里的dataset按照ratio_list里对应的比例组合,似的每个batch里的数据按按照比例采样的
76 | :param dataset_list: 数据集列表
77 | :param ratio_list: 比例列表
78 | :param module_args: dataloader的配置
79 | :param phase: 训练集还是验证集
80 | """
81 | assert sum(ratio_list) == 1 and len(dataset_list) == len(ratio_list)
82 |
83 | self.dataset_len = 0
84 | self.data_loader_list = []
85 | self.dataloader_iter_list = []
86 | all_batch_size = module_args['loader']['train_batch_size'] if phase == 'train' else module_args['loader'][
87 | 'val_batch_size']
88 | for _dataset, batch_ratio_d in zip(dataset_list, ratio_list):
89 | _batch_size = max(round(all_batch_size * float(batch_ratio_d)), 1)
90 |
91 | _data_loader = DataLoader(dataset=_dataset,
92 | batch_size=_batch_size,
93 | shuffle=module_args['loader']['shuffle'],
94 | num_workers=module_args['loader']['num_workers'])
95 |
96 | self.data_loader_list.append(_data_loader)
97 | self.dataloader_iter_list.append(iter(_data_loader))
98 | self.dataset_len += len(_dataset)
99 |
100 | def __iter__(self):
101 | return self
102 |
103 | def __len__(self):
104 | return min([len(x) for x in self.data_loader_list])
105 |
106 | def __next__(self):
107 | balanced_batch_images = []
108 | balanced_batch_score_maps = []
109 | balanced_batch_training_masks = []
110 |
111 | for i, data_loader_iter in enumerate(self.dataloader_iter_list):
112 | try:
113 | image, score_map, training_mask = next(data_loader_iter)
114 | balanced_batch_images.append(image)
115 | balanced_batch_score_maps.append(score_map)
116 | balanced_batch_training_masks.append(training_mask)
117 | except StopIteration:
118 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
119 | image, score_map, training_mask = next(self.dataloader_iter_list[i])
120 | balanced_batch_images.append(image)
121 | balanced_batch_score_maps.append(score_map)
122 | balanced_batch_training_masks.append(training_mask)
123 | except ValueError:
124 | pass
125 |
126 | balanced_batch_images = torch.cat(balanced_batch_images, 0)
127 | balanced_batch_score_maps = torch.cat(balanced_batch_score_maps, 0)
128 | balanced_batch_training_masks = torch.cat(balanced_batch_training_masks, 0)
129 | return balanced_batch_images, balanced_batch_score_maps, balanced_batch_training_masks
130 |
131 |
132 | if __name__ == '__main__':
133 | import torch
134 | from utils.util import show_img
135 | from tqdm import tqdm
136 | import matplotlib.pyplot as plt
137 | from torchvision import transforms
138 |
139 | train_data = ImageDataset(
140 | data_list=[
141 | (r'/data1/zj/ocr/icdar2015/train/img/img_713.jpg', '/data1/zj/ocr/icdar2015/train/gt/gt_img_713.txt')],
142 | input_size=640,
143 | img_channel=3,
144 | shrink_ratio=0.5,
145 | transform=transforms.ToTensor()
146 | )
147 | train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, num_workers=0)
148 |
149 | pbar = tqdm(total=len(train_loader))
150 | for i, (img, label, mask) in enumerate(train_loader):
151 | print(label.shape, label[0][0].max())
152 | print(img.shape)
153 | print(label[0][-1].sum())
154 | print(mask[0].shape)
155 | # pbar.update(1)
156 | show_img((img[0] * mask[0].to(torch.float)).numpy().transpose(1, 2, 0), color=True)
157 | show_img(label[0])
158 | show_img(mask[0])
159 | plt.show()
160 |
161 | pbar.close()
162 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2018/6/11 15:54
3 | # @Author : zhoujun
4 | import os
5 | import cv2
6 | import torch
7 | import shutil
8 | import numpy as np
9 | from tqdm.auto import tqdm
10 | from predict import Pytorch_model
11 | from utils import cal_recall_precison_f1, draw_bbox
12 |
13 | torch.backends.cudnn.benchmark = True
14 |
15 |
16 | def main(model_path, img_folder, save_path, gpu_id):
17 | if os.path.exists(save_path):
18 | shutil.rmtree(save_path, ignore_errors=True)
19 | if not os.path.exists(save_path):
20 | os.makedirs(save_path)
21 | save_img_folder = os.path.join(save_path, 'img')
22 | if not os.path.exists(save_img_folder):
23 | os.makedirs(save_img_folder)
24 | save_txt_folder = os.path.join(save_path, 'result')
25 | if not os.path.exists(save_txt_folder):
26 | os.makedirs(save_txt_folder)
27 | img_paths = [os.path.join(img_folder, x) for x in os.listdir(img_folder)]
28 | model = Pytorch_model(model_path, gpu_id=gpu_id)
29 | total_frame = 0.0
30 | total_time = 0.0
31 | for img_path in tqdm(img_paths):
32 | img_name = os.path.basename(img_path).split('.')[0]
33 | save_name = os.path.join(save_txt_folder, 'res_' + img_name + '.txt')
34 | _, boxes_list, t = model.predict(img_path)
35 | total_frame += 1
36 | total_time += t
37 | img = draw_bbox(img_path, boxes_list, color=(0, 0, 255))
38 | cv2.imwrite(os.path.join(save_img_folder, '{}.jpg'.format(img_name)), img)
39 | np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d')
40 | print('fps:{}'.format(total_frame / total_time))
41 | return save_txt_folder
42 |
43 |
44 | if __name__ == '__main__':
45 | os.environ['CUDA_VISIBLE_DEVICES'] = str('0')
46 | model_path = r'output/PAN_shufflenetv2_FPEM_FFM.pth'
47 | img_path = r'/mnt/e/zj/dataset/icdar2015/test/img'
48 | gt_path = r'/mnt/e/zj/dataset/icdar2015/test/gt'
49 | save_path = './output/result'#model_path.replace('checkpoint/best_model.pth', 'result/')
50 | gpu_id = 0
51 |
52 | save_path = main(model_path, img_path, save_path, gpu_id=gpu_id)
53 | result = cal_recall_precison_f1(gt_path=gt_path, result_path=save_path)
54 | print(result)
55 |
--------------------------------------------------------------------------------
/imgs/example/img_10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_10.jpg
--------------------------------------------------------------------------------
/imgs/example/img_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_2.jpg
--------------------------------------------------------------------------------
/imgs/example/img_29.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_29.jpg
--------------------------------------------------------------------------------
/imgs/example/img_75.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_75.jpg
--------------------------------------------------------------------------------
/imgs/example/img_91.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/example/img_91.jpg
--------------------------------------------------------------------------------
/imgs/paper/PAN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PAN.pytorch/517e9eec3eeb629a9f346f2a80599b0e01e653ff/imgs/paper/PAN.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:55
3 | # @Author : zhoujun
4 | from .model import Model
5 | from .loss import PANLoss
6 |
7 |
8 | def get_model(config):
9 | model_config = config['arch']['args']
10 | return Model(model_config)
11 |
12 | def get_loss(config):
13 | alpha = config['loss']['args']['alpha']
14 | beta = config['loss']['args']['beta']
15 | delta_agg = config['loss']['args']['delta_agg']
16 | delta_dis = config['loss']['args']['delta_dis']
17 | ohem_ratio = config['loss']['args']['ohem_ratio']
18 | return PANLoss(alpha=alpha, beta=beta, delta_agg=delta_agg, delta_dis=delta_dis, ohem_ratio=ohem_ratio)
19 |
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:56
3 | # @Author : zhoujun
4 | import itertools
5 | import torch
6 | from torch import nn
7 | import numpy as np
8 |
9 |
10 | class PANLoss(nn.Module):
11 | def __init__(self, alpha=0.5, beta=0.25, delta_agg=0.5, delta_dis=3, ohem_ratio=3, reduction='mean'):
12 | """
13 | Implement PSE Loss.
14 | :param alpha: loss kernel 前面的系数
15 | :param beta: loss agg 和 loss dis 前面的系数
16 | :param delta_agg: 计算loss agg时的常量
17 | :param delta_dis: 计算loss dis时的常量
18 | :param ohem_ratio: OHEM的比例
19 | :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
20 | """
21 | super().__init__()
22 | assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
23 | self.alpha = alpha
24 | self.beta = beta
25 | self.delta_agg = delta_agg
26 | self.delta_dis = delta_dis
27 | self.ohem_ratio = ohem_ratio
28 | self.reduction = reduction
29 |
30 | def forward(self, outputs, labels, training_masks):
31 | texts = outputs[:, 0, :, :]
32 | kernels = outputs[:, 1, :, :]
33 | gt_texts = labels[:, 0, :, :]
34 | gt_kernels = labels[:, 1, :, :]
35 |
36 |
37 | # 计算 agg loss 和 dis loss
38 | similarity_vectors = outputs[:, 2:, :, :]
39 | loss_aggs, loss_diss = self.agg_dis_loss(texts, kernels, gt_texts, gt_kernels, similarity_vectors)
40 |
41 | # 计算 text loss
42 | selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
43 | selected_masks = selected_masks.to(outputs.device)
44 |
45 | loss_texts = self.dice_loss(texts, gt_texts, selected_masks)
46 |
47 | # 计算 kernel loss
48 | # selected_masks = ((gt_texts > 0.5) & (training_masks > 0.5)).float()
49 | mask0 = torch.sigmoid(texts).detach().cpu().numpy()
50 | mask1 = training_masks.data.cpu().numpy()
51 | selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32')
52 | selected_masks = torch.from_numpy(selected_masks).float().to(texts.device)
53 | loss_kernels = self.dice_loss(kernels, gt_kernels, selected_masks)
54 |
55 | # mean or sum
56 | if self.reduction == 'mean':
57 | loss_text = loss_texts.mean()
58 | loss_kernel = loss_kernels.mean()
59 | loss_agg = loss_aggs.mean()
60 | loss_dis = loss_diss.mean()
61 | elif self.reduction == 'sum':
62 | loss_text = loss_texts.sum()
63 | loss_kernel = loss_kernels.sum()
64 | loss_agg = loss_aggs.sum()
65 | loss_dis = loss_diss.sum()
66 | else:
67 | raise NotImplementedError
68 |
69 | loss_all = loss_text + self.alpha * loss_kernel + self.beta * (loss_agg + loss_dis)
70 | return loss_all, loss_text, loss_kernel, loss_agg, loss_dis
71 |
72 | def agg_dis_loss(self, texts, kernels, gt_texts, gt_kernels, similarity_vectors):
73 | """
74 | 计算 loss agg
75 | :param texts: 文本实例的分割结果 batch_size * (w*h)
76 | :param kernels: 缩小的文本实例的分割结果 batch_size * (w*h)
77 | :param gt_texts: 文本实例的gt batch_size * (w*h)
78 | :param gt_kernels: 缩小的文本实例的gt batch_size*(w*h)
79 | :param similarity_vectors: 相似度向量的分割结果 batch_size * 4 *(w*h)
80 | :return:
81 | """
82 | batch_size = texts.size()[0]
83 | texts = texts.contiguous().reshape(batch_size, -1)
84 | kernels = kernels.contiguous().reshape(batch_size, -1)
85 | gt_texts = gt_texts.contiguous().reshape(batch_size, -1)
86 | gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1)
87 | similarity_vectors = similarity_vectors.contiguous().view(batch_size, 4, -1)
88 | loss_aggs = []
89 | loss_diss = []
90 | for text_i, kernel_i, gt_text_i, gt_kernel_i, similarity_vector in zip(texts, kernels, gt_texts, gt_kernels,
91 | similarity_vectors):
92 | text_num = gt_text_i.max().item() + 1
93 | loss_agg_single_sample = []
94 | G_kernel_list = [] # 存储计算好的G_Ki,用于计算loss dis
95 | # 求解每一个文本实例的loss agg
96 | for text_idx in range(1, int(text_num)):
97 | # 计算 D_p_Ki
98 | single_kernel_mask = gt_kernel_i == text_idx
99 | if single_kernel_mask.sum() == 0 or (gt_text_i == text_idx).sum() == 0:
100 | # 这个文本被crop掉了
101 | continue
102 | # G_Ki, shape: 4
103 | G_kernel = similarity_vector[:, single_kernel_mask].mean(1) # 4
104 | G_kernel_list.append(G_kernel)
105 | # 文本像素的矩阵 F(p) shape: 4* nums (num of text pixel)
106 | text_similarity_vector = similarity_vector[:, gt_text_i == text_idx]
107 | # ||F(p) - G(K_i)|| - delta_agg, shape: nums
108 | text_G_ki = (text_similarity_vector - G_kernel.reshape(4, 1)).norm(2, dim=0) - self.delta_agg
109 | # D(p,K_i), shape: nums
110 | D_text_kernel = torch.max(text_G_ki, torch.tensor(0, device=text_G_ki.device, dtype=torch.float)).pow(2)
111 | # 计算单个文本实例的loss, shape: nums
112 | loss_agg_single_text = torch.log(D_text_kernel + 1).mean()
113 | loss_agg_single_sample.append(loss_agg_single_text)
114 | if len(loss_agg_single_sample) > 0:
115 | loss_agg_single_sample = torch.stack(loss_agg_single_sample).mean()
116 | else:
117 | loss_agg_single_sample = torch.tensor(0, device=texts.device, dtype=torch.float)
118 | loss_aggs.append(loss_agg_single_sample)
119 |
120 | # 求解每一个文本实例的loss dis
121 | loss_dis_single_sample = 0
122 | for G_kernel_i, G_kernel_j in itertools.combinations(G_kernel_list, 2):
123 | # delta_dis - ||G(K_i) - G(K_j)||
124 | kernel_ij = self.delta_dis - (G_kernel_i - G_kernel_j).norm(2)
125 | # D(K_i,K_j)
126 | D_kernel_ij = torch.max(kernel_ij, torch.tensor(0, device=kernel_ij.device, dtype=torch.float)).pow(2)
127 | loss_dis_single_sample += torch.log(D_kernel_ij + 1)
128 | if len(G_kernel_list) > 1:
129 | loss_dis_single_sample /= (len(G_kernel_list) * (len(G_kernel_list) - 1))
130 | else:
131 | loss_dis_single_sample = torch.tensor(0, device=texts.device, dtype=torch.float)
132 | loss_diss.append(loss_dis_single_sample)
133 | return torch.stack(loss_aggs), torch.stack(loss_diss)
134 |
135 | def dice_loss(self, input, target, mask):
136 | input = torch.sigmoid(input)
137 | target[target <= 0.5] = 0
138 | target[target > 0.5] = 1
139 | input = input.contiguous().view(input.size()[0], -1)
140 | target = target.contiguous().view(target.size()[0], -1)
141 | mask = mask.contiguous().view(mask.size()[0], -1)
142 |
143 | input = input * mask
144 | target = target * mask
145 |
146 | a = torch.sum(input * target, 1)
147 | b = torch.sum(input * input, 1) + 0.001
148 | c = torch.sum(target * target, 1) + 0.001
149 | d = (2 * a) / (b + c)
150 | return 1 - d
151 |
152 | def ohem_single(self, score, gt_text, training_mask):
153 | pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
154 |
155 | if pos_num == 0:
156 | # selected_mask = gt_text.copy() * 0 # may be not good
157 | selected_mask = training_mask
158 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
159 | return selected_mask
160 |
161 | neg_num = (int)(np.sum(gt_text <= 0.5))
162 | neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))
163 |
164 | if neg_num == 0:
165 | selected_mask = training_mask
166 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
167 | return selected_mask
168 |
169 | neg_score = score[gt_text <= 0.5]
170 | neg_score_sorted = np.sort(-neg_score)
171 | threshold = -neg_score_sorted[neg_num - 1]
172 | selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
173 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
174 | return selected_mask
175 |
176 | def ohem_batch(self, scores, gt_texts, training_masks):
177 | scores = scores.data.cpu().numpy()
178 | gt_texts = gt_texts.data.cpu().numpy()
179 | training_masks = training_masks.data.cpu().numpy()
180 |
181 | selected_masks = []
182 | for i in range(scores.shape[0]):
183 | selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
184 |
185 | selected_masks = np.concatenate(selected_masks, 0)
186 | selected_masks = torch.from_numpy(selected_masks).float()
187 |
188 | return selected_masks
189 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:57
3 | # @Author : zhoujun
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from models.modules import *
9 |
10 | backbone_dict = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]},
11 | 'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]},
12 | 'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]},
13 | 'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]},
14 | 'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]},
15 | 'resnext50_32x4d': {'models': resnext50_32x4d, 'out': [256, 512, 1024, 2048]},
16 | 'resnext101_32x8d': {'models': resnext101_32x8d, 'out': [256, 512, 1024, 2048]},
17 | 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}
18 | }
19 |
20 | segmentation_head_dict = {'FPN': FPN, 'FPEM_FFM': FPEM_FFM}
21 |
22 |
23 | # 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]},
24 | # 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]},
25 | # 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}}
26 |
27 |
28 | class Model(nn.Module):
29 | def __init__(self, model_config: dict):
30 | """
31 | PANnet
32 | :param model_config: 模型配置
33 | """
34 | super().__init__()
35 | backbone = model_config['backbone']
36 | pretrained = model_config['pretrained']
37 | segmentation_head = model_config['segmentation_head']
38 |
39 | assert backbone in backbone_dict, 'backbone must in: {}'.format(backbone_dict)
40 | assert segmentation_head in segmentation_head_dict, 'segmentation_head must in: {}'.format(
41 | segmentation_head_dict)
42 |
43 | backbone_model, backbone_out = backbone_dict[backbone]['models'], backbone_dict[backbone]['out']
44 | self.backbone = backbone_model(pretrained=pretrained)
45 | self.segmentation_head = segmentation_head_dict[segmentation_head](backbone_out, **model_config)
46 | self.name = '{}_{}'.format(backbone, segmentation_head)
47 |
48 | def forward(self, x):
49 | _, _, H, W = x.size()
50 | backbone_out = self.backbone(x)
51 | segmentation_head_out = self.segmentation_head(backbone_out)
52 | y = F.interpolate(segmentation_head_out, size=(H, W), mode='bilinear', align_corners=True)
53 | return y
54 |
55 |
56 | if __name__ == '__main__':
57 | device = torch.device('cpu')
58 | x = torch.zeros(1, 3, 640, 640).to(device)
59 |
60 | model_config = {
61 | 'backbone': 'shufflenetv2',
62 | 'fpem_repeat': 4, # fpem模块重复的次数
63 | 'pretrained': True, # backbone 是否使用imagesnet的预训练模型
64 | 'result_num': 7,
65 | 'segmentation_head': 'FPEM_FFM' # 分割头,FPN or FPEM_FFM
66 | }
67 | model = Model(model_config=model_config).to(device)
68 | y = model(x)
69 | print(y.shape)
70 | # print(model)
71 | # torch.save(model.state_dict(), 'PAN.pth')
72 |
--------------------------------------------------------------------------------
/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:54
3 | # @Author : zhoujun
4 | from .resnet import *
5 | from .shufflenetv2 import *
6 | from .segmentation_head import FPEM_FFM,FPN
--------------------------------------------------------------------------------
/models/modules/resnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/8/23 21:55
3 | # @Author : zhoujun
4 | import torch.nn as nn
5 | from torchvision.models.utils import load_state_dict_from_url
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
17 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
18 | }
19 |
20 |
21 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
22 | """3x3 convolution with padding"""
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24 | padding=dilation, groups=groups, bias=False, dilation=dilation)
25 |
26 |
27 | def conv1x1(in_planes, out_planes, stride=1):
28 | """1x1 convolution"""
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | expansion = 1
34 |
35 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
36 | base_width=64, dilation=1, norm_layer=None):
37 | super(BasicBlock, self).__init__()
38 | if norm_layer is None:
39 | norm_layer = nn.BatchNorm2d
40 | if groups != 1 or base_width != 64:
41 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
42 | if dilation > 1:
43 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
44 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
45 | self.conv1 = conv3x3(inplanes, planes, stride)
46 | self.bn1 = norm_layer(planes)
47 | self.relu = nn.ReLU(inplace=True)
48 | self.conv2 = conv3x3(planes, planes)
49 | self.bn2 = norm_layer(planes)
50 | self.downsample = downsample
51 | self.stride = stride
52 |
53 | def forward(self, x):
54 | identity = x
55 |
56 | out = self.conv1(x)
57 | out = self.bn1(out)
58 | out = self.relu(out)
59 |
60 | out = self.conv2(out)
61 | out = self.bn2(out)
62 |
63 | if self.downsample is not None:
64 | identity = self.downsample(x)
65 |
66 | out += identity
67 | out = self.relu(out)
68 |
69 | return out
70 |
71 |
72 | class Bottleneck(nn.Module):
73 | expansion = 4
74 |
75 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
76 | base_width=64, dilation=1, norm_layer=None):
77 | super(Bottleneck, self).__init__()
78 | if norm_layer is None:
79 | norm_layer = nn.BatchNorm2d
80 | width = int(planes * (base_width / 64.)) * groups
81 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
82 | self.conv1 = conv1x1(inplanes, width)
83 | self.bn1 = norm_layer(width)
84 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
85 | self.bn2 = norm_layer(width)
86 | self.conv3 = conv1x1(width, planes * self.expansion)
87 | self.bn3 = norm_layer(planes * self.expansion)
88 | self.relu = nn.ReLU(inplace=True)
89 | self.downsample = downsample
90 | self.stride = stride
91 |
92 | def forward(self, x):
93 | identity = x
94 |
95 | out = self.conv1(x)
96 | out = self.bn1(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv2(out)
100 | out = self.bn2(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv3(out)
104 | out = self.bn3(out)
105 |
106 | if self.downsample is not None:
107 | identity = self.downsample(x)
108 |
109 | out += identity
110 | out = self.relu(out)
111 |
112 | return out
113 |
114 |
115 | class ResNet(nn.Module):
116 |
117 | def __init__(self, block, layers, zero_init_residual=False,
118 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
119 | norm_layer=None):
120 | super(ResNet, self).__init__()
121 | if norm_layer is None:
122 | norm_layer = nn.BatchNorm2d
123 | self._norm_layer = norm_layer
124 |
125 | self.inplanes = 64
126 | self.dilation = 1
127 | if replace_stride_with_dilation is None:
128 | # each element in the tuple indicates if we should replace
129 | # the 2x2 stride with a dilated convolution instead
130 | replace_stride_with_dilation = [False, False, False]
131 | if len(replace_stride_with_dilation) != 3:
132 | raise ValueError("replace_stride_with_dilation should be None "
133 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
134 | self.groups = groups
135 | self.base_width = width_per_group
136 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
137 | bias=False)
138 | self.bn1 = norm_layer(self.inplanes)
139 | self.relu = nn.ReLU(inplace=True)
140 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
141 | self.layer1 = self._make_layer(block, 64, layers[0])
142 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
143 | dilate=replace_stride_with_dilation[0])
144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
145 | dilate=replace_stride_with_dilation[1])
146 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
147 | dilate=replace_stride_with_dilation[2])
148 |
149 | for m in self.modules():
150 | if isinstance(m, nn.Conv2d):
151 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
152 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
153 | nn.init.constant_(m.weight, 1)
154 | nn.init.constant_(m.bias, 0)
155 |
156 | # Zero-initialize the last BN in each residual branch,
157 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
158 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
159 | if zero_init_residual:
160 | for m in self.modules():
161 | if isinstance(m, Bottleneck):
162 | nn.init.constant_(m.bn3.weight, 0)
163 | elif isinstance(m, BasicBlock):
164 | nn.init.constant_(m.bn2.weight, 0)
165 |
166 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
167 | norm_layer = self._norm_layer
168 | downsample = None
169 | previous_dilation = self.dilation
170 | if dilate:
171 | self.dilation *= stride
172 | stride = 1
173 | if stride != 1 or self.inplanes != planes * block.expansion:
174 | downsample = nn.Sequential(
175 | conv1x1(self.inplanes, planes * block.expansion, stride),
176 | norm_layer(planes * block.expansion),
177 | )
178 |
179 | layers = []
180 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
181 | self.base_width, previous_dilation, norm_layer))
182 | self.inplanes = planes * block.expansion
183 | for _ in range(1, blocks):
184 | layers.append(block(self.inplanes, planes, groups=self.groups,
185 | base_width=self.base_width, dilation=self.dilation,
186 | norm_layer=norm_layer))
187 |
188 | return nn.Sequential(*layers)
189 |
190 | def forward(self, x):
191 | x = self.conv1(x)
192 | x = self.bn1(x)
193 | x = self.relu(x)
194 | x = self.maxpool(x)
195 |
196 | c2 = self.layer1(x)
197 | c3 = self.layer2(c2)
198 | c4 = self.layer3(c3)
199 | c5 = self.layer4(c4)
200 |
201 | return c2, c3, c4, c5
202 |
203 |
204 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
205 | model = ResNet(block, layers, **kwargs)
206 | if pretrained:
207 | state_dict = load_state_dict_from_url(model_urls[arch],
208 | progress=progress)
209 | model.load_state_dict(state_dict, strict=False)
210 | print('load pretrained models from imagenet')
211 | return model
212 |
213 |
214 | def resnet18(pretrained=False, progress=True, **kwargs):
215 | """Constructs a ResNet-18 model.
216 |
217 | Args:
218 | pretrained (bool): If True, returns a model pre-trained on ImageNet
219 | progress (bool): If True, displays a progress bar of the download to stderr
220 | """
221 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
222 | **kwargs)
223 |
224 |
225 | def resnet34(pretrained=False, progress=True, **kwargs):
226 | """Constructs a ResNet-34 model.
227 |
228 | Args:
229 | pretrained (bool): If True, returns a model pre-trained on ImageNet
230 | progress (bool): If True, displays a progress bar of the download to stderr
231 | """
232 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
233 | **kwargs)
234 |
235 |
236 | def resnet50(pretrained=False, progress=True, **kwargs):
237 | """Constructs a ResNet-50 model.
238 |
239 | Args:
240 | pretrained (bool): If True, returns a model pre-trained on ImageNet
241 | progress (bool): If True, displays a progress bar of the download to stderr
242 | """
243 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
244 | **kwargs)
245 |
246 |
247 | def resnet101(pretrained=False, progress=True, **kwargs):
248 | """Constructs a ResNet-101 model.
249 |
250 | Args:
251 | pretrained (bool): If True, returns a model pre-trained on ImageNet
252 | progress (bool): If True, displays a progress bar of the download to stderr
253 | """
254 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
255 | **kwargs)
256 |
257 |
258 | def resnet152(pretrained=False, progress=True, **kwargs):
259 | """Constructs a ResNet-152 model.
260 |
261 | Args:
262 | pretrained (bool): If True, returns a model pre-trained on ImageNet
263 | progress (bool): If True, displays a progress bar of the download to stderr
264 | """
265 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
266 | **kwargs)
267 |
268 |
269 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
270 | """Constructs a ResNeXt-50 32x4d model.
271 |
272 | Args:
273 | pretrained (bool): If True, returns a model pre-trained on ImageNet
274 | progress (bool): If True, displays a progress bar of the download to stderr
275 | """
276 | kwargs['groups'] = 32
277 | kwargs['width_per_group'] = 4
278 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
279 | pretrained, progress, **kwargs)
280 |
281 |
282 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
283 | """Constructs a ResNeXt-101 32x8d model.
284 |
285 | Args:
286 | pretrained (bool): If True, returns a model pre-trained on ImageNet
287 | progress (bool): If True, displays a progress bar of the download to stderr
288 | """
289 | kwargs['groups'] = 32
290 | kwargs['width_per_group'] = 8
291 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
292 | pretrained, progress, **kwargs)
293 |
294 | if __name__ == '__main__':
295 | import torch
296 | x = torch.zeros(1, 3, 640, 640)
297 | net = resnext101_32x8d(pretrained=False)
298 | y = net(x)
299 | for u in y:
300 | print(u.shape)
--------------------------------------------------------------------------------
/models/modules/segmentation_head.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/9/13 10:29
3 | # @Author : zhoujun
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class FPN(nn.Module):
10 | def __init__(self, backbone_out_channels, **kwargs):
11 | """
12 | :param backbone_out_channels: 基础网络输出的维度
13 | :param kwargs:
14 | """
15 | super().__init__()
16 | result_num = kwargs.get('result_num', 6)
17 | inplace = True
18 | conv_out = 256
19 | # reduce layers
20 | self.reduce_conv_c2 = nn.Sequential(
21 | nn.Conv2d(backbone_out_channels[0], conv_out, kernel_size=1, stride=1, padding=0),
22 | nn.BatchNorm2d(conv_out),
23 | nn.ReLU(inplace=inplace)
24 | )
25 | self.reduce_conv_c3 = nn.Sequential(
26 | nn.Conv2d(backbone_out_channels[1], conv_out, kernel_size=1, stride=1, padding=0),
27 | nn.BatchNorm2d(conv_out),
28 | nn.ReLU(inplace=inplace)
29 | )
30 | self.reduce_conv_c4 = nn.Sequential(
31 | nn.Conv2d(backbone_out_channels[2], conv_out, kernel_size=1, stride=1, padding=0),
32 | nn.BatchNorm2d(conv_out),
33 | nn.ReLU(inplace=inplace)
34 | )
35 |
36 | self.reduce_conv_c5 = nn.Sequential(
37 | nn.Conv2d(backbone_out_channels[3], conv_out, kernel_size=1, stride=1, padding=0),
38 | nn.BatchNorm2d(conv_out),
39 | nn.ReLU(inplace=inplace)
40 | )
41 | # Smooth layers
42 | self.smooth_p4 = nn.Sequential(
43 | nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
44 | nn.BatchNorm2d(conv_out),
45 | nn.ReLU(inplace=inplace)
46 | )
47 | self.smooth_p3 = nn.Sequential(
48 | nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
49 | nn.BatchNorm2d(conv_out),
50 | nn.ReLU(inplace=inplace)
51 | )
52 | self.smooth_p2 = nn.Sequential(
53 | nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
54 | nn.BatchNorm2d(conv_out),
55 | nn.ReLU(inplace=inplace)
56 | )
57 |
58 | self.conv = nn.Sequential(
59 | nn.Conv2d(conv_out * 4, conv_out, kernel_size=3, padding=1, stride=1),
60 | nn.BatchNorm2d(conv_out),
61 | nn.ReLU(inplace=inplace)
62 | )
63 | self.out_conv = nn.Conv2d(conv_out, result_num, kernel_size=1, stride=1)
64 |
65 | def forward(self, x):
66 | c2, c3, c4, c5 = x
67 | # Top-down
68 | p5 = self.reduce_conv_c5(c5)
69 | p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
70 | p4 = self.smooth_p4(p4)
71 | p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
72 | p3 = self.smooth_p3(p3)
73 | p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
74 | p2 = self.smooth_p2(p2)
75 |
76 | x = self._upsample_cat(p2, p3, p4, p5)
77 | x = self.conv(x)
78 | x = self.out_conv(x)
79 | return x
80 |
81 | def _upsample_add(self, x, y):
82 | return F.interpolate(x, size=y.size()[2:], mode='bilinear') + y
83 |
84 | def _upsample_cat(self, p2, p3, p4, p5):
85 | h, w = p2.size()[2:]
86 | p3 = F.interpolate(p3, size=(h, w), mode='bilinear')
87 | p4 = F.interpolate(p4, size=(h, w), mode='bilinear')
88 | p5 = F.interpolate(p5, size=(h, w), mode='bilinear')
89 | return torch.cat([p2, p3, p4, p5], dim=1)
90 |
91 |
92 | class FPEM_FFM(nn.Module):
93 | def __init__(self, backbone_out_channels, **kwargs):
94 | """
95 | PANnet
96 | :param backbone_out_channels: 基础网络输出的维度
97 | """
98 | super().__init__()
99 | fpem_repeat = kwargs.get('fpem_repeat', 2)
100 | conv_out = 128
101 | # reduce layers
102 | self.reduce_conv_c2 = nn.Sequential(
103 | nn.Conv2d(in_channels=backbone_out_channels[0], out_channels=conv_out, kernel_size=1),
104 | nn.BatchNorm2d(conv_out),
105 | nn.ReLU()
106 | )
107 | self.reduce_conv_c3 = nn.Sequential(
108 | nn.Conv2d(in_channels=backbone_out_channels[1], out_channels=conv_out, kernel_size=1),
109 | nn.BatchNorm2d(conv_out),
110 | nn.ReLU()
111 | )
112 | self.reduce_conv_c4 = nn.Sequential(
113 | nn.Conv2d(in_channels=backbone_out_channels[2], out_channels=conv_out, kernel_size=1),
114 | nn.BatchNorm2d(conv_out),
115 | nn.ReLU()
116 | )
117 | self.reduce_conv_c5 = nn.Sequential(
118 | nn.Conv2d(in_channels=backbone_out_channels[3], out_channels=conv_out, kernel_size=1),
119 | nn.BatchNorm2d(conv_out),
120 | nn.ReLU()
121 | )
122 | self.fpems = nn.ModuleList()
123 | for i in range(fpem_repeat):
124 | self.fpems.append(FPEM(conv_out))
125 | self.out_conv = nn.Conv2d(in_channels=conv_out * 4, out_channels=6, kernel_size=1)
126 |
127 | def forward(self, x):
128 | c2, c3, c4, c5 = x
129 | # reduce channel
130 | c2 = self.reduce_conv_c2(c2)
131 | c3 = self.reduce_conv_c3(c3)
132 | c4 = self.reduce_conv_c4(c4)
133 | c5 = self.reduce_conv_c5(c5)
134 |
135 | # FPEM
136 | for i, fpem in enumerate(self.fpems):
137 | c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
138 | if i == 0:
139 | c2_ffm = c2
140 | c3_ffm = c3
141 | c4_ffm = c4
142 | c5_ffm = c5
143 | else:
144 | c2_ffm += c2
145 | c3_ffm += c3
146 | c4_ffm += c4
147 | c5_ffm += c5
148 |
149 | # FFM
150 | c5 = F.interpolate(c5_ffm, c2_ffm.size()[-2:], mode='bilinear')
151 | c4 = F.interpolate(c4_ffm, c2_ffm.size()[-2:], mode='bilinear')
152 | c3 = F.interpolate(c3_ffm, c2_ffm.size()[-2:], mode='bilinear')
153 | Fy = torch.cat([c2_ffm, c3, c4, c5], dim=1)
154 | y = self.out_conv(Fy)
155 | return y
156 |
157 |
158 | class FPEM(nn.Module):
159 | def __init__(self, in_channels=128):
160 | super().__init__()
161 | self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
162 | self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
163 | self.up_add3 = SeparableConv2d(in_channels, in_channels, 1)
164 | self.down_add1 = SeparableConv2d(in_channels, in_channels, 2)
165 | self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
166 | self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
167 |
168 | def forward(self, c2, c3, c4, c5):
169 | # up阶段
170 | c4 = self.up_add1(self._upsample_add(c5, c4))
171 | c3 = self.up_add2(self._upsample_add(c4, c3))
172 | c2 = self.up_add3(self._upsample_add(c3, c2))
173 |
174 | # down 阶段
175 | c3 = self.down_add1(self._upsample_add(c3, c2))
176 | c4 = self.down_add2(self._upsample_add(c4, c3))
177 | c5 = self.down_add3(self._upsample_add(c5, c4))
178 | return c2, c3, c4, c5
179 |
180 | def _upsample_add(self, x, y):
181 | return F.interpolate(x, size=y.size()[2:], mode='bilinear') + y
182 |
183 |
184 | class SeparableConv2d(nn.Module):
185 | def __init__(self, in_channels, out_channels, stride=1):
186 | super(SeparableConv2d, self).__init__()
187 |
188 | self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1,
189 | stride=stride, groups=in_channels)
190 | self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
191 | self.bn = nn.BatchNorm2d(out_channels)
192 | self.relu = nn.ReLU()
193 |
194 | def forward(self, x):
195 | x = self.depthwise_conv(x)
196 | x = self.pointwise_conv(x)
197 | x = self.bn(x)
198 | x = self.relu(x)
199 | return x
200 |
--------------------------------------------------------------------------------
/models/modules/shufflenetv2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/11/1 15:31
3 | # @Author : zhoujun
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision.models.utils import load_state_dict_from_url
8 |
9 | __all__ = [
10 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
11 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
12 | ]
13 |
14 | model_urls = {
15 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
16 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
17 | 'shufflenetv2_x1.5': None,
18 | 'shufflenetv2_x2.0': None,
19 | }
20 |
21 |
22 | def channel_shuffle(x, groups):
23 | batchsize, num_channels, height, width = x.data.size()
24 | channels_per_group = num_channels // groups
25 |
26 | # reshape
27 | x = x.view(batchsize, groups,
28 | channels_per_group, height, width)
29 |
30 | x = torch.transpose(x, 1, 2).contiguous()
31 |
32 | # flatten
33 | x = x.view(batchsize, -1, height, width)
34 |
35 | return x
36 |
37 |
38 | class InvertedResidual(nn.Module):
39 | def __init__(self, inp, oup, stride):
40 | super(InvertedResidual, self).__init__()
41 |
42 | if not (1 <= stride <= 3):
43 | raise ValueError('illegal stride value')
44 | self.stride = stride
45 |
46 | branch_features = oup // 2
47 | assert (self.stride != 1) or (inp == branch_features << 1)
48 |
49 | if self.stride > 1:
50 | self.branch1 = nn.Sequential(
51 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
52 | nn.BatchNorm2d(inp),
53 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
54 | nn.BatchNorm2d(branch_features),
55 | nn.ReLU(inplace=True),
56 | )
57 |
58 | self.branch2 = nn.Sequential(
59 | nn.Conv2d(inp if (self.stride > 1) else branch_features,
60 | branch_features, kernel_size=1, stride=1, padding=0, bias=False),
61 | nn.BatchNorm2d(branch_features),
62 | nn.ReLU(inplace=True),
63 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
64 | nn.BatchNorm2d(branch_features),
65 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
66 | nn.BatchNorm2d(branch_features),
67 | nn.ReLU(inplace=True),
68 | )
69 |
70 | @staticmethod
71 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
72 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
73 |
74 | def forward(self, x):
75 | if self.stride == 1:
76 | x1, x2 = x.chunk(2, dim=1)
77 | out = torch.cat((x1, self.branch2(x2)), dim=1)
78 | else:
79 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
80 |
81 | out = channel_shuffle(out, 2)
82 |
83 | return out
84 |
85 |
86 | class ShuffleNetV2(nn.Module):
87 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
88 | super(ShuffleNetV2, self).__init__()
89 |
90 | if len(stages_repeats) != 3:
91 | raise ValueError('expected stages_repeats as list of 3 positive ints')
92 | if len(stages_out_channels) != 5:
93 | raise ValueError('expected stages_out_channels as list of 5 positive ints')
94 | self._stage_out_channels = stages_out_channels
95 |
96 | input_channels = 3
97 | output_channels = self._stage_out_channels[0]
98 | self.conv1 = nn.Sequential(
99 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
100 | nn.BatchNorm2d(output_channels),
101 | nn.ReLU(inplace=True),
102 | )
103 | input_channels = output_channels
104 |
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 |
107 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
108 | for name, repeats, output_channels in zip(
109 | stage_names, stages_repeats, self._stage_out_channels[1:]):
110 | seq = [InvertedResidual(input_channels, output_channels, 2)]
111 | for i in range(repeats - 1):
112 | seq.append(InvertedResidual(output_channels, output_channels, 1))
113 | setattr(self, name, nn.Sequential(*seq))
114 | input_channels = output_channels
115 |
116 | output_channels = self._stage_out_channels[-1]
117 | self.conv5 = nn.Sequential(
118 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
119 | nn.BatchNorm2d(output_channels),
120 | nn.ReLU(inplace=True),
121 | )
122 |
123 | def forward(self, x):
124 | x = self.conv1(x)
125 | c2 = self.maxpool(x)
126 | c3 = self.stage2(c2)
127 | c4 = self.stage3(c3)
128 | c5 = self.stage4(c4)
129 | # c5 = self.conv5(c5)
130 | return c2, c3, c4, c5
131 |
132 |
133 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
134 | model = ShuffleNetV2(*args, **kwargs)
135 |
136 | if pretrained:
137 | model_url = model_urls[arch]
138 | if model_url is None:
139 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
140 | else:
141 | state_dict = load_state_dict_from_url(model_url, progress=progress)
142 | model.load_state_dict(state_dict,strict=False)
143 |
144 | return model
145 |
146 |
147 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
148 | """
149 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in
150 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
151 | `_.
152 |
153 | Args:
154 | pretrained (bool): If True, returns a model pre-trained on ImageNet
155 | progress (bool): If True, displays a progress bar of the download to stderr
156 | """
157 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
158 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
159 |
160 |
161 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
162 | """
163 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in
164 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
165 | `_.
166 |
167 | Args:
168 | pretrained (bool): If True, returns a model pre-trained on ImageNet
169 | progress (bool): If True, displays a progress bar of the download to stderr
170 | """
171 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
172 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
173 |
174 |
175 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
176 | """
177 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in
178 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
179 | `_.
180 |
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on ImageNet
183 | progress (bool): If True, displays a progress bar of the download to stderr
184 | """
185 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
186 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
187 |
188 |
189 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
190 | """
191 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in
192 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
193 | `_.
194 |
195 | Args:
196 | pretrained (bool): If True, returns a model pre-trained on ImageNet
197 | progress (bool): If True, displays a progress bar of the download to stderr
198 | """
199 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
200 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
201 |
--------------------------------------------------------------------------------
/post_processing/Makefile:
--------------------------------------------------------------------------------
1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags)
2 | LDFLAGS = $(shell python3-config --ldflags)
3 |
4 | DEPS = $(shell find include -xtype f)
5 | CXX_SOURCES = pse.cpp
6 |
7 | LIB_SO = pse.so
8 |
9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS)
10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC
11 |
12 | clean:
13 | rm -rf $(LIB_SO)
14 |
--------------------------------------------------------------------------------
/post_processing/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/9/8 14:18
3 | # @Author : zhoujun
4 | import os
5 | import cv2
6 | import torch
7 | import time
8 | import subprocess
9 | import numpy as np
10 |
11 | from .pypse import pse_py
12 | from .kmeans import km
13 |
14 | BASE_DIR = os.path.dirname(os.path.realpath(__file__))
15 |
16 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value
17 | raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR))
18 |
19 |
20 | def decode(preds, scale=1, threshold=0.7311, min_area=5):
21 | """
22 | 在输出上使用sigmoid 将值转换为置信度,并使用阈值来进行文字和背景的区分
23 | :param preds: 网络输出
24 | :param scale: 网络的scale
25 | :param threshold: sigmoid的阈值
26 | :return: 最后的输出图和文本框
27 | """
28 | from .pse import pse_cpp, get_points, get_num
29 | preds[:2, :, :] = torch.sigmoid(preds[:2, :, :])
30 | preds = preds.detach().cpu().numpy()
31 | score = preds[0].astype(np.float32)
32 | text = preds[0] > threshold # text
33 | kernel = (preds[1] > threshold) * text # kernel
34 | similarity_vectors = preds[2:].transpose((1, 2, 0))
35 |
36 | label_num, label = cv2.connectedComponents(kernel.astype(np.uint8), connectivity=4)
37 | label_values = []
38 | label_sum = get_num(label, label_num)
39 | for label_idx in range(1, label_num):
40 | if label_sum[label_idx] < min_area:
41 | continue
42 | label_values.append(label_idx)
43 |
44 | pred = pse_cpp(text.astype(np.uint8), similarity_vectors, label, label_num, 0.8)
45 | pred = pred.reshape(text.shape)
46 |
47 | bbox_list = []
48 | label_points = get_points(pred, score, label_num)
49 | for label_value, label_point in label_points.items():
50 | if label_value not in label_values:
51 | continue
52 | score_i = label_point[0]
53 | label_point = label_point[2:]
54 | points = np.array(label_point, dtype=int).reshape(-1, 2)
55 |
56 | if points.shape[0] < 100 / (scale * scale):
57 | continue
58 |
59 | if score_i < 0.93:
60 | continue
61 |
62 | rect = cv2.minAreaRect(points)
63 | bbox = cv2.boxPoints(rect)
64 | bbox_list.append([bbox[1], bbox[2], bbox[3], bbox[0]])
65 | return pred, np.array(bbox_list)
66 |
67 |
68 | def decode_dice(preds, scale=1, threshold=0.7311, min_area=5):
69 | import pyclipper
70 | preds[:2, :, :] = torch.sigmoid(preds[:2, :, :])
71 | preds = preds.detach().cpu().numpy()
72 | text = preds[0] > threshold # text
73 | kernel = (preds[1] > threshold) * text # kernel
74 |
75 | label_num, label = cv2.connectedComponents(kernel.astype(np.uint8), connectivity=4)
76 | bbox_list = []
77 | for label_idx in range(1, label_num):
78 | points = np.array(np.where(label_num == label_idx)).transpose((1, 0))[:, ::-1]
79 |
80 | rect = cv2.minAreaRect(points)
81 | poly = cv2.boxPoints(rect).astype(int)
82 |
83 | d_i = cv2.contourArea(poly) * 1.5 / cv2.arcLength(poly, True)
84 | pco = pyclipper.PyclipperOffset()
85 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
86 | shrinked_poly = np.array(pco.Execute(-d_i))
87 |
88 | if cv2.contourArea(shrinked_poly) < 800 / (scale * scale):
89 | continue
90 |
91 | bbox_list.append([shrinked_poly[1], shrinked_poly[2], shrinked_poly[3], shrinked_poly[0]])
92 | return label, np.array(bbox_list)
93 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/buffer_info.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/buffer_info.h: Python buffer object interface
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "detail/common.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 |
16 | /// Information record describing a Python buffer object
17 | struct buffer_info {
18 | void *ptr = nullptr; // Pointer to the underlying storage
19 | ssize_t itemsize = 0; // Size of individual items in bytes
20 | ssize_t size = 0; // Total number of entries
21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format()
22 | ssize_t ndim = 0; // Number of dimensions
23 | std::vector shape; // Shape of the tensor (1 entry per dimension)
24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension)
25 |
26 | buffer_info() { }
27 |
28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
29 | detail::any_container shape_in, detail::any_container strides_in)
30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
31 | shape(std::move(shape_in)), strides(std::move(strides_in)) {
32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
34 | for (size_t i = 0; i < (size_t) ndim; ++i)
35 | size *= shape[i];
36 | }
37 |
38 | template
39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in)
40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
41 |
42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
44 |
45 | template
46 | buffer_info(T *ptr, ssize_t size)
47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { }
48 |
49 | explicit buffer_info(Py_buffer *view, bool ownview = true)
50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim,
51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
52 | this->view = view;
53 | this->ownview = ownview;
54 | }
55 |
56 | buffer_info(const buffer_info &) = delete;
57 | buffer_info& operator=(const buffer_info &) = delete;
58 |
59 | buffer_info(buffer_info &&other) {
60 | (*this) = std::move(other);
61 | }
62 |
63 | buffer_info& operator=(buffer_info &&rhs) {
64 | ptr = rhs.ptr;
65 | itemsize = rhs.itemsize;
66 | size = rhs.size;
67 | format = std::move(rhs.format);
68 | ndim = rhs.ndim;
69 | shape = std::move(rhs.shape);
70 | strides = std::move(rhs.strides);
71 | std::swap(view, rhs.view);
72 | std::swap(ownview, rhs.ownview);
73 | return *this;
74 | }
75 |
76 | ~buffer_info() {
77 | if (view && ownview) { PyBuffer_Release(view); delete view; }
78 | }
79 |
80 | private:
81 | struct private_ctr_tag { };
82 |
83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
84 | detail::any_container &&shape_in, detail::any_container &&strides_in)
85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
86 |
87 | Py_buffer *view = nullptr;
88 | bool ownview = false;
89 | };
90 |
91 | NAMESPACE_BEGIN(detail)
92 |
93 | template struct compare_buffer_info {
94 | static bool compare(const buffer_info& b) {
95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T);
96 | }
97 | };
98 |
99 | template struct compare_buffer_info::value>> {
100 | static bool compare(const buffer_info& b) {
101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value ||
102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) ||
103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n")));
104 | }
105 | };
106 |
107 | NAMESPACE_END(detail)
108 | NAMESPACE_END(PYBIND11_NAMESPACE)
109 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/chrono.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
3 |
4 | Copyright (c) 2016 Trent Houliston and
5 | Wenzel Jakob
6 |
7 | All rights reserved. Use of this source code is governed by a
8 | BSD-style license that can be found in the LICENSE file.
9 | */
10 |
11 | #pragma once
12 |
13 | #include "pybind11.h"
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required
20 | #ifndef PyDateTime_DELTA_GET_DAYS
21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
22 | #endif
23 | #ifndef PyDateTime_DELTA_GET_SECONDS
24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
25 | #endif
26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS
27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
28 | #endif
29 |
30 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
31 | NAMESPACE_BEGIN(detail)
32 |
33 | template class duration_caster {
34 | public:
35 | typedef typename type::rep rep;
36 | typedef typename type::period period;
37 |
38 | typedef std::chrono::duration> days;
39 |
40 | bool load(handle src, bool) {
41 | using namespace std::chrono;
42 |
43 | // Lazy initialise the PyDateTime import
44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
45 |
46 | if (!src) return false;
47 | // If invoked with datetime.delta object
48 | if (PyDelta_Check(src.ptr())) {
49 | value = type(duration_cast>(
50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
53 | return true;
54 | }
55 | // If invoked with a float we assume it is seconds and convert
56 | else if (PyFloat_Check(src.ptr())) {
57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr()))));
58 | return true;
59 | }
60 | else return false;
61 | }
62 |
63 | // If this is a duration just return it back
64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) {
65 | return src;
66 | }
67 |
68 | // If this is a time_point get the time_since_epoch
69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) {
70 | return src.time_since_epoch();
71 | }
72 |
73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
74 | using namespace std::chrono;
75 |
76 | // Use overloaded function to get our duration from our source
77 | // Works out if it is a duration or time_point and get the duration
78 | auto d = get_duration(src);
79 |
80 | // Lazy initialise the PyDateTime import
81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
82 |
83 | // Declare these special duration types so the conversions happen with the correct primitive types (int)
84 | using dd_t = duration>;
85 | using ss_t = duration>;
86 | using us_t = duration;
87 |
88 | auto dd = duration_cast(d);
89 | auto subd = d - dd;
90 | auto ss = duration_cast(subd);
91 | auto us = duration_cast(subd - ss);
92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
93 | }
94 |
95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
96 | };
97 |
98 | // This is for casting times on the system clock into datetime.datetime instances
99 | template class type_caster> {
100 | public:
101 | typedef std::chrono::time_point type;
102 | bool load(handle src, bool) {
103 | using namespace std::chrono;
104 |
105 | // Lazy initialise the PyDateTime import
106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
107 |
108 | if (!src) return false;
109 | if (PyDateTime_Check(src.ptr())) {
110 | std::tm cal;
111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
117 | cal.tm_isdst = -1;
118 |
119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
120 | return true;
121 | }
122 | else return false;
123 | }
124 |
125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) {
126 | using namespace std::chrono;
127 |
128 | // Lazy initialise the PyDateTime import
129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
130 |
131 | std::time_t tt = system_clock::to_time_t(src);
132 | // this function uses static memory so it's best to copy it out asap just in case
133 | // otherwise other code that is using localtime may break this (not just python code)
134 | std::tm localtime = *std::localtime(&tt);
135 |
136 | // Declare these special duration types so the conversions happen with the correct primitive types (int)
137 | using us_t = duration;
138 |
139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
140 | localtime.tm_mon + 1,
141 | localtime.tm_mday,
142 | localtime.tm_hour,
143 | localtime.tm_min,
144 | localtime.tm_sec,
145 | (duration_cast(src.time_since_epoch() % seconds(1))).count());
146 | }
147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
148 | };
149 |
150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects
151 | // since they are not measured on calendar time. So instead we just make them timedeltas
152 | // Or if they have passed us a time as a float we convert that
153 | template class type_caster>
154 | : public duration_caster> {
155 | };
156 |
157 | template class type_caster>
158 | : public duration_caster> {
159 | };
160 |
161 | NAMESPACE_END(detail)
162 | NAMESPACE_END(PYBIND11_NAMESPACE)
163 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/common.h:
--------------------------------------------------------------------------------
1 | #include "detail/common.h"
2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
3 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/complex.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/complex.h: Complex number support
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "pybind11.h"
13 | #include
14 |
15 | /// glibc defines I as a macro which breaks things, e.g., boost template names
16 | #ifdef I
17 | # undef I
18 | #endif
19 |
20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
21 |
22 | template struct format_descriptor, detail::enable_if_t::value>> {
23 | static constexpr const char c = format_descriptor::c;
24 | static constexpr const char value[3] = { 'Z', c, '\0' };
25 | static std::string format() { return std::string(value); }
26 | };
27 |
28 | #ifndef PYBIND11_CPP17
29 |
30 | template constexpr const char format_descriptor<
31 | std::complex, detail::enable_if_t::value>>::value[3];
32 |
33 | #endif
34 |
35 | NAMESPACE_BEGIN(detail)
36 |
37 | template struct is_fmt_numeric, detail::enable_if_t::value>> {
38 | static constexpr bool value = true;
39 | static constexpr int index = is_fmt_numeric::index + 3;
40 | };
41 |
42 | template class type_caster> {
43 | public:
44 | bool load(handle src, bool convert) {
45 | if (!src)
46 | return false;
47 | if (!convert && !PyComplex_Check(src.ptr()))
48 | return false;
49 | Py_complex result = PyComplex_AsCComplex(src.ptr());
50 | if (result.real == -1.0 && PyErr_Occurred()) {
51 | PyErr_Clear();
52 | return false;
53 | }
54 | value = std::complex((T) result.real, (T) result.imag);
55 | return true;
56 | }
57 |
58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) {
59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
60 | }
61 |
62 | PYBIND11_TYPE_CASTER(std::complex, _("complex"));
63 | };
64 | NAMESPACE_END(detail)
65 | NAMESPACE_END(PYBIND11_NAMESPACE)
66 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/descr.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/descr.h: Helper type for concatenating type signatures
3 | either at runtime (C++11) or compile time (C++14)
4 |
5 | Copyright (c) 2016 Wenzel Jakob
6 |
7 | All rights reserved. Use of this source code is governed by a
8 | BSD-style license that can be found in the LICENSE file.
9 | */
10 |
11 | #pragma once
12 |
13 | #include "common.h"
14 |
15 | NAMESPACE_BEGIN(pybind11)
16 | NAMESPACE_BEGIN(detail)
17 |
18 | /* Concatenate type signatures at compile time using C++14 */
19 | #if defined(PYBIND11_CPP14) && !defined(_MSC_VER)
20 | #define PYBIND11_CONSTEXPR_DESCR
21 |
22 | template class descr {
23 | template friend class descr;
24 | public:
25 | constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1])
26 | : descr(text, types,
27 | make_index_sequence(),
28 | make_index_sequence()) { }
29 |
30 | constexpr const char *text() const { return m_text; }
31 | constexpr const std::type_info * const * types() const { return m_types; }
32 |
33 | template
34 | constexpr descr operator+(const descr &other) const {
35 | return concat(other,
36 | make_index_sequence(),
37 | make_index_sequence(),
38 | make_index_sequence(),
39 | make_index_sequence());
40 | }
41 |
42 | protected:
43 | template
44 | constexpr descr(
45 | char const (&text) [Size1+1],
46 | const std::type_info * const (&types) [Size2+1],
47 | index_sequence, index_sequence)
48 | : m_text{text[Indices1]..., '\0'},
49 | m_types{types[Indices2]..., nullptr } {}
50 |
51 | template
53 | constexpr descr
54 | concat(const descr &other,
55 | index_sequence, index_sequence,
56 | index_sequence, index_sequence) const {
57 | return descr(
58 | { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' },
59 | { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr }
60 | );
61 | }
62 |
63 | protected:
64 | char m_text[Size1 + 1];
65 | const std::type_info * m_types[Size2 + 1];
66 | };
67 |
68 | template constexpr descr _(char const(&text)[Size]) {
69 | return descr(text, { nullptr });
70 | }
71 |
72 | template struct int_to_str : int_to_str { };
73 | template struct int_to_str<0, Digits...> {
74 | static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr });
75 | };
76 |
77 | // Ternary description (like std::conditional)
78 | template
79 | constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) {
80 | return _(text1);
81 | }
82 | template
83 | constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) {
84 | return _(text2);
85 | }
86 | template
87 | constexpr enable_if_t> _(descr d, descr) { return d; }
88 | template
89 | constexpr enable_if_t> _(descr, descr d) { return d; }
90 |
91 | template auto constexpr _() -> decltype(int_to_str::digits) {
92 | return int_to_str::digits;
93 | }
94 |
95 | template constexpr descr<1, 1> _() {
96 | return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr });
97 | }
98 |
99 | inline constexpr descr<0, 0> concat() { return _(""); }
100 | template auto constexpr concat(descr descr) { return descr; }
101 | template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); }
102 | template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); }
103 |
104 | #define PYBIND11_DESCR constexpr auto
105 |
106 | #else /* Simpler C++11 implementation based on run-time memory allocation and copying */
107 |
108 | class descr {
109 | public:
110 | PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) {
111 | size_t nChars = len(text), nTypes = len(types);
112 | m_text = new char[nChars];
113 | m_types = new const std::type_info *[nTypes];
114 | memcpy(m_text, text, nChars * sizeof(char));
115 | memcpy(m_types, types, nTypes * sizeof(const std::type_info *));
116 | }
117 |
118 | PYBIND11_NOINLINE descr operator+(descr &&d2) && {
119 | descr r;
120 |
121 | size_t nChars1 = len(m_text), nTypes1 = len(m_types);
122 | size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types);
123 |
124 | r.m_text = new char[nChars1 + nChars2 - 1];
125 | r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1];
126 | memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char));
127 | memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char));
128 | memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *));
129 | memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *));
130 |
131 | delete[] m_text; delete[] m_types;
132 | delete[] d2.m_text; delete[] d2.m_types;
133 |
134 | return r;
135 | }
136 |
137 | char *text() { return m_text; }
138 | const std::type_info * * types() { return m_types; }
139 |
140 | protected:
141 | PYBIND11_NOINLINE descr() { }
142 |
143 | template static size_t len(const T *ptr) { // return length including null termination
144 | const T *it = ptr;
145 | while (*it++ != (T) 0)
146 | ;
147 | return static_cast(it - ptr);
148 | }
149 |
150 | const std::type_info **m_types = nullptr;
151 | char *m_text = nullptr;
152 | };
153 |
154 | /* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */
155 |
156 | PYBIND11_NOINLINE inline descr _(const char *text) {
157 | const std::type_info *types[1] = { nullptr };
158 | return descr(text, types);
159 | }
160 |
161 | template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); }
162 | template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); }
163 | template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; }
164 | template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; }
165 |
166 | template PYBIND11_NOINLINE descr _() {
167 | const std::type_info *types[2] = { &typeid(Type), nullptr };
168 | return descr("%", types);
169 | }
170 |
171 | template PYBIND11_NOINLINE descr _() {
172 | const std::type_info *types[1] = { nullptr };
173 | return descr(std::to_string(Size).c_str(), types);
174 | }
175 |
176 | PYBIND11_NOINLINE inline descr concat() { return _(""); }
177 | PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; }
178 | template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); }
179 | PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); }
180 |
181 | #define PYBIND11_DESCR ::pybind11::detail::descr
182 | #endif
183 |
184 | NAMESPACE_END(detail)
185 | NAMESPACE_END(pybind11)
186 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/detail/descr.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "common.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 | NAMESPACE_BEGIN(detail)
16 |
17 | #if !defined(_MSC_VER)
18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr
19 | #else
20 | # define PYBIND11_DESCR_CONSTEXPR const
21 | #endif
22 |
23 | /* Concatenate type signatures at compile time */
24 | template
25 | struct descr {
26 | char text[N + 1];
27 |
28 | constexpr descr() : text{'\0'} { }
29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { }
30 |
31 | template
32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { }
33 |
34 | template
35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { }
36 |
37 | static constexpr std::array types() {
38 | return {{&typeid(Ts)..., nullptr}};
39 | }
40 | };
41 |
42 | template
43 | constexpr descr plus_impl(const descr &a, const descr &b,
44 | index_sequence, index_sequence) {
45 | return {a.text[Is1]..., b.text[Is2]...};
46 | }
47 |
48 | template
49 | constexpr descr operator+(const descr &a, const descr &b) {
50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence());
51 | }
52 |
53 | template
54 | constexpr descr _(char const(&text)[N]) { return descr(text); }
55 | constexpr descr<0> _(char const(&)[1]) { return {}; }
56 |
57 | template struct int_to_str : int_to_str { };
58 | template struct int_to_str<0, Digits...> {
59 | static constexpr auto digits = descr(('0' + Digits)...);
60 | };
61 |
62 | // Ternary description (like std::conditional)
63 | template
64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) {
65 | return _(text1);
66 | }
67 | template
68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) {
69 | return _(text2);
70 | }
71 |
72 | template
73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; }
74 | template
75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; }
76 |
77 | template auto constexpr _() -> decltype(int_to_str::digits) {
78 | return int_to_str::digits;
79 | }
80 |
81 | template constexpr descr<1, Type> _() { return {'%'}; }
82 |
83 | constexpr descr<0> concat() { return {}; }
84 |
85 | template
86 | constexpr descr concat(const descr &descr) { return descr; }
87 |
88 | template
89 | constexpr auto concat(const descr &d, const Args &...args)
90 | -> decltype(std::declval>() + concat(args...)) {
91 | return d + _(", ") + concat(args...);
92 | }
93 |
94 | template
95 | constexpr descr type_descr(const descr &descr) {
96 | return _("{") + descr + _("}");
97 | }
98 |
99 | NAMESPACE_END(detail)
100 | NAMESPACE_END(PYBIND11_NAMESPACE)
101 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/detail/internals.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/detail/internals.h: Internal data structure and related functions
3 |
4 | Copyright (c) 2017 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "../pytypes.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 | NAMESPACE_BEGIN(detail)
16 | // Forward declarations
17 | inline PyTypeObject *make_static_property_type();
18 | inline PyTypeObject *make_default_metaclass();
19 | inline PyObject *make_object_base_type(PyTypeObject *metaclass);
20 |
21 | // The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
22 | // Thread Specific Storage (TSS) API.
23 | #if PY_VERSION_HEX >= 0x03070000
24 | # define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
25 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
26 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate))
27 | # define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
28 | #else
29 | // Usually an int but a long on Cygwin64 with Python 3.x
30 | # define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
31 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
32 | # if PY_MAJOR_VERSION < 3
33 | # define PYBIND11_TLS_DELETE_VALUE(key) \
34 | PyThread_delete_key_value(key)
35 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \
36 | do { \
37 | PyThread_delete_key_value((key)); \
38 | PyThread_set_key_value((key), (value)); \
39 | } while (false)
40 | # else
41 | # define PYBIND11_TLS_DELETE_VALUE(key) \
42 | PyThread_set_key_value((key), nullptr)
43 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \
44 | PyThread_set_key_value((key), (value))
45 | # endif
46 | #endif
47 |
48 | // Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
49 | // other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
50 | // even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
51 | // libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
52 | // which works. If not under a known-good stl, provide our own name-based hash and equality
53 | // functions that use the type name.
54 | #if defined(__GLIBCXX__)
55 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
56 | using type_hash = std::hash;
57 | using type_equal_to = std::equal_to;
58 | #else
59 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
60 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
61 | }
62 |
63 | struct type_hash {
64 | size_t operator()(const std::type_index &t) const {
65 | size_t hash = 5381;
66 | const char *ptr = t.name();
67 | while (auto c = static_cast(*ptr++))
68 | hash = (hash * 33) ^ c;
69 | return hash;
70 | }
71 | };
72 |
73 | struct type_equal_to {
74 | bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
75 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
76 | }
77 | };
78 | #endif
79 |
80 | template
81 | using type_map = std::unordered_map;
82 |
83 | struct overload_hash {
84 | inline size_t operator()(const std::pair& v) const {
85 | size_t value = std::hash()(v.first);
86 | value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
87 | return value;
88 | }
89 | };
90 |
91 | /// Internal data structure used to track registered instances and types.
92 | /// Whenever binary incompatible changes are made to this structure,
93 | /// `PYBIND11_INTERNALS_VERSION` must be incremented.
94 | struct internals {
95 | type_map registered_types_cpp; // std::type_index -> pybind11's type information
96 | std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s)
97 | std::unordered_multimap registered_instances; // void * -> instance*
98 | std::unordered_set, overload_hash> inactive_overload_cache;
99 | type_map> direct_conversions;
100 | std::unordered_map> patients;
101 | std::forward_list registered_exception_translators;
102 | std::unordered_map shared_data; // Custom data to be shared across extensions
103 | std::vector loader_patient_stack; // Used by `loader_life_support`
104 | std::forward_list static_strings; // Stores the std::strings backing detail::c_str()
105 | PyTypeObject *static_property_type;
106 | PyTypeObject *default_metaclass;
107 | PyObject *instance_base;
108 | #if defined(WITH_THREAD)
109 | PYBIND11_TLS_KEY_INIT(tstate);
110 | PyInterpreterState *istate = nullptr;
111 | #endif
112 | };
113 |
114 | /// Additional type information which does not fit into the PyTypeObject.
115 | /// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
116 | struct type_info {
117 | PyTypeObject *type;
118 | const std::type_info *cpptype;
119 | size_t type_size, type_align, holder_size_in_ptrs;
120 | void *(*operator_new)(size_t);
121 | void (*init_instance)(instance *, const void *);
122 | void (*dealloc)(value_and_holder &v_h);
123 | std::vector implicit_conversions;
124 | std::vector> implicit_casts;
125 | std::vector *direct_conversions;
126 | buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
127 | void *get_buffer_data = nullptr;
128 | void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
129 | /* A simple type never occurs as a (direct or indirect) parent
130 | * of a class that makes use of multiple inheritance */
131 | bool simple_type : 1;
132 | /* True if there is no multiple inheritance in this type's inheritance tree */
133 | bool simple_ancestors : 1;
134 | /* for base vs derived holder_type checks */
135 | bool default_holder : 1;
136 | /* true if this is a type registered with py::module_local */
137 | bool module_local : 1;
138 | };
139 |
140 | /// Tracks the `internals` and `type_info` ABI version independent of the main library version
141 | #define PYBIND11_INTERNALS_VERSION 3
142 |
143 | #if defined(_DEBUG)
144 | # define PYBIND11_BUILD_TYPE "_debug"
145 | #else
146 | # define PYBIND11_BUILD_TYPE ""
147 | #endif
148 |
149 | #if defined(WITH_THREAD)
150 | # define PYBIND11_INTERNALS_KIND ""
151 | #else
152 | # define PYBIND11_INTERNALS_KIND "_without_thread"
153 | #endif
154 |
155 | #define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
156 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
157 |
158 | #define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
159 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
160 |
161 | /// Each module locally stores a pointer to the `internals` data. The data
162 | /// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
163 | inline internals **&get_internals_pp() {
164 | static internals **internals_pp = nullptr;
165 | return internals_pp;
166 | }
167 |
168 | /// Return a reference to the current `internals` data
169 | PYBIND11_NOINLINE inline internals &get_internals() {
170 | auto **&internals_pp = get_internals_pp();
171 | if (internals_pp && *internals_pp)
172 | return **internals_pp;
173 |
174 | constexpr auto *id = PYBIND11_INTERNALS_ID;
175 | auto builtins = handle(PyEval_GetBuiltins());
176 | if (builtins.contains(id) && isinstance(builtins[id])) {
177 | internals_pp = static_cast(capsule(builtins[id]));
178 |
179 | // We loaded builtins through python's builtins, which means that our `error_already_set`
180 | // and `builtin_exception` may be different local classes than the ones set up in the
181 | // initial exception translator, below, so add another for our local exception classes.
182 | //
183 | // libstdc++ doesn't require this (types there are identified only by name)
184 | #if !defined(__GLIBCXX__)
185 | (*internals_pp)->registered_exception_translators.push_front(
186 | [](std::exception_ptr p) -> void {
187 | try {
188 | if (p) std::rethrow_exception(p);
189 | } catch (error_already_set &e) { e.restore(); return;
190 | } catch (const builtin_exception &e) { e.set_error(); return;
191 | }
192 | }
193 | );
194 | #endif
195 | } else {
196 | if (!internals_pp) internals_pp = new internals*();
197 | auto *&internals_ptr = *internals_pp;
198 | internals_ptr = new internals();
199 | #if defined(WITH_THREAD)
200 | PyEval_InitThreads();
201 | PyThreadState *tstate = PyThreadState_Get();
202 | #if PY_VERSION_HEX >= 0x03070000
203 | internals_ptr->tstate = PyThread_tss_alloc();
204 | if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
205 | pybind11_fail("get_internals: could not successfully initialize the TSS key!");
206 | PyThread_tss_set(internals_ptr->tstate, tstate);
207 | #else
208 | internals_ptr->tstate = PyThread_create_key();
209 | if (internals_ptr->tstate == -1)
210 | pybind11_fail("get_internals: could not successfully initialize the TLS key!");
211 | PyThread_set_key_value(internals_ptr->tstate, tstate);
212 | #endif
213 | internals_ptr->istate = tstate->interp;
214 | #endif
215 | builtins[id] = capsule(internals_pp);
216 | internals_ptr->registered_exception_translators.push_front(
217 | [](std::exception_ptr p) -> void {
218 | try {
219 | if (p) std::rethrow_exception(p);
220 | } catch (error_already_set &e) { e.restore(); return;
221 | } catch (const builtin_exception &e) { e.set_error(); return;
222 | } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
223 | } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
224 | } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
225 | } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
226 | } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
227 | } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
228 | } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
229 | } catch (...) {
230 | PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
231 | return;
232 | }
233 | }
234 | );
235 | internals_ptr->static_property_type = make_static_property_type();
236 | internals_ptr->default_metaclass = make_default_metaclass();
237 | internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
238 | }
239 | return **internals_pp;
240 | }
241 |
242 | /// Works like `internals.registered_types_cpp`, but for module-local registered types:
243 | inline type_map ®istered_local_types_cpp() {
244 | static type_map locals{};
245 | return locals;
246 | }
247 |
248 | /// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
249 | /// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
250 | /// cleared when the program exits or after interpreter shutdown (when embedding), and so are
251 | /// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
252 | template
253 | const char *c_str(Args &&...args) {
254 | auto &strings = get_internals().static_strings;
255 | strings.emplace_front(std::forward(args)...);
256 | return strings.front().c_str();
257 | }
258 |
259 | NAMESPACE_END(detail)
260 |
261 | /// Returns a named pointer that is shared among all extension modules (using the same
262 | /// pybind11 version) running in the current interpreter. Names starting with underscores
263 | /// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
264 | inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
265 | auto &internals = detail::get_internals();
266 | auto it = internals.shared_data.find(name);
267 | return it != internals.shared_data.end() ? it->second : nullptr;
268 | }
269 |
270 | /// Set the shared data that can be later recovered by `get_shared_data()`.
271 | inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
272 | detail::get_internals().shared_data[name] = data;
273 | return data;
274 | }
275 |
276 | /// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
277 | /// such entry exists. Otherwise, a new object of default-constructible type `T` is
278 | /// added to the shared data under the given name and a reference to it is returned.
279 | template
280 | T &get_or_create_shared_data(const std::string &name) {
281 | auto &internals = detail::get_internals();
282 | auto it = internals.shared_data.find(name);
283 | T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
284 | if (!ptr) {
285 | ptr = new T();
286 | internals.shared_data[name] = ptr;
287 | }
288 | return *ptr;
289 | }
290 |
291 | NAMESPACE_END(PYBIND11_NAMESPACE)
292 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/detail/typeid.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include
13 | #include
14 |
15 | #if defined(__GNUG__)
16 | #include
17 | #endif
18 |
19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
20 | NAMESPACE_BEGIN(detail)
21 | /// Erase all occurrences of a substring
22 | inline void erase_all(std::string &string, const std::string &search) {
23 | for (size_t pos = 0;;) {
24 | pos = string.find(search, pos);
25 | if (pos == std::string::npos) break;
26 | string.erase(pos, search.length());
27 | }
28 | }
29 |
30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
31 | #if defined(__GNUG__)
32 | int status = 0;
33 | std::unique_ptr res {
34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
35 | if (status == 0)
36 | name = res.get();
37 | #else
38 | detail::erase_all(name, "class ");
39 | detail::erase_all(name, "struct ");
40 | detail::erase_all(name, "enum ");
41 | #endif
42 | detail::erase_all(name, "pybind11::");
43 | }
44 | NAMESPACE_END(detail)
45 |
46 | /// Return a string representation of a C++ type
47 | template static std::string type_id() {
48 | std::string name(typeid(T).name());
49 | detail::clean_type_id(name);
50 | return name;
51 | }
52 |
53 | NAMESPACE_END(PYBIND11_NAMESPACE)
54 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/embed.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/embed.h: Support for embedding the interpreter
3 |
4 | Copyright (c) 2017 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "pybind11.h"
13 | #include "eval.h"
14 |
15 | #if defined(PYPY_VERSION)
16 | # error Embedding the interpreter is not supported with PyPy
17 | #endif
18 |
19 | #if PY_MAJOR_VERSION >= 3
20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
21 | extern "C" PyObject *pybind11_init_impl_##name() { \
22 | return pybind11_init_wrapper_##name(); \
23 | }
24 | #else
25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
26 | extern "C" void pybind11_init_impl_##name() { \
27 | pybind11_init_wrapper_##name(); \
28 | }
29 | #endif
30 |
31 | /** \rst
32 | Add a new module to the table of builtins for the interpreter. Must be
33 | defined in global scope. The first macro parameter is the name of the
34 | module (without quotes). The second parameter is the variable which will
35 | be used as the interface to add functions and classes to the module.
36 |
37 | .. code-block:: cpp
38 |
39 | PYBIND11_EMBEDDED_MODULE(example, m) {
40 | // ... initialize functions and classes here
41 | m.def("foo", []() {
42 | return "Hello, World!";
43 | });
44 | }
45 | \endrst */
46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \
47 | static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
48 | static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
49 | auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
50 | try { \
51 | PYBIND11_CONCAT(pybind11_init_, name)(m); \
52 | return m.ptr(); \
53 | } catch (pybind11::error_already_set &e) { \
54 | PyErr_SetString(PyExc_ImportError, e.what()); \
55 | return nullptr; \
56 | } catch (const std::exception &e) { \
57 | PyErr_SetString(PyExc_ImportError, e.what()); \
58 | return nullptr; \
59 | } \
60 | } \
61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \
62 | pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
63 | PYBIND11_CONCAT(pybind11_init_impl_, name)); \
64 | void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
65 |
66 |
67 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
68 | NAMESPACE_BEGIN(detail)
69 |
70 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
71 | struct embedded_module {
72 | #if PY_MAJOR_VERSION >= 3
73 | using init_t = PyObject *(*)();
74 | #else
75 | using init_t = void (*)();
76 | #endif
77 | embedded_module(const char *name, init_t init) {
78 | if (Py_IsInitialized())
79 | pybind11_fail("Can't add new modules after the interpreter has been initialized");
80 |
81 | auto result = PyImport_AppendInittab(name, init);
82 | if (result == -1)
83 | pybind11_fail("Insufficient memory to add a new module");
84 | }
85 | };
86 |
87 | NAMESPACE_END(detail)
88 |
89 | /** \rst
90 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be
91 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
92 | optional parameter can be used to skip the registration of signal handlers (see the
93 | `Python documentation`_ for details). Calling this function again after the interpreter
94 | has already been initialized is a fatal error.
95 |
96 | If initializing the Python interpreter fails, then the program is terminated. (This
97 | is controlled by the CPython runtime and is an exception to pybind11's normal behavior
98 | of throwing exceptions on errors.)
99 |
100 | .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
101 | \endrst */
102 | inline void initialize_interpreter(bool init_signal_handlers = true) {
103 | if (Py_IsInitialized())
104 | pybind11_fail("The interpreter is already running");
105 |
106 | Py_InitializeEx(init_signal_handlers ? 1 : 0);
107 |
108 | // Make .py files in the working directory available by default
109 | module::import("sys").attr("path").cast().append(".");
110 | }
111 |
112 | /** \rst
113 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called
114 | after this. In addition, pybind11 objects must not outlive the interpreter:
115 |
116 | .. code-block:: cpp
117 |
118 | { // BAD
119 | py::initialize_interpreter();
120 | auto hello = py::str("Hello, World!");
121 | py::finalize_interpreter();
122 | } // <-- BOOM, hello's destructor is called after interpreter shutdown
123 |
124 | { // GOOD
125 | py::initialize_interpreter();
126 | { // scoped
127 | auto hello = py::str("Hello, World!");
128 | } // <-- OK, hello is cleaned up properly
129 | py::finalize_interpreter();
130 | }
131 |
132 | { // BETTER
133 | py::scoped_interpreter guard{};
134 | auto hello = py::str("Hello, World!");
135 | }
136 |
137 | .. warning::
138 |
139 | The interpreter can be restarted by calling `initialize_interpreter` again.
140 | Modules created using pybind11 can be safely re-initialized. However, Python
141 | itself cannot completely unload binary extension modules and there are several
142 | caveats with regard to interpreter restarting. All the details can be found
143 | in the CPython documentation. In short, not all interpreter memory may be
144 | freed, either due to reference cycles or user-created global data.
145 |
146 | \endrst */
147 | inline void finalize_interpreter() {
148 | handle builtins(PyEval_GetBuiltins());
149 | const char *id = PYBIND11_INTERNALS_ID;
150 |
151 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the
152 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
153 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
154 | detail::internals **internals_ptr_ptr = detail::get_internals_pp();
155 | // It could also be stashed in builtins, so look there too:
156 | if (builtins.contains(id) && isinstance(builtins[id]))
157 | internals_ptr_ptr = capsule(builtins[id]);
158 |
159 | Py_Finalize();
160 |
161 | if (internals_ptr_ptr) {
162 | delete *internals_ptr_ptr;
163 | *internals_ptr_ptr = nullptr;
164 | }
165 | }
166 |
167 | /** \rst
168 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
169 | This a move-only guard and only a single instance can exist.
170 |
171 | .. code-block:: cpp
172 |
173 | #include
174 |
175 | int main() {
176 | py::scoped_interpreter guard{};
177 | py::print(Hello, World!);
178 | } // <-- interpreter shutdown
179 | \endrst */
180 | class scoped_interpreter {
181 | public:
182 | scoped_interpreter(bool init_signal_handlers = true) {
183 | initialize_interpreter(init_signal_handlers);
184 | }
185 |
186 | scoped_interpreter(const scoped_interpreter &) = delete;
187 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
188 | scoped_interpreter &operator=(const scoped_interpreter &) = delete;
189 | scoped_interpreter &operator=(scoped_interpreter &&) = delete;
190 |
191 | ~scoped_interpreter() {
192 | if (is_valid)
193 | finalize_interpreter();
194 | }
195 |
196 | private:
197 | bool is_valid = true;
198 | };
199 |
200 | NAMESPACE_END(PYBIND11_NAMESPACE)
201 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/eval.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/exec.h: Support for evaluating Python expressions and statements
3 | from strings and files
4 |
5 | Copyright (c) 2016 Klemens Morgenstern and
6 | Wenzel Jakob
7 |
8 | All rights reserved. Use of this source code is governed by a
9 | BSD-style license that can be found in the LICENSE file.
10 | */
11 |
12 | #pragma once
13 |
14 | #include "pybind11.h"
15 |
16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
17 |
18 | enum eval_mode {
19 | /// Evaluate a string containing an isolated expression
20 | eval_expr,
21 |
22 | /// Evaluate a string containing a single statement. Returns \c none
23 | eval_single_statement,
24 |
25 | /// Evaluate a string containing a sequence of statement. Returns \c none
26 | eval_statements
27 | };
28 |
29 | template
30 | object eval(str expr, object global = globals(), object local = object()) {
31 | if (!local)
32 | local = global;
33 |
34 | /* PyRun_String does not accept a PyObject / encoding specifier,
35 | this seems to be the only alternative */
36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
37 |
38 | int start;
39 | switch (mode) {
40 | case eval_expr: start = Py_eval_input; break;
41 | case eval_single_statement: start = Py_single_input; break;
42 | case eval_statements: start = Py_file_input; break;
43 | default: pybind11_fail("invalid evaluation mode");
44 | }
45 |
46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
47 | if (!result)
48 | throw error_already_set();
49 | return reinterpret_steal(result);
50 | }
51 |
52 | template
53 | object eval(const char (&s)[N], object global = globals(), object local = object()) {
54 | /* Support raw string literals by removing common leading whitespace */
55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
56 | : str(s);
57 | return eval(expr, global, local);
58 | }
59 |
60 | inline void exec(str expr, object global = globals(), object local = object()) {
61 | eval(expr, global, local);
62 | }
63 |
64 | template
65 | void exec(const char (&s)[N], object global = globals(), object local = object()) {
66 | eval(s, global, local);
67 | }
68 |
69 | template
70 | object eval_file(str fname, object global = globals(), object local = object()) {
71 | if (!local)
72 | local = global;
73 |
74 | int start;
75 | switch (mode) {
76 | case eval_expr: start = Py_eval_input; break;
77 | case eval_single_statement: start = Py_single_input; break;
78 | case eval_statements: start = Py_file_input; break;
79 | default: pybind11_fail("invalid evaluation mode");
80 | }
81 |
82 | int closeFile = 1;
83 | std::string fname_str = (std::string) fname;
84 | #if PY_VERSION_HEX >= 0x03040000
85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r");
86 | #elif PY_VERSION_HEX >= 0x03000000
87 | FILE *f = _Py_fopen(fname.ptr(), "r");
88 | #else
89 | /* No unicode support in open() :( */
90 | auto fobj = reinterpret_steal(PyFile_FromString(
91 | const_cast(fname_str.c_str()),
92 | const_cast("r")));
93 | FILE *f = nullptr;
94 | if (fobj)
95 | f = PyFile_AsFile(fobj.ptr());
96 | closeFile = 0;
97 | #endif
98 | if (!f) {
99 | PyErr_Clear();
100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!");
101 | }
102 |
103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
105 | local.ptr());
106 | (void) closeFile;
107 | #else
108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
109 | local.ptr(), closeFile);
110 | #endif
111 |
112 | if (!result)
113 | throw error_already_set();
114 | return reinterpret_steal(result);
115 | }
116 |
117 | NAMESPACE_END(PYBIND11_NAMESPACE)
118 |
--------------------------------------------------------------------------------
/post_processing/include/pybind11/functional.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/functional.h: std::function<> support
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "pybind11.h"
13 | #include
14 |
15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
16 | NAMESPACE_BEGIN(detail)
17 |
18 | template
19 | struct type_caster> {
20 | using type = std::function;
21 | using retval_type = conditional_t::value, void_type, Return>;
22 | using function_type = Return (*) (Args...);
23 |
24 | public:
25 | bool load(handle src, bool convert) {
26 | if (src.is_none()) {
27 | // Defer accepting None to other overloads (if we aren't in convert mode):
28 | if (!convert) return false;
29 | return true;
30 | }
31 |
32 | if (!isinstance(src))
33 | return false;
34 |
35 | auto func = reinterpret_borrow(src);
36 |
37 | /*
38 | When passing a C++ function as an argument to another C++
39 | function via Python, every function call would normally involve
40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
41 | Here, we try to at least detect the case where the function is
42 | stateless (i.e. function pointer or lambda function without
43 | captured variables), in which case the roundtrip can be avoided.
44 | */
45 | if (auto cfunc = func.cpp_function()) {
46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr()));
47 | auto rec = (function_record *) c;
48 |
49 | if (rec && rec->is_stateless &&
50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) {
51 | struct capture { function_type f; };
52 | value = ((capture *) &rec->data)->f;
53 | return true;
54 | }
55 | }
56 |
57 | value = [func](Args... args) -> Return {
58 | gil_scoped_acquire acq;
59 | object retval(func(std::forward(args)...));
60 | /* Visual studio 2015 parser issue: need parentheses around this expression */
61 | return (retval.template cast());
62 | };
63 | return true;
64 | }
65 |
66 | template
67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
68 | if (!f_)
69 | return none().inc_ref();
70 |
71 | auto result = f_.template target();
72 | if (result)
73 | return cpp_function(*result, policy).release();
74 | else
75 | return cpp_function(std::forward