├── Image_TIA.ipynb ├── LICENSE ├── README.md ├── __pycache__ ├── client.cpython-37.pyc └── test.cpython-37.pyc ├── cdistnet ├── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── data.cpython-37.pyc │ ├── data.py │ ├── hdf5converter.py │ ├── hdf5loader.py │ └── transform.py ├── engine │ ├── __init__.py │ ├── beam_search.py │ └── trainer.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── model.cpython-37.pyc │ │ └── translator.cpython-37.pyc │ ├── blocks.py │ ├── model.py │ ├── stage │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── backbone.cpython-37.pyc │ │ │ ├── multiheadAttention.cpython-37.pyc │ │ │ └── tps.cpython-37.pyc │ │ ├── backbone.py │ │ ├── multiheadAttention.py │ │ └── tps.py │ └── translator.py ├── optim │ ├── __init__.py │ ├── loss.py │ └── optim.py └── utils │ ├── __init__.py │ ├── dict_36.txt │ ├── gen_img.py │ ├── init.py │ ├── submit_with_lexicon.py │ └── tensorboardx.py ├── cdistnet_env.yaml ├── configs ├── CDistNet_config.py └── debug_config.py ├── eval.py ├── requirements.txt ├── test.py ├── train.py └── utils ├── Evaluation_TextRecog ├── constrain_select.py ├── gt.txt ├── readme.txt ├── readme_sff ├── rrc_evaluation_funcs.py ├── rrc_evaluation_funcs.pyc ├── script.py └── submit.txt ├── fig2_00.png ├── fig5_00.png ├── label_proc.py └── warp_mls.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition 2 | 3 | The official code of CDistNet. 4 | 5 | Paper Link : [Arxiv Link](http://arxiv.org/abs/2111.11011) 6 | 7 | ## What's News 8 | 9 | - [2023-08]🌟 Our paper is accepted by IJCV 10 | - [2022-01]🌟 Our code is released in github 11 | - [2021-11]🌟 The paper can be read in Arixv: http://arxiv.org/abs/2111.11011 12 | 13 | ![pipline](./utils/fig2_00.png) 14 | ## To Do List 15 | * [x] HA-IC13 & CA-IC13 16 | * [x] Pre-train model 17 | * [x] Cleaned Code 18 | * [ ] Document 19 | * [ ] Distributed Training 20 | 21 | ## Two New Datasets 22 | we test other sota method in HA-IC13 and CA-IC13 datasets. 23 | 24 | ![HA_CA](./utils/fig5_00.png) 25 | **CDistNet has a performance advantage over other SOTA methods as the character distance increases (1-6)** 26 | ### HA-IC13 27 | |Method |1 | 2 | 3 | 4 | 5 | 6 | Code & Pretrain model| 28 | |- | - | - | - | - | - | - | - | 29 | |VisionLAN (ICCV 2021) | 93.58 | 92.88 | 89.97 | 82.26 | 72.23 | 61.03 | [Offical Code](https://github.com/wangyuxin87/VisionLAN)| 30 | |ABINet (CVPR 2021 ) | 95.92 |95.22 | 91.95 | 85.76 | 73.75 | 64.99 | [Offical Code](https://github.com/FangShancheng/ABINet)| 31 | |RobustScanner* (ECCV 2020) | 96.15 | 95.33 | 93.23 | 88.91 | 81.10 |71.53 | -- | 32 | | Transformer-baseline* | 96.27 | 95.45 | 92.42 | 86.46 | 79.35 | 72.46 | -- | 33 | |CDistNet |**96.62**| **96.15** | **94.28** | **89.96** | **83.43** | **77.71** | -- | 34 | 35 | ### CA-IC13 36 | |Method |1 | 2 | 3 | 4 | 5 | 6 | Code & Pretrain model| 37 | |- | - | - | - | - | - | - | - | 38 | |VisionLAN (ICCV 2021) | 94.87 | 92.77 | 84.01 | 75.03 | 64.29 | 52.74 | [Offical Code](https://github.com/wangyuxin87/VisionLAN)| 39 | |ABINet (CVPR 2021 ) | **96.62** | **95.92** | 87.86 |76.31 | 65.46 | 54.49 | [Offical Code](https://github.com/FangShancheng/ABINet)| 40 | |RobustScanner* (ECCV 2020) | 95.22 | 94.87 | 85.30 | 76.55 | 68.38 |60.79 | -- | 41 | | Transformer-baseline* | 95.68 | 94.40 | 85.88 | 75.85 | 65.93 | 58.58 | -- | 42 | |CDistNet | 96.27 | 95.57 | **88.45** | **79.58** | **70.36** | **63.13** | -- | 43 | 44 | 45 | ## Datasets 46 | **The datasets are same as ABINet** 47 | - Training datasets 48 | 49 | 1. [MJSynth](http://www.robots.ox.ac.uk/~vgg/data/text/) (MJ): 50 | - [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ) 51 | 2. [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) (ST): 52 | - [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ) 53 | 54 | - Evaluation & Test datasets, LMDB datasets can be downloaded from [BaiduNetdisk(passwd:1dbv)](https://pan.baidu.com/s/1RUg3Akwp7n8kZYJ55rU5LQ), [GoogleDrive](https://drive.google.com/file/d/1dTI0ipu14Q1uuK4s4z32DqbqF3dJPdkk/view?usp=sharing). 55 | 1. ICDAR 2013 (IC13) 56 | 2. ICDAR 2015 (IC15) 57 | 3. IIIT5K Words (IIIT) 58 | 4. Street View Text (SVT) 59 | 5. Street View Text-Perspective (SVTP) 60 | 6. CUTE80 (CUTE) 61 | 62 | - Augment IC13 63 | - HA-IC13 & CA-IC13 : [BaiduNetdisk(passwd:d6jd)](https://pan.baidu.com/s/1s0oNmd5jQJCvoH1efjfBdg), [GoogleDrive](https://drive.google.com/drive/folders/1PTPFjDdx2Ky0KsZdgn0p9x5fqyrdxKWF?usp=sharing) 64 | 65 | - The structure of `dataset` directory is 66 | ``` 67 | dataset 68 | ├── eval 69 | │   ├── CUTE80 70 | │   ├── IC13_857 71 | │   ├── IC15_1811 72 | │   ├── IIIT5k_3000 73 | │   ├── SVT 74 | │   └── SVTP 75 | ├── train 76 | │   ├── MJ 77 | │   │   ├── MJ_test 78 | │   │   ├── MJ_train 79 | │   │   └── MJ_valid 80 | │   └── ST 81 | ``` 82 | ## Environment 83 | package you can find in `env_cdistnet.yaml`. 84 | ``` 85 | #Installed 86 | conda create -n CDistNet python=3.7 87 | conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=9.2 -c pytorch 88 | pip install opencv-python mmcv notebook numpy einops tensorboardX Pillow thop timm tornado tqdm matplotlib lmdb 89 | ``` 90 | ## Pretrained Models 91 | 92 | Get the pretrained models from [BaiduNetdisk(passwd:d6jd)](https://pan.baidu.com/s/1s0oNmd5jQJCvoH1efjfBdg), [GoogleDrive](https://drive.google.com/drive/folders/1PTPFjDdx2Ky0KsZdgn0p9x5fqyrdxKWF?usp=sharing). 93 | (We both offer training log and result.csv in same file.) 94 | The pretrained model should set in `models/reconstruct_CDistNetv3_3_10` 95 | 96 | Performances of the pretrained models are summaried as follows: 97 | 98 | [comment]: <> (|Model|GPUs|IC13|SVT|IIIT|IC15|SVTP|CUTE|AVG|) 99 | 100 | [comment]: <> (|-|-|-|-|-|-|-|-|-|) 101 | 102 | [comment]: <> (|CDistNet(paper)|6|97.67|93.82|96.57|86.25|89.77|89.58|92.28|) 103 | 104 | [comment]: <> (|CDistNet(rebuild)|4|97.43|93.51|96.37|86.03|88.68|93.4|92.57|) 105 | 106 | ## Train 107 | `CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --config=configs/CDistNet_config.py` 108 | ## Eval 109 | `CUDA_VISIBLE_DEVICES=0 python eval.py --config=configs/CDistNet_config.py` 110 | ## Citation 111 | ```bash 112 | @article{Zheng2021CDistNetPM, 113 | title={CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition}, 114 | author={Tianlun Zheng and Zhineng Chen and Shancheng Fang and Hongtao Xie and Yu-Gang Jiang}, 115 | journal={ArXiv}, 116 | year={2021}, 117 | volume={abs/2111.11011} 118 | } 119 | ``` 120 | -------------------------------------------------------------------------------- /__pycache__/client.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/__pycache__/client.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/__init__.py: -------------------------------------------------------------------------------- 1 | def test(): 2 | return None -------------------------------------------------------------------------------- /cdistnet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/data/__init__.py -------------------------------------------------------------------------------- /cdistnet/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/data/__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/data/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/data/hdf5converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import glob 4 | import numpy as np 5 | import h5py 6 | import cv2 7 | import codecs 8 | from tqdm import tqdm 9 | from PIL import Image, ImageFile 10 | 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | 14 | def load_vocab(vocab=None, vocab_size=None): 15 | """ 16 | Load vocab from disk. The fisrt four items in the vocab should be , , , 17 | """ 18 | # print('Load set vocabularies as %s.' % vocab) 19 | vocab = [' ' if len(line.split()) == 0 else line.split()[0] for line in codecs.open(vocab, 'r', 'utf-8')] 20 | vocab = vocab[:vocab_size] 21 | assert len(vocab) == vocab_size 22 | word2idx = {word: idx for idx, word in enumerate(vocab)} 23 | idx2word = {idx: word for idx, word in enumerate(vocab)} 24 | return word2idx, idx2word 25 | 26 | 27 | def get_train_addrs(hdf5_path, img_path, gt_path, en_vocab_path, max_text_len, dict_size): 28 | addrs, labels = [], [] 29 | word2idx, idx2word = load_vocab(en_vocab_path, dict_size) 30 | 31 | with open(gt_path, 'r', encoding='UTF-8-sig') as f: 32 | all = f.readlines() 33 | max_len = -1 34 | for each in tqdm(all): 35 | each = each.strip().split(' ') 36 | path = os.path.join(img_path, each[0]) 37 | text = " ".join(each[1:]) 38 | # text = [word2idx.get(ch.lower(), 1) for ch in text] 39 | text = [word2idx.get(ch, 1) for ch in text] 40 | text.insert(0, 2) 41 | text.append(3) 42 | max_len = max(max_len, len(text)) 43 | text = np.array(text) 44 | text = np.pad(text, (0, max_text_len - text.size), 'constant') 45 | labels.append(text) 46 | addrs.append(path) 47 | # print(max_len) 48 | 49 | c = list(zip(addrs, labels)) 50 | addrs, labels = zip(*c) 51 | train_addrs = addrs 52 | train_labels = labels 53 | return hdf5_path, train_addrs, train_labels 54 | 55 | 56 | def create_hdf5_file(hdf5_path, train_addrs, train_labels, keep_aspect_ratio=False, height=32, max_width=180, max_text_len=35): 57 | train_shape = (len(train_addrs), 1, height, max_width if keep_aspect_ratio else 100) 58 | hdf5_file = h5py.File(hdf5_path, mode='w') 59 | hdf5_file.create_dataset("image", train_shape, np.float32) 60 | hdf5_file.create_dataset("label", (len(train_addrs), max_text_len), np.int) 61 | hdf5_file["label"][...] = train_labels 62 | return hdf5_file, train_shape 63 | 64 | 65 | def load_and_save_image(train_addrs, hdf5_file, keep_aspect_ratio=False, max_width=180, height=32): 66 | for i in tqdm(range(len(train_addrs))): 67 | addr = train_addrs[i] 68 | try: 69 | if keep_aspect_ratio: 70 | img = Image.open(addr).convert('L') 71 | # img = cv2.imread(addr, cv2.IMREAD_GRAYSCALE) 72 | h, w = img.height, img.width 73 | # h, w = img.shape 74 | r = w * 1.0 / h 75 | r_h, r_w = height, min(max(height * r, height), max_width) 76 | 77 | img = img.resize((int(r_w), r_h), Image.ANTIALIAS) 78 | # img = cv2.resize(img, (int(r_w), r_h), interpolation=cv2.INTER_CUBIC) 79 | img = np.array(img, dtype=np.uint8) 80 | img = np.expand_dims(img, -1) 81 | img = img.transpose((2, 0, 1)) 82 | img = img.astype(np.float32) / 128. - 1. 83 | d = max_width - img.shape[-1] 84 | img = np.pad(img, ((0, 0), (0, 0), (0, d)), 'constant') 85 | else: 86 | img = Image.open(addr).convert('L').resize((100, 32), Image.ANTIALIAS) 87 | img = np.array(img, dtype=np.uint8) 88 | # img = cv2.imread(addr, cv2.IMREAD_GRAYSCALE) 89 | # img = cv2.resize(img, (100, 32), interpolation=cv2.INTER_CUBIC) 90 | img = np.expand_dims(img, -1) 91 | img = img.transpose((2, 0, 1)) 92 | img = img.astype(np.float32) / 128. - 1. 93 | except: 94 | print(addr) 95 | img = np.zeros((1, height, max_width if keep_aspect_ratio else 100), dtype=np.float32) 96 | hdf5_file["image"][i, ...] = img[None] 97 | hdf5_file.close() 98 | 99 | 100 | def main(): 101 | parser = argparse.ArgumentParser(description='Train NRTR') 102 | parser.add_argument('--hdf5_path', type=str, default='') 103 | parser.add_argument('--img_path', type=str, default='') 104 | parser.add_argument('--gt_path', type=str, default='') 105 | parser.add_argument('--keep_aspect_ratio', action='store_true') 106 | parser.add_argument('--max_width', type=int, default=180) 107 | parser.add_argument('--height', type=int, default=32) 108 | parser.add_argument('--en_vocab_path', type=str, default='') 109 | parser.add_argument('--max_text_len', type=int, default=35) 110 | parser.add_argument('--dict_size', type=int, default=40) 111 | args = parser.parse_args() 112 | # hdf5_path = '../datasets/train_two_keep.hdf5' 113 | # img_path = '/home/zhengsheng/dataset/reg' 114 | # gt_path = '/home/zhengsheng/dataset/reg/annotation_train_clean.txt' 115 | # hdf5_path = '../datasets/train_two_keep_aspect_ratio.hdf5' 116 | # img_path = '../datasets/image' 117 | # gt_path = '../datasets/gt/new_gt.txt' 118 | # keep_aspect_ratio = True 119 | # max_width = 180 120 | # height = 32 121 | # en_vocab_path = '/home/zs/zs/code/NRTR/datasets/en_vocab' 122 | 123 | hdf5_path = args.hdf5_path 124 | img_path = args.img_path 125 | gt_path = args.gt_path 126 | keep_aspect_ratio = args.keep_aspect_ratio 127 | height = args.height 128 | max_width = args.max_width 129 | en_vocab_path = args.en_vocab_path 130 | max_text_len = args.max_text_len 131 | dict_size = args.dict_size 132 | print("hdf5_path: ", hdf5_path) 133 | print("img_path: ", img_path) 134 | print("gt_path: ", gt_path) 135 | print("keep_aspect_ratio: ", keep_aspect_ratio) 136 | print("height: ", height) 137 | print("max_width:", max_width) 138 | print("en_vocab_path: ", en_vocab_path) 139 | print("max_text_len: ", max_text_len) 140 | print("dict_size: ", dict_size) 141 | hdf5_path, train_addrs, train_labels = get_train_addrs(hdf5_path, img_path, gt_path, en_vocab_path, max_text_len, dict_size) 142 | hdf5_file, train_shape = create_hdf5_file(hdf5_path, train_addrs, train_labels, keep_aspect_ratio, height, max_width, max_text_len) 143 | load_and_save_image(train_addrs, hdf5_file, keep_aspect_ratio, max_width, height) 144 | 145 | 146 | # def test(): 147 | # hdf5_file = h5py.File("/home/psdz/datasets/train_three_and_chinese.hdf5", "r") 148 | # print(len(hdf5_file['label'])) 149 | # for i in range(len(hdf5_file['label'])): 150 | # label = hdf5_file['label'][i] 151 | # print(label) 152 | # if i == 100: 153 | # break 154 | # # Image.fromarray(hdf5_file['image'][i].astype('uint8')).save('./tmp.jpg') 155 | # # print(hdf5_file['flag'][i]) 156 | # # break 157 | 158 | 159 | if __name__ == '__main__': 160 | main() -------------------------------------------------------------------------------- /cdistnet/data/hdf5loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import codecs 4 | import pickle 5 | import h5py 6 | import numpy as np 7 | from PIL import Image 8 | from PIL import ImageFile 9 | import argparse 10 | from tqdm import tqdm 11 | from mmcv import Config 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import transforms 15 | from prefetch_generator import BackgroundGenerator 16 | 17 | 18 | class DataLoaderX(DataLoader): 19 | def __iter__(self): 20 | return BackgroundGenerator(super().__iter__()) 21 | 22 | 23 | class NRTRDataset_hdf5(Dataset): 24 | def __init__(self, hdf5_file, transform=None): 25 | self.data = dict() 26 | self._transform = transform 27 | self.hdf5_file = hdf5_file 28 | 29 | def __len__(self): 30 | with h5py.File(self.hdf5_file, 'r') as data: 31 | lens = len(data['label']) 32 | return lens 33 | 34 | def __getitem__(self, idx): 35 | with h5py.File(self.hdf5_file, 'r') as data: 36 | image = data['image'][idx] 37 | image = torch.from_numpy(image) 38 | image = image.to(torch.float32) 39 | target = data['label'][idx] 40 | target = torch.from_numpy(target) 41 | target = target.to(torch.int64) 42 | return image, target 43 | 44 | 45 | def make_data_loader(cfg, is_train=True): 46 | dataset = NRTRDataset_hdf5( 47 | hdf5_file=cfg.train.hdf5 if is_train else cfg.val.hdf5, 48 | ) 49 | dataloader = DataLoaderX( 50 | dataset=dataset, 51 | batch_size=cfg.train.batch_size if is_train else cfg.val.batch_size, 52 | shuffle=True if is_train else False, 53 | num_workers=cfg.train.num_worker if is_train else cfg.val.num_worker, 54 | pin_memory=False, 55 | ) 56 | return dataloader 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser(description='Train NRTR') 61 | parser.add_argument('--config', type=str, help='train config file path') 62 | args = parser.parse_args() 63 | cfg = Config.fromfile(args.config) 64 | data_loader = make_data_loader(cfg) 65 | for _ in range(1): 66 | for idx, batch in enumerate(tqdm(data_loader)): 67 | print(batch[0].shape) 68 | # print(batch[1].shape) 69 | # if idx == 10: 70 | # break 71 | # print(image.shape) 72 | # print(batch) 73 | 74 | -------------------------------------------------------------------------------- /cdistnet/data/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision import transforms 9 | from torchvision.transforms import Compose 10 | 11 | 12 | def sample_asym(magnitude, size=None): 13 | return np.random.beta(1, 4, size) * magnitude 14 | 15 | 16 | def sample_sym(magnitude, size=None): 17 | return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude 18 | 19 | 20 | def sample_uniform(low, high, size=None): 21 | return np.random.uniform(low, high, size=size) 22 | 23 | 24 | def get_interpolation(type='random'): 25 | if type == 'random': 26 | choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA] 27 | interpolation = choice[random.randint(0, len(choice) - 1)] 28 | elif type == 'nearest': 29 | interpolation = cv2.INTER_NEAREST 30 | elif type == 'linear': 31 | interpolation = cv2.INTER_LINEAR 32 | elif type == 'cubic': 33 | interpolation = cv2.INTER_CUBIC 34 | elif type == 'area': 35 | interpolation = cv2.INTER_AREA 36 | else: 37 | raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!') 38 | return interpolation 39 | 40 | 41 | class CVRandomRotation(object): 42 | def __init__(self, degrees=15): 43 | assert isinstance(degrees, numbers.Number), "degree should be a single number." 44 | assert degrees >= 0, "degree must be positive." 45 | self.degrees = degrees 46 | 47 | @staticmethod 48 | def get_params(degrees): 49 | return sample_sym(degrees) 50 | 51 | def __call__(self, img): 52 | angle = self.get_params(self.degrees) 53 | src_h, src_w = img.shape[:2] 54 | M = cv2.getRotationMatrix2D(center=(src_w / 2, src_h / 2), angle=angle, scale=1.0) 55 | abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1]) 56 | dst_w = int(src_h * abs_sin + src_w * abs_cos) 57 | dst_h = int(src_h * abs_cos + src_w * abs_sin) 58 | M[0, 2] += (dst_w - src_w) / 2 59 | M[1, 2] += (dst_h - src_h) / 2 60 | 61 | flags = get_interpolation() 62 | return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) 63 | 64 | 65 | class CVRandomAffine(object): 66 | def __init__(self, degrees, translate=None, scale=None, shear=None): 67 | assert isinstance(degrees, numbers.Number), "degree should be a single number." 68 | assert degrees >= 0, "degree must be positive." 69 | self.degrees = degrees 70 | 71 | if translate is not None: 72 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 73 | "translate should be a list or tuple and it must be of length 2." 74 | for t in translate: 75 | if not (0.0 <= t <= 1.0): 76 | raise ValueError("translation values should be between 0 and 1") 77 | self.translate = translate 78 | 79 | if scale is not None: 80 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 81 | "scale should be a list or tuple and it must be of length 2." 82 | for s in scale: 83 | if s <= 0: 84 | raise ValueError("scale values should be positive") 85 | self.scale = scale 86 | 87 | if shear is not None: 88 | if isinstance(shear, numbers.Number): 89 | if shear < 0: 90 | raise ValueError("If shear is a single number, it must be positive.") 91 | self.shear = [shear] 92 | else: 93 | assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \ 94 | "shear should be a list or tuple and it must be of length 2." 95 | self.shear = shear 96 | else: 97 | self.shear = shear 98 | 99 | def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear): 100 | # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717 101 | from numpy import sin, cos, tan 102 | 103 | if isinstance(shear, numbers.Number): 104 | shear = [shear, 0] 105 | 106 | if not isinstance(shear, (tuple, list)) and len(shear) == 2: 107 | raise ValueError( 108 | "Shear should be a single value or a tuple/list containing " + 109 | "two values. Got {}".format(shear)) 110 | 111 | rot = math.radians(angle) 112 | sx, sy = [math.radians(s) for s in shear] 113 | 114 | cx, cy = center 115 | tx, ty = translate 116 | 117 | # RSS without scaling 118 | a = cos(rot - sy) / cos(sy) 119 | b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) 120 | c = sin(rot - sy) / cos(sy) 121 | d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) 122 | 123 | # Inverted rotation matrix with scale and shear 124 | # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 125 | M = [d, -b, 0, 126 | -c, a, 0] 127 | M = [x / scale for x in M] 128 | 129 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 130 | M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) 131 | M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) 132 | 133 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 134 | M[2] += cx 135 | M[5] += cy 136 | return M 137 | 138 | @staticmethod 139 | def get_params(degrees, translate, scale_ranges, shears, height): 140 | angle = sample_sym(degrees) 141 | if translate is not None: 142 | max_dx = translate[0] * height 143 | max_dy = translate[1] * height 144 | translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy))) 145 | else: 146 | translations = (0, 0) 147 | 148 | if scale_ranges is not None: 149 | scale = sample_uniform(scale_ranges[0], scale_ranges[1]) 150 | else: 151 | scale = 1.0 152 | 153 | if shears is not None: 154 | if len(shears) == 1: 155 | shear = [sample_sym(shears[0]), 0.] 156 | elif len(shears) == 2: 157 | shear = [sample_sym(shears[0]), sample_sym(shears[1])] 158 | else: 159 | shear = 0.0 160 | 161 | return angle, translations, scale, shear 162 | 163 | def __call__(self, img): 164 | src_h, src_w = img.shape[:2] 165 | angle, translate, scale, shear = self.get_params( 166 | self.degrees, self.translate, self.scale, self.shear, src_h) 167 | 168 | M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle, (0, 0), scale, shear) 169 | M = np.array(M).reshape(2, 3) 170 | 171 | startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)] 172 | project = lambda x, y, a, b, c: int(a * x + b * y + c) 173 | endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints] 174 | 175 | rect = cv2.minAreaRect(np.array(endpoints)) 176 | bbox = cv2.boxPoints(rect).astype(dtype=np.int) 177 | max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() 178 | min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() 179 | 180 | dst_w = int(max_x - min_x) 181 | dst_h = int(max_y - min_y) 182 | M[0, 2] += (dst_w - src_w) / 2 183 | M[1, 2] += (dst_h - src_h) / 2 184 | 185 | # add translate 186 | dst_w += int(abs(translate[0])) 187 | dst_h += int(abs(translate[1])) 188 | if translate[0] < 0: M[0, 2] += abs(translate[0]) 189 | if translate[1] < 0: M[1, 2] += abs(translate[1]) 190 | 191 | flags = get_interpolation() 192 | return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) 193 | 194 | 195 | class CVRandomPerspective(object): 196 | def __init__(self, distortion=0.5): 197 | self.distortion = distortion 198 | 199 | def get_params(self, width, height, distortion): 200 | offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int) 201 | offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int) 202 | topleft = (offset_w[0], offset_h[0]) 203 | topright = (width - 1 - offset_w[1], offset_h[1]) 204 | botright = (width - 1 - offset_w[2], height - 1 - offset_h[2]) 205 | botleft = (offset_w[3], height - 1 - offset_h[3]) 206 | 207 | startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] 208 | endpoints = [topleft, topright, botright, botleft] 209 | return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32) 210 | 211 | def __call__(self, img): 212 | height, width = img.shape[:2] 213 | startpoints, endpoints = self.get_params(width, height, self.distortion) 214 | M = cv2.getPerspectiveTransform(startpoints, endpoints) 215 | 216 | # TODO: more robust way to crop image 217 | rect = cv2.minAreaRect(endpoints) 218 | bbox = cv2.boxPoints(rect).astype(dtype=np.int) 219 | max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() 220 | min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() 221 | min_x, min_y = max(min_x, 0), max(min_y, 0) 222 | 223 | flags = get_interpolation() 224 | img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE) 225 | img = img[min_y:, min_x:] 226 | return img 227 | 228 | 229 | class CVRescale(object): 230 | 231 | def __init__(self, factor=4, base_size=(128, 512)): 232 | """ Define image scales using gaussian pyramid and rescale image to target scale. 233 | 234 | Args: 235 | factor: the decayed factor from base size, factor=4 keeps target scale by default. 236 | base_size: base size the build the bottom layer of pyramid 237 | """ 238 | if isinstance(factor, numbers.Number): 239 | self.factor = round(sample_uniform(0, factor)) 240 | elif isinstance(factor, (tuple, list)) and len(factor) == 2: 241 | self.factor = round(sample_uniform(factor[0], factor[1])) 242 | else: 243 | raise Exception('factor must be number or list with length 2') 244 | # assert factor is valid 245 | self.base_h, self.base_w = base_size[:2] 246 | 247 | def __call__(self, img): 248 | if self.factor == 0: return img 249 | src_h, src_w = img.shape[:2] 250 | cur_w, cur_h = self.base_w, self.base_h 251 | scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation()) 252 | for _ in range(self.factor): 253 | scale_img = cv2.pyrDown(scale_img) 254 | scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation()) 255 | return scale_img 256 | 257 | 258 | class CVGaussianNoise(object): 259 | def __init__(self, mean=0, var=20): 260 | self.mean = mean 261 | if isinstance(var, numbers.Number): 262 | self.var = max(int(sample_asym(var)), 1) 263 | elif isinstance(var, (tuple, list)) and len(var) == 2: 264 | self.var = int(sample_uniform(var[0], var[1])) 265 | else: 266 | raise Exception('degree must be number or list with length 2') 267 | 268 | def __call__(self, img): 269 | noise = np.random.normal(self.mean, self.var ** 0.5, img.shape) 270 | img = np.clip(img + noise, 0, 255).astype(np.uint8) 271 | return img 272 | 273 | 274 | class CVMotionBlur(object): 275 | def __init__(self, degrees=12, angle=90): 276 | if isinstance(degrees, numbers.Number): 277 | self.degree = max(int(sample_asym(degrees)), 1) 278 | elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: 279 | self.degree = int(sample_uniform(degrees[0], degrees[1])) 280 | else: 281 | raise Exception('degree must be number or list with length 2') 282 | self.angle = sample_uniform(-angle, angle) 283 | 284 | def __call__(self, img): 285 | M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1) 286 | motion_blur_kernel = np.zeros((self.degree, self.degree)) 287 | motion_blur_kernel[self.degree // 2, :] = 1 288 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree)) 289 | motion_blur_kernel = motion_blur_kernel / self.degree 290 | img = cv2.filter2D(img, -1, motion_blur_kernel) 291 | img = np.clip(img, 0, 255).astype(np.uint8) 292 | return img 293 | 294 | 295 | class CVGeometry(object): 296 | def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.), 297 | shear=(45, 15), distortion=0.5, p=0.5): 298 | self.p = p 299 | type_p = random.random() 300 | if type_p < 0.33: 301 | self.transforms = CVRandomRotation(degrees=degrees) 302 | elif type_p < 0.66: 303 | self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) 304 | else: 305 | self.transforms = CVRandomPerspective(distortion=distortion) 306 | 307 | def __call__(self, img): 308 | if random.random() < self.p: 309 | img = np.array(img) 310 | return Image.fromarray(self.transforms(img)) 311 | else: 312 | return img 313 | 314 | 315 | class CVDeterioration(object): 316 | def __init__(self, var, degrees, factor, p=0.5): 317 | self.p = p 318 | transforms = [] 319 | if var is not None: 320 | transforms.append(CVGaussianNoise(var=var)) 321 | if degrees is not None: 322 | transforms.append(CVMotionBlur(degrees=degrees)) 323 | if factor is not None: 324 | transforms.append(CVRescale(factor=factor)) 325 | 326 | random.shuffle(transforms) 327 | transforms = Compose(transforms) 328 | self.transforms = transforms 329 | 330 | def __call__(self, img): 331 | if random.random() < self.p: 332 | img = np.array(img) 333 | return Image.fromarray(self.transforms(img)) 334 | else: 335 | return img 336 | 337 | 338 | class CVColorJitter(object): 339 | def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5): 340 | self.p = p 341 | self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast, 342 | saturation=saturation, hue=hue) 343 | 344 | def __call__(self, img): 345 | if random.random() < self.p: 346 | return self.transforms(img) 347 | else: 348 | return img 349 | -------------------------------------------------------------------------------- /cdistnet/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/engine/__init__.py -------------------------------------------------------------------------------- /cdistnet/engine/beam_search.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | import numpy as np 4 | 5 | 6 | class BeamEntry: 7 | "information about one single beam at specific time-step" 8 | def __init__(self): 9 | self.prTotal = 0 # blank and non-blank 10 | self.prNonBlank = 0 # non-blank 11 | self.prBlank = 0 # blank 12 | self.prText = 1 # LM score 13 | self.lmApplied = False # flag if LM was already applied to this beam 14 | self.labeling = () # beam-labeling 15 | 16 | 17 | class BeamState: 18 | "information about the beams at specific time-step" 19 | def __init__(self): 20 | self.entries = {} 21 | 22 | def norm(self): 23 | "length-normalise LM score" 24 | for (k, _) in self.entries.items(): 25 | labelingLen = len(self.entries[k].labeling) 26 | self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) 27 | 28 | def sort(self): 29 | "return beam-labelings, sorted by probability" 30 | beams = [v for (_, v) in self.entries.items()] 31 | sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) 32 | return [x.labeling for x in sortedBeams] 33 | 34 | 35 | def applyLM(parentBeam, childBeam, classes, lm): 36 | "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars" 37 | if lm and not childBeam.lmApplied: 38 | c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char 39 | c2 = classes[childBeam.labeling[-1]] # second char 40 | lmFactor = 0.01 # influence of language model 41 | bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other 42 | childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence 43 | childBeam.lmApplied = True # only apply LM once per beam entry 44 | 45 | 46 | def addBeam(beamState, labeling): 47 | "add beam if it does not yet exist" 48 | if labeling not in beamState.entries: 49 | beamState.entries[labeling] = BeamEntry() 50 | 51 | 52 | def ctcBeamSearch(mat, classes, lm, beamWidth=10): 53 | "beam search as described by the paper of Hwang et al. and the paper of Graves et al." 54 | 55 | blankIdx = len(classes) 56 | maxT, maxC = mat.shape 57 | 58 | # initialise beam state 59 | last = BeamState() 60 | labeling = () 61 | last.entries[labeling] = BeamEntry() 62 | last.entries[labeling].prBlank = 1 63 | last.entries[labeling].prTotal = 1 64 | 65 | # go over all time-steps 66 | for t in range(maxT): 67 | curr = BeamState() 68 | 69 | # get beam-labelings of best beams 70 | bestLabelings = last.sort()[0:beamWidth] 71 | 72 | # go over best beams 73 | for labeling in bestLabelings: 74 | 75 | # probability of paths ending with a non-blank 76 | prNonBlank = 0 77 | # in case of non-empty beam 78 | if labeling: 79 | # probability of paths with repeated last char at the end 80 | prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] 81 | 82 | # probability of paths ending with a blank 83 | prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] 84 | 85 | # add beam at current time-step if needed 86 | addBeam(curr, labeling) 87 | 88 | # fill in data 89 | curr.entries[labeling].labeling = labeling 90 | curr.entries[labeling].prNonBlank += prNonBlank 91 | curr.entries[labeling].prBlank += prBlank 92 | curr.entries[labeling].prTotal += prBlank + prNonBlank 93 | curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from 94 | curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling 95 | 96 | # extend current beam-labeling 97 | for c in range(maxC - 1): 98 | # add new char to current beam-labeling 99 | newLabeling = labeling + (c,) 100 | 101 | # if new labeling contains duplicate char at the end, only consider paths ending with a blank 102 | if labeling and labeling[-1] == c: 103 | prNonBlank = mat[t, c] * last.entries[labeling].prBlank 104 | else: 105 | prNonBlank = mat[t, c] * last.entries[labeling].prTotal 106 | 107 | # add beam at current time-step if needed 108 | addBeam(curr, newLabeling) 109 | 110 | # fill in data 111 | curr.entries[newLabeling].labeling = newLabeling 112 | curr.entries[newLabeling].prNonBlank += prNonBlank 113 | curr.entries[newLabeling].prTotal += prNonBlank 114 | 115 | # apply LM 116 | applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm) 117 | 118 | # set new beam state 119 | last = curr 120 | 121 | # normalise LM scores according to beam-labeling-length 122 | last.norm() 123 | 124 | # sort by probability 125 | bestLabeling = last.sort()[0] # get most probable labeling 126 | 127 | # map labels to chars 128 | res = '' 129 | for l in bestLabeling: 130 | res += classes[l] 131 | 132 | return res 133 | 134 | 135 | def testBeamSearch(): 136 | "test decoder" 137 | classes = 'ab' 138 | mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]]) 139 | print('Test beam search') 140 | expected = 'a' 141 | actual = ctcBeamSearch(mat, classes, None) 142 | print('Expected: "' + expected + '"') 143 | print('Actual: "' + actual + '"') 144 | print('OK' if expected == actual else 'ERROR') 145 | 146 | 147 | if __name__ == '__main__': 148 | testBeamSearch() -------------------------------------------------------------------------------- /cdistnet/engine/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import time 4 | import datetime 5 | from cdistnet.optim.loss import cal_performance 6 | import os 7 | # os.environ['CUDA_VISIBLE_DEVICES']="0,1,2,3" 8 | 9 | def train(model, 10 | train_data_loader, 11 | val_data_loader, 12 | optimizer, 13 | device, 14 | epoch, 15 | logger, 16 | meter, 17 | save_iter, 18 | display_iter, 19 | tfboard_iter, 20 | eval_iter, 21 | model_dir, 22 | label_smoothing, 23 | grads_clip, 24 | cfg, 25 | best_eval, 26 | best_epoch, 27 | best_iteration): 28 | model.train() 29 | total_loss = 0 30 | n_word_total = 0 31 | n_word_correct = 0 32 | 33 | max_iter = len(train_data_loader) 34 | end = time.time() 35 | total_time = 0. 36 | count = 0 37 | for iteration, batch in enumerate(train_data_loader, 0): 38 | if not batch: 39 | print('Error') 40 | continue 41 | meter.update_iter(max_iter * epoch + iteration) 42 | 43 | if cfg.train_method=='dist': 44 | images = batch[0].cuda(device,non_blocking=True) 45 | tgt = batch[1].cuda(device,non_blocking=True) 46 | else: 47 | images = batch[0].to(device) 48 | tgt = batch[1].to(device) 49 | 50 | optimizer.zero_grad() 51 | pred = model(images, tgt) 52 | #pred(b*tgt_len,vacab_size) 53 | 54 | tgt = tgt[:, 1:] 55 | # tgt(b,max_len) 56 | loss, n_correct = cal_performance(pred, tgt, smoothing=label_smoothing,local_rank=device) 57 | 58 | # torch.distributed.barrier() 59 | loss.backward() 60 | 61 | # clip gradients 62 | torch.nn.utils.clip_grad_norm_(model.parameters(), grads_clip) 63 | 64 | # optimizer.step() 65 | optimizer.step_and_update_lr(epoch) 66 | 67 | total_loss += loss.item() 68 | non_pad_mask = tgt.ne(0) 69 | n_word = non_pad_mask.sum().item() 70 | n_word_total += n_word 71 | n_word_correct += n_correct 72 | 73 | batch_time = time.time() - end 74 | end = time.time() 75 | total_time += batch_time 76 | count += 1 77 | avg_time = total_time / count 78 | eta_seconds = avg_time * (max_iter - iteration) 79 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 80 | 81 | acc = n_word_correct / n_word_total 82 | 83 | if tfboard_iter and (max_iter * epoch + iteration) % tfboard_iter == 0: 84 | meter.add_scalar(lr=optimizer._optimizer.param_groups[0]['lr'], loss=loss.item(), acc=acc) 85 | 86 | # if (max_iter * epoch + iteration) % 5000 == 0: 87 | # # meter.add_graph(model.module,(images,tgt)) 88 | # for key in model.module.state_dict(): 89 | # meter.add_histogram(key,model.module.state_dict()[key]) 90 | if cfg.train_method!='dist' or device ==0: 91 | if iteration % display_iter == 0: 92 | msg = 'epoch: {epoch} iter: {iter} loss: {loss: .6f} lr: {lr: .6f} eta: {eta}'.format( 93 | epoch=epoch, 94 | iter='{}/{}'.format(iteration, max_iter), 95 | loss=loss.item(), 96 | lr=optimizer._optimizer.param_groups[0]['lr'], 97 | # lr=optimizer.param_groups[0]['lr'], 98 | eta=eta_string 99 | ) 100 | logger.info(msg) 101 | 102 | # if save_iter and (max_iter * epoch + iteration) % save_iter == 0: 103 | # logger.info("Saving model ...") 104 | # torch.save(model.module.state_dict(), '{}/model_epoch_{}_iter_{}.pth'.format(model_dir, epoch, iteration)) 105 | # logger.info("Saved!") 106 | 107 | if epoch >= 6 and iteration % eval_iter == 0: 108 | eval_loss, eval_acc = eval( 109 | model=model, 110 | data_loader=val_data_loader, 111 | device=device, 112 | label_smoothing=label_smoothing, 113 | cfg=cfg 114 | ) 115 | meter.add_scalar(eval_loss=eval_loss, eval_acc=eval_acc) 116 | logger.info('eval_loss:{:.4f},eval_acc:{:.4f}--------\n'.format(eval_loss,eval_acc)) 117 | if eval_acc > best_eval: 118 | best_eval = eval_acc 119 | best_epoch = epoch 120 | best_iteration = iteration 121 | logger.info("Saving model: best_acc in epoch:{},iteration:{}".format(best_epoch,best_iteration)) 122 | torch.save(model.module.state_dict(), '{}/epoch{}_best_acc.pth'.format(model_dir, epoch)) 123 | logger.info("Saved!") 124 | if epoch > 8: 125 | logger.info("Saving last epoch model in epoch:{},iteration:{}".format(epoch, iteration)) 126 | torch.save(model.module.state_dict(), '{}/epoch{}_iter{}.pth'.format(model_dir, epoch,iteration)) 127 | logger.info("Saved!") 128 | model.train() 129 | 130 | loss_per_word = total_loss / max_iter 131 | accuracy = n_word_correct / n_word_total 132 | logger.info("Now: best_acc in epoch:{},iteration:{}".format(best_epoch, best_iteration)) 133 | return loss_per_word, accuracy 134 | 135 | 136 | def eval(model, data_loader, device, label_smoothing,cfg): 137 | model.eval() 138 | total_loss = 0 139 | n_word_total = 0 140 | n_word_correct = 0 141 | avg_acc = .0 142 | datasets_len = len(data_loader) 143 | with torch.no_grad(): 144 | for dataset in data_loader: 145 | data_len=len(dataset) 146 | for iteration, batch in enumerate(dataset, 0): 147 | if cfg.train_method=='dist': 148 | images = batch[0].cuda(device,non_blocking=True) 149 | tgt = batch[1].cuda(device,non_blocking=True) 150 | else: 151 | images = batch[0].to(device) 152 | tgt = batch[1].to(device) 153 | pred = model(images, tgt) 154 | tgt = tgt[:, 1:] 155 | loss, n_correct = cal_performance(pred, tgt, smoothing=label_smoothing,local_rank=device) 156 | 157 | total_loss += loss.item() 158 | non_pad_mask = tgt.ne(0) 159 | n_word = non_pad_mask.sum().item() 160 | n_word_total += n_word 161 | n_word_correct += n_correct 162 | 163 | loss_per_word = total_loss / data_len 164 | accuracy = n_word_correct / n_word_total 165 | # print("accuracy:{}".format(accuracy)) 166 | avg_acc +=accuracy 167 | 168 | return loss_per_word, avg_acc/datasets_len 169 | 170 | 171 | def do_train(model, 172 | train_dataloader, 173 | val_dataloader, 174 | optimizer, 175 | device, 176 | num_epochs, 177 | current_epoch, 178 | logger, 179 | meter, 180 | save_iter, 181 | display_iter, 182 | tfboard_iter, 183 | eval_iter, 184 | model_dir, 185 | label_smoothing, 186 | grads_clip,cfg): 187 | # meter.add_graph(model.module,(images,tgt)) 188 | best_eval = 0. 189 | best_epoch = 0 190 | best_iteration = 0 191 | for epoch in range(current_epoch, num_epochs): 192 | if cfg.train_method=='dist': 193 | train_dataloader.sampler.set_epoch(epoch) 194 | val_dataloader.sampler.set_epoch(epoch) 195 | start = time.time() 196 | train_loss, train_accu = train( 197 | model=model, 198 | train_data_loader=train_dataloader, 199 | val_data_loader=val_dataloader, 200 | optimizer=optimizer, 201 | device=device, 202 | epoch=epoch, 203 | logger=logger, 204 | meter=meter, 205 | save_iter=save_iter, 206 | display_iter=display_iter, 207 | tfboard_iter=tfboard_iter, 208 | eval_iter=eval_iter, 209 | model_dir=model_dir, 210 | label_smoothing=label_smoothing, 211 | grads_clip=grads_clip, 212 | cfg=cfg, 213 | best_eval=best_eval, 214 | best_epoch=best_epoch, 215 | best_iteration=best_iteration 216 | ) 217 | 218 | logger.info(' - (Training) loss: {loss: 8.5f}, accuracy: {accu:3.3f} %, time: {time:3.3f} min' 219 | .format(loss=train_loss, accu=100 * train_accu, time=(time.time() - start) / 60)) 220 | 221 | # eval & save 222 | start = time.time() 223 | if epoch >= 6: 224 | logger.info("Start eval ...") 225 | val_loss, val_accu = eval( 226 | model=model, 227 | data_loader=val_dataloader, 228 | device=device, 229 | label_smoothing=label_smoothing, 230 | cfg=cfg, 231 | ) 232 | logger.info(' - (Validation) loss: {loss: 8.5f}, accuracy: {accu:3.3f} %, time: {time:3.3f} min' 233 | .format(loss=val_loss, accu=100 * val_accu, time=(time.time() - start) / 60)) 234 | 235 | if cfg.train_method != 'dist' or device == 0: 236 | logger.info("Saving model ...") 237 | torch.save(model.module.state_dict(), '{}/model_epoch_{}.pth'.format(model_dir, epoch)) 238 | logger.info("Saved!") -------------------------------------------------------------------------------- /cdistnet/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/__init__.py -------------------------------------------------------------------------------- /cdistnet/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/model/__pycache__/translator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/__pycache__/translator.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/model/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import copy 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.nn.modules import Module 7 | # from torch.nn.modules import MultiheadAttention 8 | from torch.nn.modules import ModuleList 9 | from torch.nn.init import xavier_uniform_ 10 | from torch.nn.modules import Dropout 11 | from torch.nn.modules import Linear 12 | from torch.nn.modules import LayerNorm 13 | from torch.nn.modules import Conv2d 14 | 15 | from cdistnet.model.stage.multiheadAttention import MultiheadAttention 16 | 17 | def _get_clones(module, N): 18 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 19 | 20 | class MDCDP(Module): 21 | r""" 22 | Multi-Domain CharacterDistance Perception 23 | """ 24 | 25 | def __init__(self, decoder_layer, num_layers): 26 | super(MDCDP, self).__init__() 27 | 28 | d_model = 512 29 | self.num_layers = num_layers 30 | 31 | # step 1 SAE: 32 | self.layers_pos = _get_clones(decoder_layer, num_layers) 33 | 34 | # step 2 CBI: 35 | self.layers2 = _get_clones(decoder_layer, num_layers) 36 | self.layers3 = _get_clones(decoder_layer, num_layers) 37 | 38 | # step 3 :DSF 39 | self.dynamic_shared_fusion = DSF(d_model,2) 40 | 41 | def forward(self, sem, vis, pos, tgt_mask=None, 42 | memory_mask=None, tgt_key_padding_mask=None, 43 | memory_key_padding_mask=None): 44 | 45 | # print("start!!:pos{},\n sem{}\n vis{} \n".format(pos.shape,sem.shape,vis.shape)) 46 | for i in range(self.num_layers): 47 | # step 1 : SAE 48 | # pos 49 | pos = self.layers_pos[i](pos, pos, pos, 50 | memory_mask=tgt_mask, 51 | memory_key_padding_mask=tgt_key_padding_mask) 52 | # print("pos:{}".format(pos.shape)) 53 | 54 | 55 | #----------step 2 -----------: CBI 56 | # CBI-V : pos_vis 57 | pos_vis = self.layers2[i](pos, vis, vis, 58 | memory_mask=memory_mask, 59 | memory_key_padding_mask=memory_key_padding_mask) 60 | # print("pos_vis:{}".format(pos_vis.shape)) 61 | 62 | # CBI-S : pos_sem 63 | pos_sem = self.layers3[i](pos, sem, sem, 64 | memory_mask=tgt_mask, 65 | memory_key_padding_mask=tgt_key_padding_mask) 66 | # print("pos_sem:{}".format(pos_sem.shape)) 67 | 68 | # ----------step 3 -----------: DSF 69 | pos = self.dynamic_shared_fusion(pos_vis, pos_sem) 70 | 71 | output = pos 72 | return output 73 | 74 | class TransformerEncoder(Module): 75 | r"""TransformerEncoder is a stack of N encoder layers 76 | 77 | Args: 78 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 79 | num_layers: the number of sub-encoder-layers in the encoder (required). 80 | norm: the layer normalization component (optional). 81 | 82 | Examples:: 83 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) 84 | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers) 85 | """ 86 | 87 | def __init__(self, encoder_layer, num_layers, norm=None): 88 | super(TransformerEncoder, self).__init__() 89 | self.layers = _get_clones(encoder_layer, num_layers) 90 | self.pos_embeding = nn.Parameter(torch.zeros(48,1, 512)) 91 | self.pos_encoding = PositionalEncoding(dropout=0.0, dim=512) 92 | self.num_layers = num_layers 93 | self.norm = norm 94 | 95 | def forward(self, src, mask=None, src_key_padding_mask=None,pos_test=False): 96 | r"""Pass the input through the endocder layers in turn. 97 | 98 | Args: 99 | src: the sequnce to the encoder (required). 100 | mask: the mask for the src sequence (optional). 101 | src_key_padding_mask: the mask for the src keys per batch (optional). 102 | 103 | Shape: 104 | see the docs in Transformer class. 105 | output: math(S,N,E) 106 | """ 107 | # pos message for encoder 108 | # pos = src.new_zeros(*src.shape) 109 | # pos = self.pos_encoding(pos) 110 | if pos_test == True: 111 | pos = self.pos_embeding 112 | # print("src:{}".format(src.shape)) 113 | # print("pos:{}".format(pos.shape)) 114 | output = src + pos 115 | else: 116 | output = src 117 | for i in range(self.num_layers): 118 | output = self.layers[i](output, src_mask=mask, 119 | src_key_padding_mask=src_key_padding_mask) 120 | if self.norm: 121 | output = self.norm(output) 122 | if src_key_padding_mask is not None: 123 | # only show no mask seq value 124 | output = output.permute(1, 0, 2) * torch.unsqueeze(~src_key_padding_mask, dim=-1).to(torch.float) 125 | output = output.permute(1, 0, 2) 126 | return output 127 | return output 128 | 129 | 130 | class TransformerEncoderLayer(Module): 131 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 132 | This standard encoder layer is based on the paper "Attention Is All You Need". 133 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 134 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 135 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 136 | in a different way during application. 137 | 138 | Args: 139 | d_model: the number of expected features in the input (required). 140 | nhead: the number of heads in the multiheadattention models (required). 141 | dim_feedforward: the dimension of the feedforward network model (default=2048). 142 | dropout: the dropout value (default=0.1). 143 | 144 | Examples:: 145 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) 146 | """ 147 | 148 | def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1): 149 | super(TransformerEncoderLayer, self).__init__() 150 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=attention_dropout_rate) 151 | # # # Implementation of Feedforward model 152 | # self.linear1 = Linear(d_model, dim_feedforward) 153 | # self.dropout = Dropout(attention_dropout_rate) 154 | # self.linear2 = Linear(dim_feedforward, d_model) 155 | 156 | self.conv1 = Conv2d(in_channels=d_model, out_channels=dim_feedforward, kernel_size=(1, 1)) 157 | self.conv2 = Conv2d(in_channels=dim_feedforward, out_channels=d_model, kernel_size=(1, 1)) 158 | # torch.nn.init.xavier_uniform_(self.conv1.weight.data) 159 | # torch.nn.init.xavier_uniform_(self.conv2.weight.data) 160 | # if self.conv1.bias is not None: 161 | # if self.conv1.bias is not None: 162 | # self.conv1.bias.data.zero_() 163 | # if self.conv2.bias is not None: 164 | # if self.conv2.bias is not None: 165 | # self.conv2.bias.data.zero_() 166 | 167 | self.norm1 = LayerNorm(d_model) 168 | self.norm2 = LayerNorm(d_model) 169 | self.dropout1 = Dropout(residual_dropout_rate) 170 | self.dropout2 = Dropout(residual_dropout_rate) 171 | 172 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 173 | r"""Pass the input through the endocder layer. 174 | 175 | Args: 176 | src: the sequence to the encoder layer (required). 177 | src_mask: the mask for the src sequence (optional). 178 | src_key_padding_mask: the mask for the src keys per batch (optional). 179 | 180 | Shape: 181 | see the docs in Transformer class. 182 | """ 183 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 184 | key_padding_mask=src_key_padding_mask)[0] 185 | src = src + self.dropout1(src2) 186 | src = self.norm1(src) 187 | 188 | # default 189 | # src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 190 | 191 | src = src.permute(1, 2, 0) 192 | src = torch.unsqueeze(src, 2) 193 | src2 = self.conv2(F.relu(self.conv1(src))) 194 | src2 = torch.squeeze(src2, 2) 195 | src2 = src2.permute(2, 0, 1) 196 | src = torch.squeeze(src, 2) 197 | src = src.permute(2, 0, 1) 198 | 199 | src = src + self.dropout2(src2) 200 | src = self.norm2(src) 201 | return src 202 | 203 | class DSF(nn.Module): 204 | def __init__(self, d_model,fusion_num): 205 | super(DSF, self).__init__() 206 | self.w_att = nn.Linear(fusion_num * d_model, d_model) 207 | 208 | def forward(self, l_feature, v_feature): 209 | """ 210 | Args: 211 | l_feature: (N, T, E) where T is length, N is batch size and d is dim of model 212 | v_feature: (N, T, E) shape the same as l_feature 213 | l_lengths: (N,) 214 | v_lengths: (N,) 215 | """ 216 | f = torch.cat((l_feature, v_feature), dim=2) 217 | f_att = torch.sigmoid(self.w_att(f)) 218 | output = f_att * v_feature + (1 - f_att) * l_feature 219 | 220 | return output 221 | 222 | class CommonAttentionLayer(Module): 223 | def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1): 224 | super(CommonAttentionLayer, self).__init__() 225 | 226 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=attention_dropout_rate) 227 | # Implementation of Feedforward model 228 | self.linear1 = Linear(d_model, dim_feedforward) 229 | self.dropout = Dropout(attention_dropout_rate) 230 | self.linear2 = Linear(dim_feedforward, d_model) 231 | 232 | self.norm2 = LayerNorm(d_model) 233 | self.norm3 = LayerNorm(d_model) 234 | 235 | self.dropout2 = Dropout(residual_dropout_rate) 236 | self.dropout3 = Dropout(residual_dropout_rate) 237 | 238 | def forward(self, query, key, value, memory_mask=None, 239 | memory_key_padding_mask=None): 240 | """Pass the inputs (and mask) through the decoder layer. 241 | """ 242 | 243 | out = self.multihead_attn(query, key, value, attn_mask=memory_mask, 244 | key_padding_mask=memory_key_padding_mask)[0] 245 | out = query + self.dropout2(out) 246 | out = self.norm2(out) 247 | 248 | out2 = self.linear2(self.dropout(F.relu(self.linear1(out)))) 249 | out = out + self.dropout3(out2) 250 | out = self.norm3(out) 251 | return out 252 | 253 | class CommonDecoderLayer(Module): 254 | def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1): 255 | super(CommonDecoderLayer, self).__init__() 256 | 257 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=attention_dropout_rate) 258 | # Implementation of Feedforward model 259 | self.linear1 = Linear(d_model, dim_feedforward) 260 | self.dropout = Dropout(attention_dropout_rate) 261 | self.linear2 = Linear(dim_feedforward, d_model) 262 | 263 | self.conv1 = Conv2d(in_channels=d_model, out_channels=dim_feedforward, kernel_size=(1, 1)) 264 | self.conv2 = Conv2d(in_channels=dim_feedforward, out_channels=d_model, kernel_size=(1, 1)) 265 | 266 | self.norm2 = LayerNorm(d_model) 267 | self.norm3 = LayerNorm(d_model) 268 | 269 | self.dropout2 = Dropout(residual_dropout_rate) 270 | self.dropout3 = Dropout(residual_dropout_rate) 271 | 272 | def forward(self, query, key, value, tgt_mask=None, memory_mask=None, 273 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 274 | """Pass the inputs (and mask) through the decoder layer. 275 | """ 276 | 277 | out = self.multihead_attn(query, key, value, attn_mask=memory_mask, 278 | key_padding_mask=memory_key_padding_mask)[0] 279 | out = query + self.dropout2(out) 280 | out = self.norm2(out) 281 | 282 | # default 283 | # tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) 284 | tgt = out.permute(1, 2, 0) 285 | tgt = torch.unsqueeze(tgt, 2) 286 | tgt2 = self.conv2(F.relu(self.conv1(tgt))) 287 | tgt2 = torch.squeeze(tgt2, 2) 288 | tgt2 = tgt2.permute(2, 0, 1) 289 | tgt = torch.squeeze(tgt, 2) 290 | tgt = tgt.permute(2, 0, 1) 291 | 292 | tgt = tgt + self.dropout3(tgt2) 293 | tgt = self.norm3(tgt) 294 | return tgt 295 | 296 | class PositionalEncoding(nn.Module): 297 | r"""Inject some information about the relative or absolute position of the tokens 298 | in the sequence. The positional encodings have the same dimension as 299 | the, so that the two can be summed. Here, we use sine and cosine 300 | functions of different frequencies. 301 | .. math:: 302 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 303 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 304 | \text{where pos is the word position and i is the embed idx) 305 | Args: 306 | d_model: the embed dim (required). 307 | dropout: the dropout value (default=0.1). 308 | max_len: the max. length of the incoming sequence (default=5000). 309 | Examples: 310 | >>> pos_encoder = PositionalEncoding(d_model) 311 | """ 312 | 313 | def __init__(self, dropout, dim, max_len=5000): 314 | super(PositionalEncoding, self).__init__() 315 | self.dropout = nn.Dropout(p=dropout) 316 | 317 | pe = torch.zeros(max_len, dim) 318 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 319 | div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) 320 | pe[:, 0::2] = torch.sin(position * div_term) 321 | pe[:, 1::2] = torch.cos(position * div_term) 322 | pe = pe.unsqueeze(0).transpose(0, 1) 323 | self.register_buffer('pe', pe) 324 | 325 | def forward(self, x): 326 | r"""Inputs of forward function 327 | Args: 328 | x: the sequence fed to the positional encoder model (required). 329 | Shape: 330 | x: [sequence length, batch size, embed dim] 331 | output: [sequence length, batch size, embed dim] 332 | Examples: 333 | >>> output = pos_encoder(x) 334 | """ 335 | # x(w,b,h*c) 336 | x = x + self.pe[:x.size(0), :] 337 | return self.dropout(x) 338 | 339 | 340 | class PositionalEncoding_2d(nn.Module): 341 | r"""Inject some information about the relative or absolute position of the tokens 342 | in the sequence. The positional encodings have the same dimension as 343 | the embeddings, so that the two can be summed. Here, we use sine and cosine 344 | functions of different frequencies. 345 | .. math:: 346 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 347 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 348 | \text{where pos is the word position and i is the embed idx) 349 | Args: 350 | d_model: the embed dim (required). 351 | dropout: the dropout value (default=0.1). 352 | max_len: the max. length of the incoming sequence (default=5000). 353 | Examples: 354 | >>> pos_encoder = PositionalEncoding(d_model) 355 | """ 356 | 357 | def __init__(self, dropout, dim, max_len=5000): 358 | super(PositionalEncoding_2d, self).__init__() 359 | self.dropout = nn.Dropout(p=dropout) 360 | 361 | pe = torch.zeros(max_len, dim) 362 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 363 | div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) 364 | pe[:, 0::2] = torch.sin(position * div_term) 365 | pe[:, 1::2] = torch.cos(position * div_term) 366 | pe = pe.unsqueeze(0).transpose(0, 1) 367 | self.register_buffer('pe', pe) 368 | 369 | self.avg_pool_1 = nn.AdaptiveAvgPool2d((1, 1)) 370 | self.linear1 = nn.Linear(dim, dim) 371 | self.linear1.weight.data.fill_(1.) 372 | self.avg_pool_2 = nn.AdaptiveAvgPool2d((1, 1)) 373 | self.linear2 = nn.Linear(dim, dim) 374 | self.linear2.weight.data.fill_(1.) 375 | 376 | def forward(self, x): 377 | r"""Inputs of forward function 378 | Args: 379 | x: the sequence fed to the positional encoder model (required). 380 | Shape: 381 | x: [sequence length, batch size, embed dim] 382 | output: [sequence length, batch size, embed dim] 383 | Examples: 384 | >>> output = pos_encoder(x) 385 | """ 386 | 387 | # x = x + self.pe[:x.size(0), :] 388 | w_pe = self.pe[:x.size(-1), :] 389 | w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) 390 | w_pe = w_pe * w1 391 | w_pe = w_pe.permute(1, 2, 0) 392 | w_pe = w_pe.unsqueeze(2) 393 | 394 | h_pe = self.pe[:x.size(-2), :] 395 | w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) 396 | h_pe = h_pe * w2 397 | h_pe = h_pe.permute(1, 2, 0) 398 | h_pe = h_pe.unsqueeze(3) 399 | 400 | x = x + w_pe + h_pe 401 | x = x.contiguous().view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) 402 | 403 | return self.dropout(x) 404 | 405 | 406 | class Embeddings(nn.Module): 407 | def __init__(self, d_model, vocab, padding_idx, scale_embedding): 408 | super(Embeddings, self).__init__() 409 | self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx) 410 | self.embedding.weight.data.normal_(mean=0.0, std=d_model**-0.5) 411 | self.d_model = d_model 412 | self.scale_embedding = scale_embedding 413 | 414 | def forward(self, x): 415 | if self.scale_embedding: 416 | return self.embedding(x) * math.sqrt(self.d_model) 417 | return self.embedding(x) 418 | -------------------------------------------------------------------------------- /cdistnet/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.init import xavier_uniform_ 4 | from torch.nn import functional as F 5 | from torch.nn.modules import Dropout 6 | from torch.nn.modules import Linear 7 | from torch.nn.modules import LayerNorm 8 | from torch.nn.modules import Conv2d 9 | 10 | from cdistnet.model.blocks import PositionalEncoding, Embeddings, TransformerEncoderLayer, \ 11 | TransformerEncoder, CommonDecoderLayer, MDCDP, CommonAttentionLayer 12 | # from cdistnet.utils.init import init_weights_xavier 13 | from cdistnet.model.stage.backbone import ResNet45, ResNet31, MTB_nrtr 14 | from cdistnet.model.stage.tps import TPS_SpatialTransformerNetwork 15 | 16 | 17 | def generate_square_subsequent_mask(sz): 18 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 19 | Unmasked positions are filled with float(0.0). 20 | """ 21 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 22 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 23 | return mask 24 | 25 | 26 | def generate_padding_mask(x): 27 | padding_mask = x.eq(0) 28 | return padding_mask 29 | 30 | class VIS_Pre(nn.Module): 31 | def __init__(self,cfg,dim_feedforward): 32 | super(VIS_Pre, self).__init__() 33 | self.keep_aspect_ratio = cfg.keep_aspect_ratio 34 | if cfg.tps_block == 'TPS': 35 | self.transform = TPS_SpatialTransformerNetwork( 36 | F=cfg.num_fiducial, I_size=(cfg.height, cfg.width), I_r_size=(cfg.height, cfg.width), 37 | I_channel_num=1 if cfg.rgb2gray else 3) 38 | 39 | if cfg.feature_block == 'Resnet45': 40 | self.backbone = ResNet45() 41 | elif cfg.feature_block == 'Resnet31': 42 | self.backbone = ResNet31() 43 | elif cfg.feature_block == 'MTB': 44 | self.backbone = MTB_nrtr() 45 | self.positional_encoding = PositionalEncoding( 46 | dropout=cfg.residual_dropout_rate, 47 | dim=cfg.hidden_units, 48 | ) 49 | 50 | encoder_layer = TransformerEncoderLayer(cfg.hidden_units, cfg.num_heads, dim_feedforward, cfg.attention_dropout_rate, cfg.residual_dropout_rate) 51 | self.trans_encoder = TransformerEncoder(encoder_layer, cfg.num_encoder_blocks, None) 52 | 53 | def forward(self,image): 54 | x = image 55 | 56 | src = torch.sum(torch.abs(image).view(image.shape[0], -1, image.shape[-1]), dim=1) 57 | src_padding_mask = generate_padding_mask(src) 58 | 59 | x = self.transform(x) 60 | x = self.backbone(x) 61 | 62 | # # x(b,c,h,w) 63 | if self.keep_aspect_ratio: 64 | r = round(src_padding_mask.shape[1] / x.shape[1]) 65 | src_key_padding_mask = src_padding_mask[:, ::r] 66 | memory_key_padding_mask = src_key_padding_mask 67 | else: 68 | src_key_padding_mask, memory_key_padding_mask = None, None 69 | 70 | # 1d 71 | x = self.positional_encoding(x.permute(1, 0, 2)) 72 | # memory: S N E 73 | memory = self.trans_encoder(x, mask=None, src_key_padding_mask=src_key_padding_mask) 74 | return memory,memory_key_padding_mask 75 | 76 | 77 | class SEM_Pre(nn.Module): 78 | def __init__(self,cfg): 79 | super(SEM_Pre, self).__init__() 80 | self.embedding = Embeddings( 81 | d_model=cfg.hidden_units, 82 | vocab=cfg.dst_vocab_size, 83 | padding_idx=0, 84 | scale_embedding=cfg.scale_embedding 85 | ) 86 | 87 | self.positional_encoding = PositionalEncoding( 88 | dropout=cfg.residual_dropout_rate, 89 | dim=cfg.hidden_units, 90 | ) 91 | def forward(self,tgt): 92 | # tgt = tgt[:, :-1] 93 | # tgt(b,max_len) 94 | # # image(b,c,h,w) 95 | 96 | tgt_key_padding_mask = generate_padding_mask(tgt) 97 | # tgt_key_padding_mask:record 0 padding for true in tgt.same as image 98 | tgt = self.embedding(tgt).permute(1, 0, 2) 99 | # print("tgt.shape{}".format(tgt.shape)) 100 | tgt = self.positional_encoding(tgt) 101 | # tgt(max_len,b,d_model(hidden_unit)) 102 | tgt_mask = generate_square_subsequent_mask(tgt.shape[0]).to(device=tgt.device) 103 | return tgt,tgt_mask,tgt_key_padding_mask 104 | 105 | class POS_Pre(nn.Module): 106 | def __init__(self, cfg): 107 | super(POS_Pre, self).__init__() 108 | d_model = cfg.hidden_units 109 | self.pos_encoding = PositionalEncoding( 110 | dropout=cfg.residual_dropout_rate, 111 | dim=d_model, 112 | ) 113 | self.linear1 = Linear(d_model, d_model) 114 | self.linear2 = Linear(d_model, d_model) 115 | 116 | self.norm2 = LayerNorm(d_model) 117 | 118 | 119 | def forward(self,tgt): 120 | pos = tgt.new_zeros(*tgt.shape) 121 | pos = self.pos_encoding(pos) 122 | 123 | pos2 = self.linear2(F.relu(self.linear1(pos))) 124 | pos = self.norm2(pos + pos2) 125 | return pos 126 | 127 | class CDistNet(nn.Module): 128 | def __init__(self, dim_feedforward=2048, cfg=None): 129 | super(CDistNet, self).__init__() 130 | 131 | self.d_model = cfg.hidden_units 132 | self.nhead = cfg.num_heads 133 | self.keep_aspect_ratio = cfg.keep_aspect_ratio 134 | 135 | self.visual_branch = VIS_Pre(cfg,dim_feedforward) 136 | self.semantic_branch = SEM_Pre(cfg) 137 | self.positional_branch = POS_Pre(cfg) 138 | 139 | decoder_layer = CommonAttentionLayer(cfg.hidden_units, cfg.num_heads, dim_feedforward // 2, cfg.attention_dropout_rate, 140 | cfg.residual_dropout_rate) 141 | self.mdcdp = MDCDP(decoder_layer, cfg.num_decoder_blocks) 142 | self._reset_parameters() 143 | 144 | self.tgt_word_prj = nn.Linear(cfg.hidden_units, cfg.dst_vocab_size, bias=False) 145 | self.tgt_word_prj.weight.data.normal_(mean=0.0, std=cfg.hidden_units ** -0.5) 146 | 147 | def forward(self, image, tgt): 148 | tgt = tgt[:, :-1] 149 | vis_feat,vis_key_padding_mask = self.visual_branch(image) 150 | sem_feat,sem_mask,sem_key_padding_mask = self.semantic_branch(tgt) 151 | pos_feat = self.positional_branch(sem_feat) 152 | # if x.size(1) != tgt.size(1): 153 | # raise RuntimeError("the batch number of src and tgt must be equal") 154 | # 155 | # if x.size(2) != self.d_model or tgt.size(2) != self.d_model: 156 | # raise RuntimeError("the feature number of src and tgt must be equal to d_model") 157 | 158 | output = self.mdcdp(sem_feat, vis_feat, pos_feat, 159 | tgt_mask=sem_mask, 160 | memory_mask=None, 161 | tgt_key_padding_mask=sem_key_padding_mask, 162 | memory_key_padding_mask=vis_key_padding_mask) 163 | 164 | output = output.permute(1, 0, 2) 165 | 166 | logit = self.tgt_word_prj(output) 167 | return logit.view(-1, logit.shape[2]) 168 | 169 | def _reset_parameters(self): 170 | r"""Initiate parameters in the transformer model.""" 171 | 172 | for p in self.parameters(): 173 | if p.dim() > 1: 174 | xavier_uniform_(p) 175 | 176 | def build_CDistNet(cfg): 177 | net = CDistNet(dim_feedforward=2048, cfg=cfg) 178 | return net 179 | 180 | -------------------------------------------------------------------------------- /cdistnet/model/stage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/stage/__init__.py -------------------------------------------------------------------------------- /cdistnet/model/stage/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/stage/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/model/stage/__pycache__/backbone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/stage/__pycache__/backbone.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/model/stage/__pycache__/multiheadAttention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/stage/__pycache__/multiheadAttention.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/model/stage/__pycache__/tps.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/model/stage/__pycache__/tps.cpython-37.pyc -------------------------------------------------------------------------------- /cdistnet/model/stage/backbone.py: -------------------------------------------------------------------------------- 1 | '''FPN in PyTorch. 2 | See the paper "Feature Pyramid Networks for Object Detection" for more details. 3 | ''' 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | 11 | from torch.autograd import Variable 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv1x1(inplanes, planes) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes, stride) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | # class BasicBlock2(nn.Module): 55 | # expansion = 1 56 | # 57 | # def __init__(self, inplanes, planes, stride=1, downsample=False): 58 | # super().__init__() 59 | # self.conv1 = conv3x3(inplanes, planes, stride) 60 | # self.bn1 = nn.BatchNorm2d(planes) 61 | # self.relu = nn.ReLU(inplace=True) 62 | # self.conv2 = conv3x3(planes, planes) 63 | # self.bn2 = nn.BatchNorm2d(planes) 64 | # self.downsample = downsample 65 | # if downsample: 66 | # self.downsample = nn.Sequential( 67 | # nn.Conv2d( 68 | # inplanes, planes * self.expansion, 1, stride, bias=False), 69 | # nn.BatchNorm2d(planes * self.expansion), 70 | # ) 71 | # else: 72 | # self.downsample = nn.Sequential() 73 | # self.stride = stride 74 | # 75 | # def forward(self, x): 76 | # residual = x 77 | # 78 | # out = self.conv1(x) 79 | # out = self.bn1(out) 80 | # out = self.relu(out) 81 | # 82 | # out = self.conv2(out) 83 | # out = self.bn2(out) 84 | # 85 | # if self.downsample: 86 | # residual = self.downsample(x) 87 | # 88 | # out += residual 89 | # out = self.relu(out) 90 | # 91 | # return out 92 | 93 | class ConvBnRelu(nn.Module): 94 | # adapt padding for kernel_size change 95 | def __init__(self, in_channels, out_channels, kernel_size, conv = nn.Conv2d,stride=2, inplace=True): 96 | super().__init__() 97 | p_size = [int(k//2) for k in kernel_size] 98 | # p_size = int(kernel_size//2) 99 | self.conv = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=p_size) 100 | self.bn = nn.BatchNorm2d(out_channels) 101 | self.relu = nn.ReLU(inplace=inplace) 102 | 103 | def forward(self, x): 104 | x = self.conv(x) 105 | x = self.bn(x) 106 | x = self.relu(x) 107 | return x 108 | class ConvBlock(nn.Module): 109 | def __init__(self, in_chan,out_chan,kernel_size,downsample=True,conv=nn.Conv2d): 110 | super().__init__() 111 | if downsample == True: 112 | h_dim = in_chan//2 if in_chan>64 else 64 113 | else: 114 | h_dim = out_chan 115 | self.block = nn.Sequential( 116 | ConvBnRelu(in_chan,h_dim,kernel_size,conv=conv,stride=2), 117 | ConvBnRelu(h_dim,h_dim,3,conv=conv,stride=1), 118 | ConvBnRelu(h_dim,h_dim,3,conv=conv,stride=1), 119 | ) 120 | if h_dim != out_chan: 121 | self.block.add_module('down_sample',ConvBnRelu(h_dim,out_chan,1,stride=1)) 122 | 123 | def forward(self, x): 124 | x = self.block(x) 125 | return x 126 | 127 | class Bottleneck(nn.Module): 128 | expansion = 4 129 | 130 | def __init__(self, in_planes, planes, stride=1): 131 | super(Bottleneck, self).__init__() 132 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 133 | self.bn1 = nn.BatchNorm2d(planes) 134 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 135 | self.bn2 = nn.BatchNorm2d(planes) 136 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 137 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 138 | 139 | self.shortcut = nn.Sequential() 140 | if stride != 1 or in_planes != self.expansion*planes: 141 | self.shortcut = nn.Sequential( 142 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 143 | nn.BatchNorm2d(self.expansion*planes) 144 | ) 145 | 146 | def forward(self, x): 147 | out = F.relu(self.bn1(self.conv1(x))) 148 | out = F.relu(self.bn2(self.conv2(out))) 149 | out = self.bn3(self.conv3(out)) 150 | out += self.shortcut(x) 151 | out = F.relu(out) 152 | return out 153 | 154 | class NRTRModalityTransform(nn.Module): 155 | 156 | def __init__(self, input_channels=3, input_height=32): 157 | super().__init__() 158 | 159 | self.conv_1 = nn.Conv2d( 160 | in_channels=input_channels, 161 | out_channels=32, 162 | kernel_size=3, 163 | stride=2, 164 | padding=1) 165 | self.relu_1 = nn.ReLU(True) 166 | self.bn_1 = nn.BatchNorm2d(32) 167 | 168 | self.conv_2 = nn.Conv2d( 169 | in_channels=32, 170 | out_channels=64, 171 | kernel_size=(3,1), 172 | stride=(2,1), 173 | padding=(1,0)) 174 | self.relu_2 = nn.ReLU(True) 175 | self.bn_2 = nn.BatchNorm2d(64) 176 | 177 | feat_height = input_height // 4 178 | 179 | self.linear = nn.Linear(512, 512) 180 | 181 | def init_weights(self, pretrained=None): 182 | for m in self.modules(): 183 | if isinstance(m, nn.Conv2d): 184 | kaiming_init(m) 185 | elif isinstance(m, nn.BatchNorm2d): 186 | uniform_init(m) 187 | 188 | def forward(self, x): 189 | x = self.conv_1(x) 190 | x = self.relu_1(x) 191 | x = self.bn_1(x) 192 | 193 | x = self.conv_2(x) 194 | x = self.relu_2(x) 195 | x = self.bn_2(x) 196 | 197 | n, c, h, w = x.size() 198 | x = x.permute(0, 3, 2, 1).contiguous().view(n, w, h * c) 199 | x = self.linear(x) 200 | # print(x.shape) 201 | # x = x.permute(0, 2, 1).contiguous() 202 | # print(x.shape) 203 | return x 204 | 205 | class ResNet31OCR(nn.Module): 206 | """Implement ResNet backbone for text recognition, modified from 207 | `ResNet `_ 208 | Args: 209 | base_channels (int): Number of channels of input image tensor. 210 | layers (list[int]): List of BasicBlock number for each stage. 211 | channels (list[int]): List of out_channels of Conv2d layer. 212 | out_indices (None | Sequence[int]): Indices of output stages. 213 | stage4_pool_cfg (dict): Dictionary to construct and configure 214 | pooling layer in stage 4. 215 | last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage. 216 | """ 217 | 218 | def __init__(self, 219 | base_channels=3, 220 | layers=[1, 2, 5, 3], 221 | channels=[64, 128, 256, 256, 512, 512, 512], 222 | out_indices=None, 223 | stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)), 224 | last_stage_pool=False): 225 | super().__init__() 226 | assert isinstance(base_channels, int) 227 | # assert utils.is_type_list(layers, int) 228 | # assert utils.is_type_list(channels, int) 229 | assert out_indices is None or (isinstance(out_indices, list) 230 | or isinstance(out_indices, tuple)) 231 | assert isinstance(last_stage_pool, bool) 232 | 233 | self.out_indices = out_indices 234 | self.last_stage_pool = last_stage_pool 235 | 236 | # conv 1 (Conv, Conv) 237 | self.conv1_1 = nn.Conv2d( 238 | base_channels, channels[0], kernel_size=3, stride=1, padding=1) 239 | self.bn1_1 = nn.BatchNorm2d(channels[0]) 240 | self.relu1_1 = nn.ReLU(inplace=True) 241 | 242 | self.conv1_2 = nn.Conv2d( 243 | channels[0], channels[1], kernel_size=3, stride=1, padding=1) 244 | self.bn1_2 = nn.BatchNorm2d(channels[1]) 245 | self.relu1_2 = nn.ReLU(inplace=True) 246 | 247 | # conv 2 (Max-pooling, Residual block, Conv) 248 | self.pool2 = nn.MaxPool2d( 249 | kernel_size=2, stride=2, padding=0, ceil_mode=True) 250 | self.block2 = self._make_layer(channels[1], channels[2], layers[0]) 251 | self.conv2 = nn.Conv2d( 252 | channels[2], channels[2], kernel_size=3, stride=1, padding=1) 253 | self.bn2 = nn.BatchNorm2d(channels[2]) 254 | self.relu2 = nn.ReLU(inplace=True) 255 | 256 | # conv 3 (Max-pooling, Residual block, Conv) 257 | self.pool3 = nn.MaxPool2d( 258 | kernel_size=2, stride=2, padding=0, ceil_mode=True) 259 | self.block3 = self._make_layer(channels[2], channels[3], layers[1]) 260 | self.conv3 = nn.Conv2d( 261 | channels[3], channels[3], kernel_size=3, stride=1, padding=1) 262 | self.bn3 = nn.BatchNorm2d(channels[3]) 263 | self.relu3 = nn.ReLU(inplace=True) 264 | 265 | # conv 4 (Max-pooling, Residual block, Conv) 266 | self.pool4 = nn.MaxPool2d(padding=0, ceil_mode=True, **stage4_pool_cfg) 267 | self.block4 = self._make_layer(channels[3], channels[4], layers[2]) 268 | self.conv4 = nn.Conv2d( 269 | channels[4], channels[4], kernel_size=3, stride=1, padding=1) 270 | self.bn4 = nn.BatchNorm2d(channels[4]) 271 | self.relu4 = nn.ReLU(inplace=True) 272 | 273 | # conv 5 ((Max-pooling), Residual block, Conv) 274 | self.pool5 = None 275 | if self.last_stage_pool: 276 | self.pool5 = nn.MaxPool2d( 277 | kernel_size=2, stride=2, padding=0, ceil_mode=True) # 1/16 278 | self.block5 = self._make_layer(channels[4], channels[5], layers[3]) 279 | self.conv5 = nn.Conv2d( 280 | channels[5], channels[5], kernel_size=3, stride=1, padding=1) 281 | self.bn5 = nn.BatchNorm2d(channels[5]) 282 | self.relu5 = nn.ReLU(inplace=True) 283 | self.convbnrelu = ConvBnRelu(in_channels=512, out_channels=512, kernel_size=(1,3), stride=(1,2)) 284 | 285 | def init_weights(self, pretrained=None): 286 | # initialize weight and bias 287 | for m in self.modules(): 288 | if isinstance(m, nn.Conv2d): 289 | kaiming_init(m) 290 | elif isinstance(m, nn.BatchNorm2d): 291 | uniform_init(m) 292 | 293 | def _make_layer(self, input_channels, output_channels, blocks): 294 | layers = [] 295 | for _ in range(blocks): 296 | downsample = None 297 | if input_channels != output_channels: 298 | downsample = nn.Sequential( 299 | nn.Conv2d( 300 | input_channels, 301 | output_channels, 302 | kernel_size=1, 303 | stride=1, 304 | bias=False), 305 | nn.BatchNorm2d(output_channels), 306 | ) 307 | layers.append( 308 | BasicBlock( 309 | input_channels, output_channels, downsample=downsample)) 310 | input_channels = output_channels 311 | 312 | return nn.Sequential(*layers) 313 | 314 | def forward(self, x): 315 | 316 | x = self.conv1_1(x) 317 | x = self.bn1_1(x) 318 | x = self.relu1_1(x) 319 | 320 | x = self.conv1_2(x) 321 | x = self.bn1_2(x) 322 | x = self.relu1_2(x) 323 | 324 | outs = [] 325 | for i in range(4): 326 | layer_index = i + 2 327 | pool_layer = getattr(self, f'pool{layer_index}') 328 | block_layer = getattr(self, f'block{layer_index}') 329 | conv_layer = getattr(self, f'conv{layer_index}') 330 | bn_layer = getattr(self, f'bn{layer_index}') 331 | relu_layer = getattr(self, f'relu{layer_index}') 332 | 333 | if pool_layer is not None: 334 | x = pool_layer(x) 335 | x = block_layer(x) 336 | x = conv_layer(x) 337 | x = bn_layer(x) 338 | x = relu_layer(x) 339 | 340 | # outs.append(x) 341 | x = self.convbnrelu(x) 342 | x = rearrange(x, 'b c h w -> b (w h) c') 343 | # if self.out_indices is not None: 344 | # return tuple([outs[i] for i in self.out_indices]) 345 | # print(x.shape) 346 | return x 347 | 348 | class ABI_ResNet(nn.Module): 349 | 350 | def __init__(self, block, layers): 351 | self.inplanes = 32 352 | super(ABI_ResNet, self).__init__() 353 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, 354 | bias=False) 355 | self.bn1 = nn.BatchNorm2d(32) 356 | self.relu = nn.ReLU(inplace=True) 357 | self.convbnrelu = ConvBnRelu(in_channels=512, out_channels=512, kernel_size=(3,3),stride=(2,2)) 358 | # self.convbnrelu2 = ConvBnRelu(in_channels=512, out_channels=512, kernel_size=3,stride=2) 359 | # self.convbnrelu2 = ConvBnRelu(in_channels=1024, out_channels=256, kernel_size=3) 360 | 361 | self.layer1 = self._make_layer(block, 32, layers[0], stride=2) 362 | self.layer2 = self._make_layer(block, 64, layers[1], stride=1) 363 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 364 | self.layer4 = self._make_layer(block, 256, layers[3], stride=1) 365 | self.layer5 = self._make_layer(block, 512, layers[4], stride=1)#stride = 1 366 | 367 | for m in self.modules(): 368 | if isinstance(m, nn.Conv2d): 369 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 370 | m.weight.data.normal_(0, math.sqrt(2. / n)) 371 | elif isinstance(m, nn.BatchNorm2d): 372 | m.weight.data.fill_(1) 373 | m.bias.data.zero_() 374 | 375 | def _make_layer(self, block, planes, blocks, stride=1): 376 | downsample = None 377 | if stride != 1 or self.inplanes != planes * block.expansion: 378 | downsample = nn.Sequential( 379 | nn.Conv2d(self.inplanes, planes * block.expansion, 380 | kernel_size=1, stride=stride, bias=False), 381 | nn.BatchNorm2d(planes * block.expansion), 382 | ) 383 | 384 | layers = [] 385 | layers.append(block(self.inplanes, planes, stride, downsample)) 386 | self.inplanes = planes * block.expansion 387 | for i in range(1, blocks): 388 | layers.append(block(self.inplanes, planes)) 389 | 390 | return nn.Sequential(*layers) 391 | 392 | def forward(self, x): 393 | x = self.conv1(x) 394 | x = self.bn1(x) 395 | x = self.relu(x) 396 | x = self.layer1(x) 397 | x = self.layer2(x) 398 | x = self.layer3(x) 399 | x = self.layer4(x) 400 | x = self.layer5(x) 401 | x = self.convbnrelu(x) 402 | # x = self.convbnrelu2(x) 403 | x = rearrange(x, 'b c h w -> b (w h) c') 404 | # print(x.shape) 405 | return x 406 | 407 | 408 | def ResNet45(): 409 | return ABI_ResNet(BasicBlock, [3, 4, 6, 6, 3]) 410 | def ResNet31(): 411 | return ResNet31OCR() 412 | def MTB_nrtr(): 413 | return NRTRModalityTransform() 414 | 415 | 416 | def test(): 417 | pass 418 | # net = FPN101() 419 | # fms = net(Variable(torch.randn(1,3,600,900))) 420 | # for fm in fms: 421 | # print(fm.size()) 422 | # layer1 = _make_layer(block=Bottleneck, planes=64, num_blocks=2, stride=2) 423 | # layer1(torch.randn(1,64,600,900)) 424 | # layer2 = _make_layer(Bottleneck, 128, 2, stride=2) -------------------------------------------------------------------------------- /cdistnet/model/stage/multiheadAttention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import Linear 6 | from torch.nn.init import xavier_uniform_ 7 | from torch.nn.init import constant_ 8 | from torch.nn.init import xavier_normal_ 9 | from torch.nn.parameter import Parameter 10 | 11 | 12 | 13 | class MultiheadAttention(nn.Module): 14 | r"""Allows the model to jointly attend to information 15 | from different representation subspaces. 16 | See reference: Attention Is All You Need 17 | 18 | .. math:: 19 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 20 | \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 21 | 22 | Args: 23 | embed_dim: total dimension of the model 24 | num_heads: parallel attention layers, or heads 25 | 26 | Examples:: 27 | 28 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 29 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 30 | """ 31 | 32 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): 33 | super(MultiheadAttention, self).__init__() 34 | self.embed_dim = embed_dim 35 | self.num_heads = num_heads 36 | self.dropout = dropout 37 | self.head_dim = embed_dim // num_heads 38 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 39 | self.scaling = self.head_dim ** -0.5 40 | 41 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 42 | if bias: 43 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 44 | else: 45 | self.register_parameter('in_proj_bias', None) 46 | self.out_proj = Linear(embed_dim, embed_dim, bias=bias) 47 | 48 | if add_bias_kv: 49 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 50 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 51 | else: 52 | self.bias_k = self.bias_v = None 53 | 54 | self.add_zero_attn = add_zero_attn 55 | 56 | 57 | self.conv1 = torch.nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) 58 | self.conv2 = torch.nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1)) 59 | self.conv3 = torch.nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1)) 60 | # torch.nn.init.xavier_uniform_(self.conv1.weight.data) 61 | # torch.nn.init.xavier_uniform_(self.conv2.weight.data) 62 | # torch.nn.init.xavier_uniform_(self.conv3.weight.data) 63 | # if self.conv1.bias is not None: 64 | # if self.conv1.bias is not None: 65 | # self.conv1.bias.data.zero_() 66 | # if self.conv2.bias is not None: 67 | # if self.conv2.bias is not None: 68 | # self.conv2.bias.data.zero_() 69 | # if self.conv3.bias is not None: 70 | # if self.conv3.bias is not None: 71 | # self.conv3.bias.data.zero_() 72 | self._reset_parameters() 73 | 74 | 75 | 76 | def _reset_parameters(self): 77 | xavier_uniform_(self.in_proj_weight[:self.embed_dim, :]) 78 | xavier_uniform_(self.in_proj_weight[self.embed_dim:(self.embed_dim * 2), :]) 79 | xavier_uniform_(self.in_proj_weight[(self.embed_dim * 2):, :]) 80 | 81 | xavier_uniform_(self.out_proj.weight) 82 | if self.in_proj_bias is not None: 83 | constant_(self.in_proj_bias, 0.) 84 | constant_(self.out_proj.bias, 0.) 85 | if self.bias_k is not None: 86 | xavier_normal_(self.bias_k) 87 | if self.bias_v is not None: 88 | xavier_normal_(self.bias_v) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 93 | elif isinstance(m, nn.BatchNorm2d): 94 | nn.init.constant_(m.weight, 1) 95 | nn.init.constant_(m.bias, 0) 96 | 97 | def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, 98 | need_weights=True, static_kv=False, attn_mask=None): 99 | """ 100 | Inputs of forward function 101 | query: [target length, batch size, embed dim] 102 | key: [sequence length, batch size, embed dim] 103 | value: [sequence length, batch size, embed dim] 104 | key_padding_mask: if True, mask padding based on batch size 105 | attn_mask : triu mask for [T,T] or [T,S] 106 | incremental_state: if provided, previous time steps are cashed 107 | need_weights: output attn_output_weights 108 | static_kv: key and value are static 109 | 110 | Outputs of forward function 111 | attn_output: [target length, batch size, embed dim] 112 | attn_output_weights: [batch size, target length, sequence length] 113 | """ 114 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 115 | kv_same = key.data_ptr() == value.data_ptr() 116 | 117 | tgt_len, bsz, embed_dim = query.size() 118 | assert embed_dim == self.embed_dim 119 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 120 | assert key.size() == value.size() 121 | 122 | if incremental_state is not None: 123 | saved_state = self._get_input_buffer(incremental_state) 124 | if 'prev_key' in saved_state: 125 | # previous time steps are cached - no need to recompute 126 | # key and value if they are static 127 | if static_kv: 128 | assert kv_same and not qkv_same 129 | key = value = None 130 | else: 131 | saved_state = None 132 | 133 | if qkv_same: 134 | # self-attention 135 | q, k, v = self._in_proj_qkv(query) 136 | elif kv_same: 137 | # encoder-decoder attention 138 | q = self._in_proj_q(query) 139 | if key is None: 140 | assert value is None 141 | k = v = None 142 | else: 143 | k, v = self._in_proj_kv(key) 144 | else: 145 | q = self._in_proj_q(query) 146 | k = self._in_proj_k(key) 147 | v = self._in_proj_v(value) 148 | # q *= self.scaling 149 | q = q*self.scaling 150 | 151 | if self.bias_k is not None: 152 | assert self.bias_v is not None 153 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 154 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 155 | if attn_mask is not None: 156 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 157 | if key_padding_mask is not None: 158 | key_padding_mask = torch.cat( 159 | [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) 160 | 161 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 162 | if k is not None: 163 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 164 | if v is not None: 165 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 166 | # q([batch*head,s_len,head_dim]) 167 | # k([batch*head,t_len,head_dim]) 168 | if saved_state is not None: 169 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) 170 | if 'prev_key' in saved_state: 171 | prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) 172 | if static_kv: 173 | k = prev_key 174 | else: 175 | k = torch.cat((prev_key, k), dim=1) 176 | if 'prev_value' in saved_state: 177 | prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) 178 | if static_kv: 179 | v = prev_value 180 | else: 181 | v = torch.cat((prev_value, v), dim=1) 182 | saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) 183 | saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) 184 | 185 | self._set_input_buffer(incremental_state, saved_state) 186 | 187 | src_len = k.size(1) 188 | 189 | if key_padding_mask is not None: 190 | assert key_padding_mask.size(0) == bsz 191 | assert key_padding_mask.size(1) == src_len 192 | 193 | if self.add_zero_attn: 194 | src_len += 1 195 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 196 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 197 | if attn_mask is not None: 198 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 199 | if key_padding_mask is not None: 200 | key_padding_mask = torch.cat( 201 | [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1) 202 | 203 | # step: q*k^T [batch*head,t_len,src_len] 204 | # attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 205 | # sqrt(q*k^T) 206 | attn_output_weights = torch.bmm(q, k.transpose(1, 2))/math.sqrt(self.head_dim) 207 | assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 208 | 209 | # step: mask_triu 210 | if attn_mask is not None: 211 | attn_mask = attn_mask.unsqueeze(0) 212 | # [1,t_len,s_len] 213 | attn_output_weights += attn_mask 214 | 215 | # step: key_padding 216 | if key_padding_mask is not None: 217 | attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) 218 | attn_output_weights = attn_output_weights.masked_fill( 219 | key_padding_mask.unsqueeze(1).unsqueeze(2), 220 | float('-inf'), 221 | ) 222 | # key_padding[batch,1,1,s_len] 223 | attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len) 224 | 225 | attn_output_weights = F.softmax( 226 | attn_output_weights.float(), dim=-1, 227 | dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype) 228 | attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) 229 | 230 | attn_output = torch.bmm(attn_output_weights, v) 231 | assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 232 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 233 | attn_output = self.out_proj(attn_output) 234 | 235 | if need_weights: 236 | # average attention weights over heads 237 | attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) 238 | attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads 239 | else: 240 | attn_output_weights = None 241 | 242 | return attn_output, attn_output_weights 243 | 244 | def _in_proj_qkv(self, query): 245 | # return self._in_proj(query).chunk(3, dim=-1) 246 | query = query.permute(1, 2, 0) 247 | query = torch.unsqueeze(query, dim=2) 248 | res = self.conv3(query) 249 | res = torch.squeeze(res, dim=2) 250 | res = res.permute(2, 0, 1) 251 | return res.chunk(3, dim=-1) 252 | 253 | def _in_proj_kv(self, key): 254 | # return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 255 | key = key.permute(1, 2, 0) 256 | key = torch.unsqueeze(key, dim=2) 257 | res = self.conv2(key) 258 | res = torch.squeeze(res, dim=2) 259 | res = res.permute(2, 0, 1) 260 | return res.chunk(2, dim=-1) 261 | 262 | def _in_proj_q(self, query): 263 | # return self._in_proj(query, end=self.embed_dim) 264 | query = query.permute(1, 2, 0) 265 | query = torch.unsqueeze(query, dim=2) 266 | res = self.conv1(query) 267 | res = torch.squeeze(res, dim=2) 268 | res = res.permute(2, 0, 1) 269 | return res 270 | 271 | def _in_proj_k(self, key): 272 | # return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 273 | key = key.permute(1, 2, 0) 274 | key = torch.unsqueeze(key, dim=2) 275 | res = self.conv1(key) 276 | res = torch.squeeze(res, dim=2) 277 | res = res.permute(2, 0, 1) 278 | return res 279 | 280 | def _in_proj_v(self, value): 281 | # return self._in_proj(value, start=2 * self.embed_dim) 282 | value = value.permute(1, 2, 0) 283 | value = torch.unsqueeze(value, dim=2) 284 | res = self.conv1(value) 285 | res = torch.squeeze(res, dim=2) 286 | res = res.permute(2, 0, 1) 287 | return res 288 | 289 | def _in_proj(self, input, start=0, end=None): 290 | weight = self.in_proj_weight 291 | bias = self.in_proj_bias 292 | weight = weight[start:end, :] 293 | # if bias is not None: 294 | # bias = bias[start:end] 295 | return F.linear(input, weight, bias) -------------------------------------------------------------------------------- /cdistnet/model/stage/tps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | class TPS_SpatialTransformerNetwork(nn.Module): 9 | """ Rectification Network of RARE, namely TPS based STN """ 10 | 11 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 12 | """ Based on RARE TPS 13 | input: 14 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 15 | I_size : (height, width) of the input image I 16 | I_r_size : (height, width) of the rectified image I_r 17 | I_channel_num : the number of channels of the input image I 18 | output: 19 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 20 | """ 21 | super(TPS_SpatialTransformerNetwork, self).__init__() 22 | self.F = F 23 | self.I_size = I_size 24 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 25 | self.I_channel_num = I_channel_num 26 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 27 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 28 | 29 | def forward(self, batch_I): 30 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 31 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 32 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 33 | 34 | if torch.__version__ > "1.2.0": 35 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 36 | else: 37 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 38 | 39 | return batch_I_r 40 | 41 | 42 | class LocalizationNetwork(nn.Module): 43 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 44 | 45 | def __init__(self, F, I_channel_num): 46 | super(LocalizationNetwork, self).__init__() 47 | self.F = F 48 | self.I_channel_num = I_channel_num 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 51 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 52 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 53 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 54 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 55 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 56 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 57 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 58 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 59 | ) 60 | 61 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 62 | self.localization_fc2 = nn.Linear(256, self.F * 2) 63 | 64 | # Init fc2 in LocalizationNetwork 65 | self.localization_fc2.weight.data.fill_(0) 66 | """ see RARE paper Fig. 6 (a) """ 67 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 68 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 69 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 70 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 71 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 72 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 73 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 74 | 75 | def forward(self, batch_I): 76 | """ 77 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 78 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 79 | """ 80 | batch_size = batch_I.size(0) 81 | features = self.conv(batch_I).view(batch_size, -1) 82 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 83 | return batch_C_prime 84 | 85 | 86 | class GridGenerator(nn.Module): 87 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 88 | 89 | def __init__(self, F, I_r_size): 90 | """ Generate P_hat and inv_delta_C for later """ 91 | super(GridGenerator, self).__init__() 92 | self.eps = 1e-6 93 | self.I_r_height, self.I_r_width = I_r_size 94 | self.F = F 95 | self.C = self._build_C(self.F) # F x 2 96 | self.P = self._build_P(self.I_r_width, self.I_r_height) 97 | ## for multi-gpu, you need register buffer 98 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 99 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 100 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer 101 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 102 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 103 | 104 | def _build_C(self, F): 105 | """ Return coordinates of fiducial points in I_r; C """ 106 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 107 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 108 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 109 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 110 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 111 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 112 | return C # F x 2 113 | 114 | def _build_inv_delta_C(self, F, C): 115 | """ Return inv_delta_C which is needed to calculate T """ 116 | hat_C = np.zeros((F, F), dtype=float) # F x F 117 | for i in range(0, F): 118 | for j in range(i, F): 119 | r = np.linalg.norm(C[i] - C[j]) 120 | hat_C[i, j] = r 121 | hat_C[j, i] = r 122 | np.fill_diagonal(hat_C, 1) 123 | hat_C = (hat_C ** 2) * np.log(hat_C) 124 | # print(C.shape, hat_C.shape) 125 | delta_C = np.concatenate( # F+3 x F+3 126 | [ 127 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 128 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 129 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 130 | ], 131 | axis=0 132 | ) 133 | inv_delta_C = np.linalg.inv(delta_C) 134 | return inv_delta_C # F+3 x F+3 135 | 136 | def _build_P(self, I_r_width, I_r_height): 137 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 138 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 139 | P = np.stack( # self.I_r_width x self.I_r_height x 2 140 | np.meshgrid(I_r_grid_x, I_r_grid_y), 141 | axis=2 142 | ) 143 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 144 | 145 | def _build_P_hat(self, F, C, P): 146 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 147 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 148 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 149 | P_diff = P_tile - C_tile # n x F x 2 150 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 151 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 152 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 153 | return P_hat # n x F+3 154 | 155 | def build_P_prime(self, batch_C_prime): 156 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 157 | batch_size = batch_C_prime.size(0) 158 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 159 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 160 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 161 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 162 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 163 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 164 | return batch_P_prime # batch_size x n x 2 165 | -------------------------------------------------------------------------------- /cdistnet/model/translator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class Beam(): 6 | ''' Beam search ''' 7 | 8 | def __init__(self, size, device=False): 9 | 10 | self.size = size 11 | self._done = False 12 | 13 | # The score for each translation on the beam. 14 | self.scores = torch.zeros((size,), dtype=torch.float, device=device) 15 | self.all_scores = [] 16 | 17 | # The backpointers at each time-step. 18 | self.prev_ks = [] 19 | 20 | # The outputs at each time-step. 21 | self.next_ys = [torch.full((size,), 0, dtype=torch.long, device=device)] 22 | self.next_ys[0][0] = 2 23 | 24 | def get_current_state(self): 25 | "Get the outputs for the current timestep." 26 | return self.get_tentative_hypothesis() 27 | 28 | def get_current_origin(self): 29 | "Get the backpointers for the current timestep." 30 | return self.prev_ks[-1] 31 | 32 | @property 33 | def done(self): 34 | return self._done 35 | 36 | def advance(self, word_prob): 37 | "Update beam status and check if finished or not." 38 | num_words = word_prob.size(1) 39 | 40 | # Sum the previous scores. 41 | if len(self.prev_ks) > 0: 42 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) 43 | else: 44 | beam_lk = word_prob[0] 45 | 46 | flat_beam_lk = beam_lk.view(-1) 47 | 48 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 49 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort 50 | 51 | self.all_scores.append(self.scores) 52 | self.scores = best_scores 53 | 54 | # bestScoresId is flattened as a (beam x word) array, 55 | # so we need to calculate which word and beam each score came from 56 | # print("print:{}....{}".format(best_scores_id,num_words)) 57 | # print("prev_k:{}".format(best_scores_id/ num_words)) 58 | prev_k = best_scores_id // num_words 59 | self.prev_ks.append(prev_k) 60 | self.next_ys.append(best_scores_id - prev_k * num_words) 61 | 62 | # End condition is when top-of-beam is EOS. 63 | if self.next_ys[-1][0].item() == 3: 64 | self._done = True 65 | self.all_scores.append(self.scores) 66 | 67 | return self._done 68 | 69 | def sort_scores(self): 70 | "Sort the scores." 71 | return torch.sort(self.scores, 0, True) 72 | 73 | def get_the_best_score_and_idx(self): 74 | "Get the score of the best in the beam." 75 | scores, ids = self.sort_scores() 76 | return scores[1], ids[1] 77 | 78 | def get_tentative_hypothesis(self): 79 | "Get the decoded sequence for the current timestep." 80 | 81 | if len(self.next_ys) == 1: 82 | dec_seq = self.next_ys[0].unsqueeze(1) 83 | else: 84 | _, keys = self.sort_scores() 85 | # print("self.prev_ks:{}\n".format(self.prev_ks)) 86 | # print("keys :{} type:{}\n".format(keys, type(keys))) 87 | hyps = [self.get_hypothesis(k) for k in keys] 88 | hyps = [[2] + h for h in hyps] 89 | dec_seq = torch.LongTensor(hyps) 90 | 91 | return dec_seq 92 | 93 | def get_hypothesis(self, k): 94 | """ Walk back to construct the full hypothesis. """ 95 | hyp = [] 96 | for j in range(len(self.prev_ks) - 1, -1, -1): 97 | # print("j :{} k:{}\n".format(j,k)) 98 | # print("k_pre_ks :{} type:{} \n".format(self.prev_ks[j][k],type(self.prev_ks[j][k]))) 99 | hyp.append(self.next_ys[j+1][k]) 100 | k = self.prev_ks[j][k] 101 | 102 | return list(map(lambda x: x.item(), hyp[::-1])) 103 | 104 | 105 | class Translator(object): 106 | def __init__(self, cfg, model): 107 | self.cfg = cfg 108 | self.device = torch.device(cfg.test.device) 109 | self.model = model 110 | self.keep_aspect_ratio = cfg.keep_aspect_ratio 111 | self.stages = {'TPS': cfg.tps_block, 'Feat': cfg.feature_block} 112 | 113 | def translate_batch(self, images): 114 | ''' Translation work in one batch ''' 115 | 116 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 117 | ''' Indicate the position of an instance in a tensor. ''' 118 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 119 | 120 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): 121 | ''' Collect tensor parts associated to active instances. ''' 122 | 123 | _, *d_hs = beamed_tensor.size() 124 | n_curr_active_inst = len(curr_active_inst_idx) 125 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 126 | 127 | beamed_tensor = beamed_tensor.contiguous().view(n_prev_active_inst, -1) 128 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) 129 | beamed_tensor = beamed_tensor.view(*new_shape) 130 | 131 | return beamed_tensor 132 | 133 | def collate_active_info( 134 | src_enc, inst_idx_to_position_map, active_inst_idx_list): 135 | # Sentences which are still active are collected, 136 | # so the decoder will not run on completed sentences. 137 | n_prev_active_inst = len(inst_idx_to_position_map) 138 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 139 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 140 | 141 | # active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) 142 | active_src_enc = collect_active_part(src_enc.permute(1, 0, 2), active_inst_idx, n_prev_active_inst, n_bm).permute(1, 0, 2) 143 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 144 | 145 | return active_src_enc, active_inst_idx_to_position_map 146 | 147 | def beam_decode_step( 148 | inst_dec_beams, len_dec_seq, enc_output, inst_idx_to_position_map, n_bm, memory_key_padding_mask): 149 | ''' Decode and update beam status, and then return active beam idx ''' 150 | 151 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 152 | ''' 153 | prepare beam_tgt for decoder 154 | :param inst_dec_beams: class beam(num:batch_size) 155 | :param len_dec_seq: max_len(beam search len) 156 | :return: beam_tgt 157 | ''' 158 | # init beams as 2 for start 159 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] 160 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) 161 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) 162 | return dec_partial_seq 163 | 164 | def prepare_beam_memory_key_padding_mask(inst_dec_beams, memory_key_padding_mask, n_bm): 165 | keep = [] 166 | for idx, each in enumerate(memory_key_padding_mask): 167 | if not inst_dec_beams[idx].done: 168 | keep.append(idx) 169 | memory_key_padding_mask = memory_key_padding_mask[torch.tensor(keep)] 170 | len_s = memory_key_padding_mask.shape[-1] 171 | n_inst = memory_key_padding_mask.shape[0] 172 | memory_key_padding_mask = memory_key_padding_mask.repeat(1, n_bm).view(n_inst * n_bm, len_s) 173 | return memory_key_padding_mask 174 | 175 | 176 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 177 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 178 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 179 | return dec_partial_pos 180 | 181 | def predict_word(dec_seq, enc_output, n_active_inst, n_bm, memory_key_padding_mask): 182 | # ------ decoder predict word------- 183 | sem_seq,sem_mask,sem_key_padding_mask = self.model.semantic_branch(dec_seq) 184 | pos_seq = self.model.positional_branch(sem_seq) 185 | dec_output = self.model.mdcdp(sem_seq,enc_output,pos_seq, 186 | tgt_mask=sem_mask, 187 | tgt_key_padding_mask=sem_key_padding_mask, 188 | memory_key_padding_mask=memory_key_padding_mask, 189 | ).permute(1, 0, 2) 190 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h 191 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) 192 | word_prob = word_prob.view(n_active_inst, n_bm, -1) 193 | 194 | return word_prob 195 | 196 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): 197 | active_inst_idx_list = [] 198 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 199 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) 200 | if not is_inst_complete: 201 | active_inst_idx_list += [inst_idx] 202 | 203 | return active_inst_idx_list 204 | 205 | # --- beam decoder step start --- 206 | # dec_seq : decoder word num for one tgt 207 | n_active_inst = len(inst_idx_to_position_map) 208 | 209 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) 210 | if self.keep_aspect_ratio: 211 | memory_key_padding_mask = prepare_beam_memory_key_padding_mask(inst_dec_beams, memory_key_padding_mask, n_bm) 212 | else: 213 | memory_key_padding_mask = None 214 | # dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) 215 | word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, memory_key_padding_mask) 216 | 217 | # Update the beam with predicted word prob information and collect incomplete instances 218 | active_inst_idx_list = collect_active_inst_idx_list( 219 | inst_dec_beams, word_prob, inst_idx_to_position_map) 220 | 221 | return active_inst_idx_list 222 | 223 | def collect_hypothesis_and_scores(inst_dec_beams, n_best): 224 | all_hyp, all_scores = [], [] 225 | for inst_idx in range(len(inst_dec_beams)): 226 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 227 | all_scores += [scores[:n_best]] 228 | # print("best :{} type:{}\n".format(tail_idxs[:n_best])) 229 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] 230 | all_hyp += [hyps] 231 | return all_hyp, all_scores 232 | 233 | # ---start predict--- 234 | with torch.no_grad(): 235 | #-- Encode 236 | images = images.to(self.device) 237 | src_enc,memory_key_padding_mask = self.model.visual_branch(images) 238 | # print(src_enc.shape) 239 | # -------- note delete ------ 240 | #-- Repeat data for beam search 241 | 242 | n_bm = self.cfg.beam_size 243 | 244 | src_enc = src_enc.permute(1, 0, 2) 245 | n_inst, len_s, d_h = src_enc.size() 246 | # src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 247 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h).permute(1, 0, 2) 248 | # memory_key_padding_mask = memory_key_padding_mask.repeat(1, n_bm).view(n_inst * n_bm, len_s) 249 | #-- Prepare beams 250 | # n_inst == batch_size 251 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] 252 | 253 | #-- Bookkeeping for active or not 254 | active_inst_idx_list = list(range(n_inst)) 255 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 256 | 257 | #-- Decode 258 | for len_dec_seq in range(1, 50): 259 | # word iter for sentense 260 | 261 | active_inst_idx_list = beam_decode_step( 262 | inst_dec_beams, len_dec_seq, src_enc, inst_idx_to_position_map, n_bm, memory_key_padding_mask) 263 | 264 | if not active_inst_idx_list: 265 | break # all instances have finished their path to 266 | 267 | src_enc, inst_idx_to_position_map = collate_active_info( 268 | src_enc, inst_idx_to_position_map, active_inst_idx_list) 269 | # each decoder word transform to vocab 270 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.cfg.n_best) 271 | 272 | return batch_hyp, batch_scores 273 | 274 | -------------------------------------------------------------------------------- /cdistnet/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/optim/__init__.py -------------------------------------------------------------------------------- /cdistnet/optim/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import distributed 4 | 5 | def reduce_tensor(tensor): 6 | rt = tensor.clone() 7 | distributed.all_reduce(rt, op=distributed.reduce_op.SUM) 8 | rt /= distributed.get_world_size()#总进程数 9 | return rt 10 | 11 | def cal_performance(pred, tgt, local_rank,smoothing=True): 12 | # pred(b*tgt_len,vacab_size) 13 | # tgt(b,max_len) 14 | loss = cal_loss(pred, tgt, local_rank, smoothing) 15 | pred = pred.max(1)[1] 16 | tgt = tgt.contiguous().view(-1) 17 | non_pad_mask = tgt.ne(0) 18 | n_correct = pred.eq(tgt) 19 | 20 | # loss = reduce_tensor(loss.data) 21 | # n_correct = reduce_tensor(n_correct) 22 | 23 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() 24 | return loss, n_correct 25 | 26 | 27 | def cal_loss(pred, tgt, local_rank, smoothing=True): 28 | tgt = tgt.contiguous().view(-1) 29 | if smoothing: 30 | eps = 0.1 31 | n_class = pred.size(1) 32 | 33 | one_hot = torch.zeros_like(pred).scatter(1, tgt.view(-1, 1), 1) 34 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 35 | log_prb = F.log_softmax(pred, dim=1) 36 | 37 | non_pad_mask = tgt.ne(0) 38 | loss = -(one_hot * log_prb).sum(dim=1) 39 | loss = loss.masked_select(non_pad_mask).mean() 40 | else: 41 | loss = F.cross_entropy(pred, tgt, ignore_index=0, reduction='mean') 42 | return loss -------------------------------------------------------------------------------- /cdistnet/optim/optim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ScheduledOptim(): 5 | '''A simple wrapper class for learning rate scheduling''' 6 | 7 | def __init__(self, optimizer, d_model, n_warmup_steps, n_current_steps=0): 8 | self._optimizer = optimizer 9 | self.n_warmup_steps = n_warmup_steps 10 | self.n_current_steps = n_current_steps 11 | self.init_lr = np.power(d_model, -0.5) 12 | # curr_epoch = step + 1 (4 -- 5th) 13 | self.step2 = 7 14 | self.step2_lr = 0.00001 15 | 16 | def step_and_update_lr(self,epoch = 0): 17 | "Step with the inner optimizer" 18 | self._update_learning_rate(epoch) 19 | self._optimizer.step() 20 | 21 | def zero_grad(self): 22 | "Zero out the gradients by the inner optimizer" 23 | self._optimizer.zero_grad() 24 | 25 | def _get_lr_scale(self): 26 | return np.min([ 27 | np.power(self.n_current_steps, -0.5), 28 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 29 | 30 | def _update_learning_rate(self,epoch): 31 | ''' Learning rate scheduling per step ''' 32 | 33 | self.n_current_steps += 1 34 | lr = self.init_lr * self._get_lr_scale() 35 | if epoch >= self.step2: 36 | lr = self.step2_lr 37 | 38 | for param_group in self._optimizer.param_groups: 39 | param_group['lr'] = lr 40 | 41 | # optimizer = optim.Adam(model.parameters(), lr=0.001) 42 | # lmbda = lambda epoch: 0.9**(epoch // 300) if epoch < 13200 else 10**(-2) 43 | # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lmbda) 44 | 45 | # if self.last_epoch < self.warmup_steps: 46 | # return (self.end_lr - self.start_lr) * float( 47 | # self.last_epoch) / float(self.warmup_steps) + self.start_lr 48 | 49 | class WarmupOptim(): 50 | '''A simple wrapper class for learning rate scheduling''' 51 | 52 | def __init__(self, optimizer, d_model, n_warmup_steps, n_current_steps=0,current_epoch=0): 53 | self._optimizer = optimizer 54 | self.n_warmup_steps = n_warmup_steps 55 | self.current_epoch = current_epoch 56 | self.n_current_steps = n_current_steps 57 | self.start_lr = 0.0 58 | # self.init_lr = 0.001 59 | self.step_lr = [0.001,0.0001,0.00001] 60 | self.step = [4,6] #curr_epoch = step + 1 (4 -- 5th) 61 | 62 | def step_and_update_lr(self,epoch): 63 | "Step with the inner optimizer" 64 | self._update_learning_rate(epoch) 65 | self._optimizer.step() 66 | 67 | def zero_grad(self): 68 | "Zero out the gradients by the inner optimizer" 69 | self._optimizer.zero_grad() 70 | 71 | def _get_lr_scale(self,epoch): 72 | if epoch <= self.step[0]: 73 | return self.step_lr[0] 74 | if epoch <= self.step[1]: 75 | return self.step_lr[1] 76 | return self.step_lr[2] 77 | 78 | def _update_learning_rate(self,epoch): 79 | ''' Learning rate scheduling per step ''' 80 | 81 | self.n_current_steps += 1 82 | if self.n_current_steps < self.n_warmup_steps: 83 | lr = (self.step_lr[0] - self.start_lr) * float( 84 | self.n_current_steps) / float(self.n_warmup_steps) + self.start_lr 85 | else: 86 | lr = self._get_lr_scale(epoch) 87 | 88 | for param_group in self._optimizer.param_groups: 89 | param_group['lr'] = lr -------------------------------------------------------------------------------- /cdistnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/cdistnet/utils/__init__.py -------------------------------------------------------------------------------- /cdistnet/utils/dict_36.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 0 6 | 1 7 | 2 8 | 3 9 | 4 10 | 5 11 | 6 12 | 7 13 | 8 14 | 9 15 | a 16 | b 17 | c 18 | d 19 | e 20 | f 21 | g 22 | h 23 | i 24 | j 25 | k 26 | l 27 | m 28 | n 29 | o 30 | p 31 | q 32 | r 33 | s 34 | t 35 | u 36 | v 37 | w 38 | x 39 | y 40 | z -------------------------------------------------------------------------------- /cdistnet/utils/gen_img.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image,ImageDraw,ImageFont 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | def gen_img(str1): 7 | ''' 8 | text2img 9 | :param str1: string:text 10 | :return: ndarray:array 11 | ''' 12 | # str1 = 'fdsakjhfdskjfh' 13 | num=len(str1)+1 14 | print(num) 15 | img = Image.new('RGB', (15*num,32), (0, 0, 0)) 16 | draw = ImageDraw.Draw(img) 17 | fontpath = "simsun.ttc" 18 | font = ImageFont.truetype(fontpath, 32) 19 | 20 | #绘制文字信息 21 | draw.text((0, 0), str1, font = font, fill = (255,255,255)) 22 | img = np.array(img) 23 | img=cv2.resize(img,(96,32)) 24 | # plt.imshow(img) 25 | return img -------------------------------------------------------------------------------- /cdistnet/utils/init.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def init_weights_xavier(module): 4 | for m in module.modules(): 5 | if isinstance(m, nn.Conv2d): 6 | nn.init.xavier_normal_(m.weight.data) 7 | if m.bias is not None: 8 | m.bias.data.zero_() -------------------------------------------------------------------------------- /cdistnet/utils/submit_with_lexicon.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import editdistance 4 | import pdb 5 | 6 | ### for icdar2003, svt, iiit5k, they all have lexicon_50 7 | def ic03_lex(lexdir, lex_type, gt_file, submit_file, name_file): 8 | gt = open(gt_file, 'r') 9 | gt_lines = gt.readlines() 10 | gt.close() 11 | 12 | sub = open(submit_file, 'r') 13 | sub_lines = sub.readlines() 14 | sub.close() 15 | 16 | imgname = open(name_file, 'r') 17 | img_lines = imgname.readlines() 18 | imgname.close() 19 | 20 | # for lexicon full 21 | if lex_type== 'full': 22 | sub_file = submit_file[:-4] + '_full.txt' 23 | sub_fout = open(sub_file, 'w') 24 | for i in range(len(gt_lines)):#for each gt 25 | suf, gt, _ = (gt_lines[i].strip()).split('"') 26 | sub = (sub_lines[i].strip()).split('"')[1] 27 | 28 | lex_file = open(os.path.join(lexdir, 'lexicon_Full.txt'), 'r') 29 | lex = lex_file.readlines() 30 | lex_file.close() 31 | 32 | min_dis = 10000 33 | min_word = sub 34 | for word in lex: 35 | word = word.strip() 36 | word = word.lower() 37 | dis = editdistance.eval(sub, word) 38 | if dis < min_dis: 39 | min_word = word 40 | min_dis = dis 41 | sub_fout.write(suf + '"' + str(min_word) + '"\n') 42 | sub_fout.close() 43 | # for lexicon 50 44 | else: 45 | sub_file = submit_file[:-4] + '_50.txt' 46 | sub_fout = open(sub_file, 'w') 47 | for i in range(len(gt_lines)):#for each gt 48 | base_name = img_lines[i].strip().split('.')[0] 49 | suf, gt, _ = (gt_lines[i].strip()).split('"') 50 | sub = (sub_lines[i].strip()).split('"')[1] 51 | 52 | lex_file = open(os.path.join(lexdir, 'lexicon_50', 'lexicon_' + base_name + '_' + gt + '.txt'), 'r') 53 | lex = lex_file.readlines() 54 | lex_file.close() 55 | 56 | min_dis = 10000 57 | min_word = sub 58 | for word in lex: 59 | word = word.strip() 60 | word = word.lower() 61 | dis = editdistance.eval(sub, word) 62 | if dis < min_dis: 63 | min_word = word 64 | min_dis = dis 65 | sub_fout.write(suf + '"' + str(min_word) + '"\n') 66 | sub_fout.close() 67 | 68 | return sub_file 69 | 70 | ### for svt 71 | def svt_lex(lexdir, lex_type, gt_file, submit_file, name_file): 72 | gt = open(gt_file, 'r') 73 | gt_lines = gt.readlines() 74 | gt.close() 75 | 76 | sub = open(submit_file, 'r') 77 | sub_lines = sub.readlines() 78 | sub.close() 79 | 80 | imgname = open(name_file, 'r') 81 | img_lines = imgname.readlines() 82 | imgname.close() 83 | 84 | # for lexicon 50 85 | 86 | sub_file = submit_file[:-4] + '_50.txt' 87 | sub_fout = open(sub_file, 'w') 88 | for i in range(len(gt_lines)):#for each gt 89 | base_name = img_lines[i].strip().split('.')[0] 90 | suf, gt, _ = (gt_lines[i].strip()).split('"') 91 | sub = (sub_lines[i].strip()).split('"')[1] 92 | 93 | lex_file = open(os.path.join(lexdir, 'lexicon_50', 'lexicon_' + base_name + '_' + gt + '.txt'), 'r') 94 | lex = lex_file.readlines() 95 | lex_file.close() 96 | 97 | min_dis = 10000 98 | min_word = sub 99 | for word in lex: 100 | word = word.strip() 101 | word = word.lower() 102 | dis = editdistance.eval(sub, word) 103 | if dis < min_dis: 104 | min_word = word 105 | min_dis = dis 106 | sub_fout.write(suf + '"' + str(min_word) + '"\n') 107 | sub_fout.close() 108 | 109 | return sub_file 110 | 111 | 112 | ### for svt-p 113 | def svt_p_lex(lexdir, lex_type, gt_file, submit_file, name_file): 114 | gt = open(gt_file, 'r') 115 | gt_lines = gt.readlines() 116 | gt.close() 117 | 118 | sub = open(submit_file, 'r') 119 | sub_lines = sub.readlines() 120 | sub.close() 121 | 122 | imgname = open(name_file, 'r') 123 | img_lines = imgname.readlines() 124 | imgname.close() 125 | 126 | # for lexicon 50 127 | 128 | sub_file = submit_file[:-4] + '_50.txt' 129 | sub_fout = open(sub_file, 'w') 130 | for i in range(len(gt_lines)):#for each gt 131 | base_name = img_lines[i].strip().split('.')[0] 132 | suf, gt, _ = (gt_lines[i].strip()).split('"') 133 | sub = (sub_lines[i].strip()).split('"')[1] 134 | 135 | lex_file = open(os.path.join(lexdir, 'lexicon_50', 'lexicon_' + base_name + '_' + gt + '.txt'), 'r') 136 | lex = lex_file.readlines() 137 | lex_file.close() 138 | 139 | min_dis = 10000 140 | min_word = sub 141 | for word in lex: 142 | word = word.strip() 143 | word = word.lower() 144 | dis = editdistance.eval(sub, word) 145 | if dis < min_dis: 146 | min_word = word 147 | min_dis = dis 148 | sub_fout.write(suf + '"' + str(min_word) + '"\n') 149 | sub_fout.close() 150 | 151 | return sub_file 152 | 153 | ### for iiit5k 154 | def iiit5k_lex(lexdir, lex_type, gt_file, submit_file, name_file): 155 | gt = open(gt_file, 'r') 156 | gt_lines = gt.readlines() 157 | gt.close() 158 | 159 | sub = open(submit_file, 'r') 160 | sub_lines = sub.readlines() 161 | sub.close() 162 | 163 | imgname = open(name_file, 'r') 164 | img_lines = imgname.readlines() 165 | imgname.close() 166 | 167 | # for lexicon full 168 | if lex_type== '1k': 169 | sub_file = submit_file[:-4] + '_1k.txt' 170 | sub_fout = open(sub_file, 'w') 171 | for i in range(len(gt_lines)):#for each gt 172 | base_name = img_lines[i].strip().split('.')[0] 173 | suf, gt, _ = (gt_lines[i].strip()).split('"') 174 | sub = (sub_lines[i].strip()).split('"')[1] 175 | 176 | lex_file = open(os.path.join(lexdir, 'lexicon_1k', 'lexicon_' + base_name + '_' + gt + '.txt'), 'r') 177 | lex = lex_file.readlines() 178 | lex_file.close() 179 | 180 | min_dis = 10000 181 | min_word = sub 182 | for word in lex: 183 | word = word.strip() 184 | word = word.lower() 185 | dis = editdistance.eval(sub, word) 186 | if dis < min_dis: 187 | min_word = word 188 | min_dis = dis 189 | sub_fout.write(suf + '"' + str(min_word) + '"\n') 190 | sub_fout.close() 191 | # for lexicon 50 192 | else: 193 | sub_file = submit_file[:-4] + '_50.txt' 194 | sub_fout = open(sub_file, 'w') 195 | for i in range(len(gt_lines)):#for each gt 196 | base_name = img_lines[i].strip().split('.')[0] 197 | suf, gt, _ = (gt_lines[i].strip()).split('"') 198 | sub = (sub_lines[i].strip()).split('"')[1] 199 | 200 | lex_file = open(os.path.join(lexdir, 'lexicon_50', 'lexicon_' + base_name + '_' + gt + '.txt'), 'r') 201 | lex = lex_file.readlines() 202 | lex_file.close() 203 | 204 | min_dis = 10000 205 | min_word = sub 206 | for word in lex: 207 | word = word.strip() 208 | word = word.lower() 209 | dis = editdistance.eval(sub, word) 210 | if dis < min_dis: 211 | min_word = word 212 | min_dis = dis 213 | sub_fout.write(suf + '"' + str(min_word) + '"\n') 214 | sub_fout.close() 215 | 216 | return sub_file -------------------------------------------------------------------------------- /cdistnet/utils/tensorboardx.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | import torch 4 | from tensorboardX import SummaryWriter 5 | 6 | class TensorboardLogger(object): 7 | def __init__(self, log_dir, start_iter=0): 8 | self.iteration = start_iter 9 | self.writer = self._get_tensorboard_writer(log_dir) 10 | 11 | @staticmethod 12 | def _get_tensorboard_writer(log_dir): 13 | timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H:%M') 14 | tb_logger = SummaryWriter('{}/{}'.format(log_dir, timestamp)) 15 | return tb_logger 16 | 17 | def add_scalar(self, ** kwargs): 18 | if self.writer: 19 | for k, v in kwargs.items(): 20 | if isinstance(v, torch.Tensor): 21 | v = v.item() 22 | assert isinstance(v, (float, int)) 23 | self.writer.add_scalar(k, v, self.iteration) 24 | 25 | def add_graph(self,model,data=(torch.randn(4,32),torch.randn(4,1,128,32))): 26 | # tgt(b,max_len) 27 | # image(b,c,h,w) 28 | self.writer.add_graph(model,input_to_model=data) 29 | 30 | def add_histogram(self,k,v): 31 | self.writer.add_histogram(k,v,self.iteration) 32 | 33 | def update_iter(self, iteration): 34 | self.iteration = iteration -------------------------------------------------------------------------------- /cdistnet_env.yaml: -------------------------------------------------------------------------------- 1 | einops==0.3.0 2 | matplotlib==3.4.1 3 | mmcv==1.2.6 4 | notebook==6.2.0 5 | numpy 6 | opencv-python==4.5.1.48 7 | Pillow==8.2.0 8 | tensorboardX==2.1 9 | tensorflow==2.4.1 10 | thop==0.0.31.post2005241907 11 | timm==0.4.5 12 | torch==1.5.0+cu92 13 | torchvision==0.6.0+cu92 14 | tornado==6.1 15 | tqdm==4.56.2 -------------------------------------------------------------------------------- /configs/CDistNet_config.py: -------------------------------------------------------------------------------- 1 | dst_vocab = 'cdistnet/utils/dict_36.txt' 2 | dst_vocab_size = 40 3 | rgb2gray =False 4 | keep_aspect_ratio = False 5 | width = 128 #100 6 | height = 32 #32 7 | max_width = 180 8 | is_lower = True 9 | cnn_num = 2 10 | leakyRelu = False 11 | hidden_units = 512 12 | ff_units = 1024 #ff 13 | scale_embedding = True 14 | attention_dropout_rate = 0.0 15 | residual_dropout_rate = 0.1 16 | num_encoder_blocks = 3 17 | num_decoder_blocks = 3 18 | num_heads = 8 19 | beam_size = 10 20 | n_best = 1 21 | data_aug = True 22 | num_fiducial = 20 #number of fiducial points of TPS-STN 23 | train_method = 'origin' #dist: use distributed train method origin 24 | optim = 'origin' 25 | 26 | # method choice 27 | tps_block = 'TPS' # TPS None 28 | feature_block = 'Resnet45' # Resnet45 Resnet31 MTB 29 | 30 | train = dict( 31 | grads_clip=5, 32 | optimizer='adam_decay', # not used 33 | learning_rate_warmup_steps=10000, 34 | label_smoothing=True, # fixed in code 35 | shared_embedding=False, # not used 36 | device='cuda', 37 | gt_file=['../dataset/MJ/MJ_train/', 38 | '../dataset/MJ/MJ_test/', 39 | '../dataset/MJ/MJ_valid/', 40 | '../dataset/ST'], 41 | num_worker=16, 42 | # model_dir ='model/test', 43 | model_dir='models/reconstruct_CDistNet_3_10', 44 | num_epochs=10, 45 | # gpu_device_ids=[1,2,3,4,5,6,7], 46 | batch_size=1400, # 4gpu 1800 47 | model=None, 48 | # model ='models/new_baseline_sem_pos_pos_vis_3_32*128_tps_resnet45_epoch_6/model_epoch_5.pth', 49 | # current_epoch=6, # epoch start 50 | save_iter=10000, 51 | display_iter=100, 52 | tfboard_iter=100, 53 | eval_iter=3000, 54 | ) 55 | 56 | 57 | val = dict( 58 | model='models/baseline_two_32*100_1d_2cnn-test/model_epoch_1.pth', # abandon 59 | device='cuda', 60 | # is_val_gt=True, 61 | image_dir='datasets/NewVersion/val_data', 62 | gt_file= [ 63 | './dataset/eval/IC13_857', 64 | './dataset/eval/SVT', 65 | './dataset/eval/IIIT5k_3000', 66 | './dataset/eval/IC15_1811', 67 | './dataset/eval/SVTP', 68 | './dataset/eval/CUTE80'], 69 | # gt_file=['datasets/NewVersion/val_data/val_data.txt'], 70 | # gt_file='../dataset/MJ/MJ_valid/', 71 | batch_size=800, # 4gpu 1800 72 | num_worker=16, 73 | ) 74 | 75 | 76 | test = dict( 77 | test_one=False, 78 | device='cuda', 79 | rotate=False, 80 | best_acc_test=True, # test best_acc 81 | eval_all=False, # test all model_epoch_9_iter_4080.pth 82 | s_epoch=7, # start_epoch 83 | e_epoch=10, 84 | avg_s=-1, 85 | avg_e=9, 86 | avg_all=False, 87 | is_test_gt=False, 88 | image_dir= None, #if is_test_gt == False,needn't use image_dir 89 | test_list=[ 90 | './dataset/eval/IC13_857', 91 | './dataset/eval/SVT', 92 | './dataset/eval/IIIT5k_3000', 93 | './dataset/eval/IC15_1811', 94 | './dataset/eval/SVTP', 95 | './dataset/eval/CUTE80' 96 | ], 97 | batch_size=128, 98 | num_worker=8, 99 | model_dir='models/reconstruct_CDistNetv3_3_10', # load test model 100 | script_path='utils/Evaluation_TextRecog/script.py', 101 | python_path='/data1/zs/anaconda3/envs/py2/bin/python' #abandon 102 | ) 103 | -------------------------------------------------------------------------------- /configs/debug_config.py: -------------------------------------------------------------------------------- 1 | dst_vocab = '/data6/zhengtianlun/temp/dict_99.txt' # 98 + 空格 2 | dst_vocab_size = 99 3 | rgb2gray = True 4 | keep_aspect_ratio = False 5 | width = 96 #100 6 | height = 32 #32 7 | max_width = 180 8 | is_lower = False # True:训练的时候转成小写 9 | cnn_num = 2 10 | leakyRelu = False 11 | hidden_units = 512 12 | ff_units = 1024 13 | scale_embedding = True 14 | attention_dropout_rate = 0.0 15 | residual_dropout_rate = 0.1 16 | num_encoder_blocks = 4 17 | num_decoder_blocks = 4 18 | num_heads = 8 19 | beam_size = 10 20 | n_best = 1 21 | num_fiducial = 20 #number of fiducial points of TPS-STN 22 | use_squ = True #if fasle: use diag for tgt mask(not ready) 23 | train_method = 'origin' #dist: use distributed train method origin 24 | 25 | # method choice 26 | tps_block = None # TPS None 27 | feature_block = 'origin' # None (not use cnn) origin Resnet 28 | patch_block = 'wh' # None wh_2_4_8 wh w+h+wh+avg vit w+h w+h+wh wh_fusion 29 | custom_encoder = 'trans_blstm' # None swin-trans pvt text2img-msa(not ready) 30 | custom_decoder = None 31 | transformer = 'transformer' # transformer patch4_trans 32 | 33 | 34 | train = dict( 35 | grads_clip=5, 36 | optimizer='adam_decay', # not used 37 | learning_rate_warmup_steps=10000, 38 | label_smoothing=True, # fixed in code 39 | shared_embedding=False, # not used 40 | device='cuda', 41 | # image_dir='/home/zhengsheng/datasets/TextRecog/mnt/ramdisk/max/90kDICT32px', 42 | # gt_file='/home/zhengsheng/datasets/TextRecog/mnt/ramdisk/max/90kDICT32px/annotation_train_two_Synth_shuf_clean.txt', 43 | image_dir='/data6/zhengtianlun/temp/icdar2015_test/images', 44 | gt_file='/data6/zhengtianlun/temp/icdar2015_test/gt.txt', 45 | # hdf5='datasets/train.hdf5', # train_two.hdf5 train_keep_aspect_ratio.hdf5 train_two_keep_aspect_ratio.hdf5 46 | num_worker=16, 47 | model_dir='models/baseline_debug', # 模型保存的目录 48 | num_epochs=4, 49 | # gpu_device_ids=[0], 50 | batch_size=250, # 4gpu 1800 51 | model=None, # 加载的模型地址, None不加载 e.g. '/home/zhengsheng/github/NRTR/models/model_epoch_14.pth', 52 | current_epoch=15, # 从第几个epoch开始训练,根据加载的模型设置 e.g. 15 53 | save_iter=2000, 54 | display_iter=100, 55 | tfboard_iter=100, 56 | eval_iter=10, 57 | ) 58 | 59 | 60 | val = dict( 61 | model='models/baseline_two_32*100_1d_2cnn-test/model_epoch_1.pth', # 加载的模型, 训练的时候用不到 62 | device='cuda', 63 | image_dir='/data6/zhengtianlun/temp/icdar2015_test/images', 64 | gt_file='/data6/zhengtianlun/temp/icdar2015_test/gt.txt', 65 | # hdf5='datasets/val.hdf5', 66 | batch_size=1800, # 4gpu 1800 67 | num_worker=16, 68 | ) 69 | 70 | 71 | test = dict( 72 | device='cuda', 73 | rotate=False, # 测试时旋转90度 74 | eval_all=False, # 测试全部,包括两个epoch之间保存的模型。例如model_epoch_9_iter_4080.pth 75 | s_epoch=15, # 从第s_epoch开始测试,当s_epoch = -1: 不测试 76 | e_epoch=15, # 到第e_epoch结束测试 77 | avg_s=-1, # 从第avg_s到avg_e进行模型平均,当avg_s = -1: 不平均 78 | avg_e=9, 79 | avg_all=False, # 如果True,模型平均的时候包括两个epoch之间保存的模型。例如model_epoch_9_iter_4080.pth 80 | test_list=[ 81 | 'ICDAR2003_860', 82 | 'ICDAR2003_867', 83 | 'ICDAR2013_857', 84 | 'ICDAR2013_1015', 85 | 'ICDAR2015_1811', 86 | 'ICDAR2015_2077', 87 | 'IIIT5K', 88 | 'SVT', 89 | 'SVT-P', 90 | 'CUTE80' 91 | ], 92 | image_dir='datasets/NewVersion', 93 | batch_size=8, 94 | num_worker=1, 95 | model_dir='models/baseline_20_epoch_trans_blstm_4*4', # 测试加载的模型目录 96 | script_path='/data6/zhengtianlun/temp/Evaluation_TextRecog/script.py', 97 | python_path='/data1/zs/anaconda3/envs/py2/bin/python' 98 | ) 99 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import codecs 4 | import subprocess 5 | import csv 6 | import glob 7 | # os.environ['CUDA_VISIBLE_DEVICES']="6" 8 | from tqdm import tqdm 9 | from mmcv import Config 10 | import torch 11 | import torch.nn as nn 12 | import torch.distributed as dist 13 | 14 | from cdistnet.data.data import make_data_loader_test, make_lmdb_data_loader_test 15 | # from cdistnet.hdf5loader import make_data_loader 16 | from cdistnet.model.translator import Translator 17 | from cdistnet.utils.submit_with_lexicon import ic03_lex, iiit5k_lex, svt_lex, svt_p_lex 18 | from cdistnet.model.model import build_CDistNet 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Train CDistNet') 23 | parser.add_argument('--config', help='train config file path') 24 | parser.add_argument('--local_rank', default=-1, type=int,help='node rank for distributed training') 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def load_vocab(vocab=None, vocab_size=None): 30 | """ 31 | Load vocab from disk. The fisrt four items in the vocab should be , , , 32 | """ 33 | # print('Load set vocabularies as %s.' % vocab) 34 | vocab = [' ' if len(line.split()) == 0 else line.split()[0] for line in codecs.open(vocab, 'r', 'utf-8')] 35 | vocab = vocab[:vocab_size] 36 | assert len(vocab) == vocab_size 37 | word2idx = {word: idx for idx, word in enumerate(vocab)} 38 | idx2word = {idx: word for idx, word in enumerate(vocab)} 39 | return word2idx, idx2word 40 | 41 | 42 | def get_alphabet(dict_path): 43 | with open(dict_path, "r") as f: 44 | data = f.readlines() 45 | data = list(map(lambda x: x.strip(), data)) 46 | data = data[4:] 47 | return data 48 | 49 | 50 | def get_pred_gt_name(translator, idx2word, b_image, b_gt, b_name, num, dict_path, rotate,rgb2gray,is_test_gt=True): 51 | # rgb2gray=False 52 | gt_list, name_list, pred_list = [], [], [] 53 | alphabet = get_alphabet(dict_path) # not used 54 | if rotate: 55 | batch_hyp, batch_scores = translator.translate_batch( 56 | images=b_image.view(-1, b_image.shape[-2], b_image.shape[-1]).unsqueeze(dim=1) 57 | ) 58 | batch_scores = torch.cat(batch_scores, dim=0).view(-1, 3) 59 | _, idx = torch.max(batch_scores, 1) 60 | idx = torch.arange(0, idx.shape[0], dtype=torch.long) * 3 + idx.cpu() 61 | batch_hyp_ = [] 62 | for id, v in enumerate(batch_hyp): 63 | if id in idx: 64 | batch_hyp_.append(v) 65 | batch_hyp = batch_hyp_ 66 | else: 67 | if rgb2gray == False: 68 | batch_hyp, batch_scores = translator.translate_batch(images=b_image[:, :3, :, :]) 69 | else: 70 | batch_hyp, batch_scores = translator.translate_batch(images=b_image[:, 0:1, :, :]) 71 | for idx, seqs in enumerate(batch_hyp): 72 | for seq in seqs: 73 | seq = [x for x in seq if x != 3] 74 | pred = [idx2word[x] for x in seq] 75 | pred = ''.join(pred) 76 | flag = False 77 | if is_test_gt==False: 78 | num += 1 79 | pred_list.append('word_{}.png'.format(num) + ', "' + pred + '"\n') 80 | gt_list.append('word_{}.png'.format(num) + ', "' + b_gt[idx] + '"\n') 81 | name_list.append(b_name[idx] + '\n') 82 | else: 83 | num += 1 84 | pred_list.append('{}'.format(b_name[idx]) + ', "' + pred + '"\n') 85 | gt_list.append('{}'.format(b_name[idx]) + ', "' + b_gt[idx] + '"\n') 86 | name_list.append(b_name[idx] + '\n') 87 | return gt_list, name_list, pred_list, num 88 | 89 | 90 | def write_to_file(file_name, datas): 91 | with open(file_name, "w") as f: 92 | f.writelines(datas) 93 | 94 | 95 | def eval_and_save(script_path, gt_file, submit_file, python_path): 96 | cmd = "%s %s -g=%s -s=%s" % (python_path, script_path, gt_file, submit_file) 97 | print("cmd:{}".format(cmd)) 98 | p = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True) 99 | cmd_out = p.stdout.read().decode('utf-8') 100 | # Calculated!{"crwup": 0.9443155452436195, "tedupL": 0.9704312114989733, "tedL": 0.9704312114989733, "totalWords": 862, "crwN": 814.0, "crwupN": 814.0, "ted": 144.0, "tedup": 144.0, "detWords": 862, "crw": 0.9443155452436195} 101 | print("cmd_out: {}".format(cmd_out)) 102 | crwup = cmd_out[cmd_out.index('{') + 1:cmd_out.index('}')].split(', ')[0] 103 | res = crwup.split(': ')[1] 104 | res = float(res) * 100 105 | res = '%.2f' % float(res) 106 | return str(res) 107 | 108 | 109 | def start_eval(script_path, data_name, gt_file, pred_file, name_file, lexdir, python_path): 110 | submit_file_list = [] 111 | res = [eval_and_save(script_path, gt_file, pred_file, python_path)] 112 | if data_name == 'icdar2003': 113 | submit_file_list.append(ic03_lex(os.path.join(lexdir, data_name), '50', gt_file, pred_file, name_file)) 114 | submit_file_list.append(ic03_lex(os.path.join(lexdir, data_name), 'full', gt_file, pred_file, name_file)) 115 | elif data_name == 'svt': 116 | submit_file_list.append(svt_lex(os.path.join(lexdir, data_name), '50', gt_file, pred_file, name_file)) 117 | elif data_name == 'svt-p': 118 | submit_file_list.append(svt_p_lex(os.path.join(lexdir, data_name), '50', gt_file, pred_file, name_file)) 119 | elif data_name == 'iiit5k': 120 | submit_file_list.append(iiit5k_lex(os.path.join(lexdir, data_name), '50', gt_file, pred_file, name_file)) 121 | submit_file_list.append(iiit5k_lex(os.path.join(lexdir, data_name), '1k', gt_file, pred_file, name_file)) 122 | for submit_file in submit_file_list: 123 | res.append(eval_and_save(script_path, gt_file, submit_file, python_path)) 124 | return res 125 | 126 | def start_eval_simple(submit_list,gt_list,name_list): 127 | i = 0 128 | total = 0 129 | num = len(gt_list) 130 | err_list = [] 131 | for pred in submit_list: 132 | if pred.lower() == gt_list[i].lower(): 133 | total+=1 134 | # else: 135 | # err_list.append("{} image is diff:{} ---- {}\n".format(name_list[i],pred,gt_list[i].lower())) 136 | i +=1 137 | # print(err_list) 138 | # with open(err_dir, "w") as f: 139 | # f.writelines(err_list) 140 | return total / num *100.0 141 | 142 | def eval(cfg, args,model_path): 143 | # init dist_train 144 | if cfg.train_method=='dist': 145 | dist.init_process_group(backend='nccl') 146 | torch.cuda.set_device(args.local_rank) 147 | 148 | model = build_CDistNet(cfg) 149 | model.load_state_dict(torch.load(model_path)) 150 | if cfg.train_method=='dist': 151 | model.cuda(args.local_rank) 152 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 153 | else: 154 | device = torch.device(cfg.test.device) 155 | model.to(device) 156 | model.eval() 157 | 158 | translator = Translator(cfg, model=model) 159 | word2idx, idx2word = load_vocab(cfg.dst_vocab, cfg.dst_vocab_size) 160 | lexdir = cfg.test.image_dir 161 | result_line = [] 162 | for i,data_name in enumerate(cfg.test.test_list): 163 | print("dataset name: {}".format(data_name)) 164 | if cfg.test.is_test_gt ==True: 165 | test_dataloader = make_data_loader_test(cfg, lexdir[i], gt_file=data_name) 166 | else: 167 | test_dataloader = make_lmdb_data_loader_test(cfg, [data_name]) 168 | 169 | gt_list, name_list, pred_list = [], [], [] 170 | num = 0 171 | 172 | #start eval 173 | for iteration, batch in enumerate(tqdm(test_dataloader)): 174 | b_image, b_gt, b_name = batch[0], batch[1], batch[2] 175 | gt_list_, name_list_, pred_list_, num = get_pred_gt_name( 176 | translator, idx2word, b_image, b_gt, b_name, num, cfg.dst_vocab, cfg.test.rotate,cfg.rgb2gray,cfg.test.is_test_gt 177 | ) 178 | gt_list += gt_list_ 179 | name_list += name_list_ 180 | pred_list += pred_list_ 181 | # print("gt:{} \n pred:{}".format(gt_list,pred_list)) 182 | gt_file = os.path.join(cfg.test.model_dir, 'gt.txt') 183 | pred_file = os.path.join(cfg.test.model_dir, 'submit.txt') 184 | name_file = os.path.join(cfg.test.model_dir, 'name.txt') 185 | write_to_file(gt_file, gt_list) 186 | write_to_file(pred_file, pred_list) 187 | write_to_file(name_file, name_list) 188 | res_simple = start_eval_simple(pred_list,gt_list,name_list) 189 | print("res_simple_acc:{}".format(res_simple)) 190 | result_line += res_simple 191 | # if cfg.test.is_test_gt == False: 192 | # res = start_eval(cfg.test.script_path, data_name, gt_file, pred_file, name_file, lexdir, cfg.test.python_path) 193 | # result_line += res 194 | # print("result_line:{}".format(result_line)) 195 | result_line.insert(0, model_path.split('/')[-1]) 196 | print(os.path.join(cfg.test.model_dir, 'result.csv')) 197 | with open(os.path.join(cfg.test.model_dir, 'result.csv'), 'a') as f: 198 | writer = csv.writer(f) 199 | writer.writerow(result_line) 200 | 201 | 202 | def average(model, models): 203 | """Average models into model""" 204 | # with torch.no_grad(): 205 | # for ps in zip(*[m.parameters() for m in [model] + models]): 206 | # ps[0].copy_(torch.sum(torch.stack(ps[1:]), dim=0) / len(ps[1:])) 207 | 208 | with torch.no_grad(): 209 | for key in model.state_dict().keys(): 210 | v = [] 211 | for m in models: 212 | v.append(m.state_dict()[key]) 213 | v = torch.sum(torch.stack(v), dim=0) / len(v) 214 | model.state_dict()[key].copy_(v) 215 | 216 | 217 | def main(): 218 | args = parse_args() 219 | cfg = Config.fromfile(args.config) 220 | headers = cfg.test.test_list 221 | result_path = os.path.join(cfg.test.model_dir, 'result.csv') 222 | if not os.path.exists(result_path): 223 | with open(result_path, 'w') as f: 224 | writer = csv.writer(f) 225 | writer.writerow(headers) 226 | if cfg.test.best_acc_test: 227 | path2 = glob.glob(cfg.test.model_dir + '/epoch9_*.pth') 228 | path = glob.glob(cfg.test.model_dir + '/*_best_acc.pth') 229 | for model_path in path2: 230 | print("model: {}".format(model_path)) 231 | # eval(cfg, args,os.path.join(cfg.test.model_dir, model_path)) 232 | eval(cfg, args, model_path) 233 | for model_path in path: 234 | print("model: {}".format(model_path)) 235 | # eval(cfg, args,os.path.join(cfg.test.model_dir, model_path)) 236 | eval(cfg, args, model_path) 237 | # return 238 | 239 | # eval all 240 | if cfg.test.eval_all: 241 | paths = glob.glob(cfg.test.model_dir + "/*.pth") 242 | for model_path in paths: 243 | print("model: {}".format(model_path)) 244 | # eval(cfg, args,os.path.join(cfg.test.model_dir, model_path)) 245 | eval(cfg, args, model_path) 246 | return 247 | else: 248 | model_path_patten = cfg.test.model_dir + '/model_epoch_{}.pth' 249 | s, e = cfg.test.s_epoch, cfg.test.e_epoch 250 | if e < s: 251 | s, e = e, s 252 | if s != -1: 253 | for i in range(s, e + 1): 254 | model_path = model_path_patten.format(i) 255 | print("model: {}".format(model_path)) 256 | eval(cfg, args,model_path) 257 | 258 | # model average 259 | avg_s, avg_e = cfg.test.avg_s, cfg.test.avg_e 260 | if avg_e < avg_s: 261 | avg_s, avg_e = avg_e, avg_s 262 | if avg_s == -1: 263 | return 264 | models = [] 265 | if cfg.test.avg_all: 266 | for i in range(avg_s, avg_e + 1): 267 | model_paths = glob.glob(cfg.test.model_dir + '/model_epoch_{}*.pth'.format(i)) 268 | for model_path in model_paths: 269 | print("model: {}".format(model_path)) 270 | model = build_CDistNet(cfg) 271 | model.load_state_dict(torch.load(model_path)) 272 | models.append(model) 273 | else: 274 | for i in range(avg_s, avg_e + 1): 275 | model_path = model_path_patten.format(i) 276 | print("model: {}".format(model_path)) 277 | model = build_CDistNet(cfg) 278 | model.load_state_dict(torch.load(model_path)) 279 | models.append(model) 280 | model = build_CDistNet(cfg) 281 | # model = models[0] 282 | average(model, models) 283 | if cfg.test.avg_all: 284 | model_path = os.path.join(cfg.test.model_dir, 'model_epoch_avg({}-{}-all).pth'.format(avg_s, avg_e)) 285 | else: 286 | model_path = os.path.join(cfg.test.model_dir, 'model_epoch_avg({}-{}).pth'.format(avg_s, avg_e)) 287 | torch.save(model.state_dict(), model_path) 288 | eval(cfg, args,model_path) 289 | 290 | 291 | if __name__ == '__main__': 292 | main() 293 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.3.0 2 | matplotlib==3.4.1 3 | mmcv==1.2.6 4 | notebook==6.2.0 5 | numpy 6 | opencv-python==4.5.1.48 7 | Pillow==8.2.0 8 | tensorboardX==2.1 9 | tensorflow==2.4.1 10 | thop==0.0.31.post2005241907 11 | timm==0.4.5 12 | torch==1.5.0+cu92 13 | torchvision==0.6.0+cu92 14 | tornado==6.1 15 | tqdm==4.56.2 16 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | 4 | import cv2 5 | import torch 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from mmcv import Config 9 | import numpy as np 10 | import os 11 | os.environ["CUDA_VISIBLE_DEVICES"]="1" 12 | 13 | # from cdistnet.hdf5loader import make_data_loader 14 | from cdistnet.model.translator import Translator 15 | from cdistnet.model.model import build_CDistNet 16 | # from cdistnet.data.data import make_data_loader 17 | 18 | 19 | # test 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Train CDistNet') 23 | parser.add_argument('--i_path', type=str, default='1.jpg', 24 | help='Input image path') 25 | parser.add_argument('--model_path', type=str, default='models/new_baseline_dssnetv3_3_32*128_tps_resnet45_epoch_6/epoch9_best_acc.pth', 26 | help='Input model path') 27 | parser.add_argument('--config', type=str, default='configs/CDistNet_config.py', 28 | help='train config file path') 29 | parser.add_argument('--use-cuda', action='store_true', default=False, 30 | help='Use NVIDIA GPU acceleration') 31 | parser.add_argument('--test_one', default=True, 32 | help='test one image') 33 | parser.add_argument('--use_origin', default=True, 34 | help='use_origin_process') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def load_vocab(vocab=None, vocab_size=None): 40 | """ 41 | Load vocab from disk. The fisrt four items in the vocab should be , , , 42 | """ 43 | # print('Load set vocabularies as %s.' % vocab) 44 | vocab = [' ' if len(line.split()) == 0 else line.split()[0] for line in codecs.open(vocab, 'r', 'utf-8')] 45 | vocab = vocab[:vocab_size] 46 | assert len(vocab) == vocab_size 47 | word2idx = {word: idx for idx, word in enumerate(vocab)} 48 | idx2word = {idx: word for idx, word in enumerate(vocab)} 49 | return word2idx, idx2word 50 | 51 | 52 | def preprocess_image(image_path): 53 | img = cv2.imread(image_path, 1) 54 | assert img is not None 55 | img = np.float32(img) 56 | # # Opencv loads as BGR: 57 | img = img[:, :, ::-1] 58 | grayscale = transforms.Grayscale(num_output_channels=1) 59 | preprocessing = transforms.Compose([ 60 | transforms.ToTensor(), 61 | # normalize, 62 | transforms.ToPILImage(), 63 | grayscale, 64 | transforms.ToTensor(), 65 | ]) 66 | return preprocessing(img.copy()).unsqueeze(0) 67 | 68 | 69 | def origin_process_img(cfg, image_path): 70 | # self.data=[(img_path,text),...] 71 | if cfg.rgb2gray: 72 | image = Image.open(image_path).convert('L') 73 | else: 74 | image = Image.open(image_path).convert('RGB') 75 | assert image is not None 76 | image = image.resize((cfg.width, cfg.height), Image.ANTIALIAS) 77 | image = np.array(image) 78 | if cfg.rgb2gray: 79 | image = np.expand_dims(image, -1) 80 | image = np.expand_dims(image, -1) 81 | print(image.shape) 82 | image = np.expand_dims(image, -1) 83 | image = image.transpose((2, 3, 0, 1)) 84 | image = image.astype(np.float32) / 128. - 1. 85 | image = torch.from_numpy(image) 86 | # text = self.data[idx][1] 87 | # text = [self.word2idx.get(ch, 1) for ch in text] 88 | # text.insert(0, 2) 89 | # text.append(3) 90 | # target = np.array(text) 91 | return image 92 | 93 | 94 | def test(cfg): 95 | model = build_CDistNet(cfg) 96 | model.load_state_dict(torch.load( 97 | '/media/zs/zs/zs/code/NRTR/models/baseline_hdf5_100_32_two_local_MultiHeadAttention/model_epoch_avg.pth')) 98 | device = torch.device(cfg.test.device) 99 | model.to(device) 100 | model.eval() 101 | cfg.n_best = 5 102 | # vision more res 103 | translator = Translator(cfg, model) 104 | val_dataloader = make_data_loader(cfg, is_train=False) 105 | # word2idx, idx2word = load_vocab('datasets/en_vocab', 40) 106 | word2idx, idx2word = load_vocab(cfg.dst_vocab, cfg.dst_vocab_size) 107 | cnt = 1 108 | with open('pred.txt', 'w') as f: 109 | for batch in tqdm(val_dataloader): 110 | all_hyp, all_scores = translator.translate_batch(batch[0]) 111 | for idx_seqs in all_hyp: 112 | for idx_seq in idx_seqs: 113 | idx_seq = [x for x in idx_seq if x != 3] 114 | pred_line = '{}.png, "'.format(cnt) + ''.join([idx2word[idx] for idx in idx_seq]) + '"' 115 | f.write(pred_line + '\n') 116 | cnt += 1 117 | 118 | def get_parameter_number(net): 119 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 120 | return 'Trainable: {} M'.format(trainable_num/1000000) 121 | 122 | def test_one(cfg, args): 123 | # model_path = 'models/baseline_20_epoch_wh_44/model_epoch_10.pth' 124 | # prepare model 125 | model = build_CDistNet(cfg) 126 | # en = get_parameter_number(model.transformer.encoder) 127 | # de = get_parameter_number(model.transformer.decoder) 128 | # print('encoder:{}\ndecoder:{}\n'.format(en,de)) 129 | model_path = 'models/new_baseline_dssnetv3_3_32*128_tps_resnet45_epoch_6/epoch9_best_acc.pth' 130 | model.load_state_dict(torch.load(model_path)) 131 | device = torch.device(cfg.test.device) 132 | model.to(device) 133 | model.eval() 134 | translator = Translator(cfg, model) 135 | word2idx, idx2word = load_vocab(cfg.dst_vocab, cfg.dst_vocab_size) 136 | 137 | # if args['use_origin'] is True: 138 | img = origin_process_img(cfg, 'test/1.jpg') 139 | # else: 140 | # img = preprocess_image(args['img_path']) 141 | 142 | cnt = 0 143 | res = [] 144 | all_hyp, all_scores = translator.translate_batch(img) 145 | # print(all_hyp, all_scores) 146 | for idx_seqs in all_hyp: 147 | for idx_seq in idx_seqs: 148 | idx_seq = [x for x in idx_seq if x != 3] 149 | pred_line = 'Results{}:"'.format(cnt) + ''.join([idx2word[idx] for idx in idx_seq]) + '"' 150 | res.append('Vocab Prob:{}\nTotal Score:{}\n{}\n\n'\ 151 | .format(all_hyp[0][cnt],all_scores[0][cnt],pred_line)) 152 | cnt = cnt + 1 153 | print(res) 154 | return res 155 | 156 | def test_demo(args): 157 | print(args['config_path']) 158 | print(type(args['config_path'])) 159 | cfg = Config.fromfile(args['config_path']) 160 | return test_one(cfg, args) 161 | 162 | def main(): 163 | args = parse_args() 164 | print(args.config) 165 | print(type(args.config)) 166 | cfg = Config.fromfile(args.config) 167 | if args.test_one is True: 168 | test_one(cfg, args) 169 | else: 170 | test(cfg) 171 | 172 | 173 | if __name__ == '__main__': 174 | # test_demo() 175 | main() 176 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import logging 5 | # os.environ['CUDA_VISIBLE_DEVICES']="7" 6 | from mmcv import Config 7 | from thop import profile 8 | import torch 9 | from torch import optim 10 | import torch.nn as nn 11 | import torch.distributed as dist 12 | 13 | from cdistnet.model.model import build_CDistNet 14 | from cdistnet.data.data import make_data_loader, MyConcatDataset 15 | # from cdistnet.data.hdf5loader import make_data_loader 16 | from cdistnet.engine.trainer import do_train 17 | from cdistnet.optim.optim import ScheduledOptim, WarmupOptim 18 | from cdistnet.utils.tensorboardx import TensorboardLogger 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Train CDistNet') 22 | parser.add_argument('--config', type=str, default = 'configs/config.py',help='train config file path') 23 | parser.add_argument('--local_rank', default=-1, type=int,help='node rank for distributed training') 24 | args = parser.parse_args() 25 | cfg = Config.fromfile(args.config) 26 | # assert not os.path.exists(cfg.train.model_dir), "{} already exists".format(cfg.train.model_dir) 27 | if not os.path.exists(cfg.train.model_dir): 28 | os.makedirs(cfg.train.model_dir) 29 | shutil.copy(args.config, cfg.train.model_dir) 30 | return cfg,args 31 | 32 | def get_parameter_number(net): 33 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 34 | return 'Trainable: {} M'.format(trainable_num/1000000) 35 | 36 | def get_flop_param(net): 37 | image=torch.randn(1, 3, 32, 96) 38 | tgt = torch.rand(1,180).long() 39 | flops, params = profile(net, inputs=(image, tgt)) 40 | return 'Param: {} M, \n FLOPS: {} G'.format(params/1000000,flops/1000000000) 41 | 42 | def getlogger(mode_dir): 43 | logger = logging.getLogger('CDistNet') 44 | logger.setLevel(level=logging.INFO) 45 | handler = logging.FileHandler(os.path.join(mode_dir, 'log.txt' )) 46 | handler.setLevel(level=logging.INFO) 47 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 48 | handler.setFormatter(formatter) 49 | console = logging.StreamHandler() 50 | console.setLevel(level=logging.INFO) 51 | console.setFormatter(formatter) 52 | 53 | logger.addHandler(handler) 54 | logger.addHandler(console) 55 | return logger 56 | def train(cfg,args): 57 | # init dist_train 58 | if cfg.train_method=='dist': 59 | dist.init_process_group(backend='nccl') 60 | torch.cuda.set_device(args.local_rank) 61 | 62 | model = build_CDistNet(cfg) 63 | train_dataloader = make_data_loader(cfg, is_train=True) 64 | val_dataloader = [] 65 | for val_gt_file in cfg.val.gt_file: 66 | val_dataloader.append(make_data_loader(cfg, is_train=False,val_gt_file=val_gt_file)) 67 | n_current_steps = 0 68 | current_epoch = 0 69 | if cfg.train.model: 70 | model.load_state_dict(torch.load(cfg.train.model)) 71 | current_epoch = cfg.train.current_epoch 72 | n_current_steps = current_epoch * len(train_dataloader) 73 | 74 | parse_nums = get_parameter_number(model) 75 | # parse_nums2 = get_flop_param(model) 76 | if cfg.train_method=='dist': 77 | model.cuda(args.local_rank) 78 | # 同步bn 可以验证准确性 79 | # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model,torch.distributed.new_group(ranks=[0])) 80 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],find_unused_parameters=True) 81 | else: 82 | model = nn.DataParallel(model) 83 | # model = nn.DataParallel(model,device_ids=[cfg.train.gpu_device_ids]) 84 | device = torch.device(cfg.train.device) 85 | model = model.to(device) 86 | # print("device_count :{}".format(torch.cuda.device_count())) 87 | logger = getlogger(cfg.train.model_dir) 88 | if cfg.optim == 'warmup': 89 | logger.info("use warmup\n") 90 | optimizer = WarmupOptim( 91 | optim.Adam( 92 | filter(lambda x: x.requires_grad, model.parameters()), 93 | betas=(0.9, 0.98), 94 | eps=1e-09, 95 | ), 96 | cfg.hidden_units, cfg.train.learning_rate_warmup_steps, n_current_steps,current_epoch) 97 | else: 98 | optimizer = ScheduledOptim( 99 | optim.Adam( 100 | filter(lambda x: x.requires_grad, model.parameters()), 101 | betas=(0.9, 0.98), 102 | eps=1e-09, 103 | ), 104 | cfg.hidden_units, cfg.train.learning_rate_warmup_steps, n_current_steps) 105 | # optimizer = optim.Adam(model.parameters(), lr=cfg.train.learning_rate) 106 | 107 | 108 | meter = TensorboardLogger(cfg.train.model_dir) 109 | logger.info("model parameter:-------\n{}".format(parse_nums)) 110 | logger.info("model struct:-------\n{}".format(model)) 111 | # logger.info("model compute:-------\n{}".format(parse_nums2)): 112 | # logger.info("step1:pos: {}, feat: {}, sem: {}\n".format(cfg.step1[0], cfg.step1[1], cfg.step1[2])) 113 | # logger.info("step2:feat_sem: {},pos_feat: {},pos_sem: {}\n".format(cfg.step2[0], cfg.step2[1], cfg.step2[2])) 114 | do_train( 115 | model=model, 116 | train_dataloader=train_dataloader, 117 | val_dataloader=val_dataloader, 118 | optimizer=optimizer, 119 | device = args.local_rank if cfg.train_method=='dist' else device, 120 | num_epochs=cfg.train.num_epochs, 121 | current_epoch=current_epoch, 122 | logger=logger, 123 | meter=meter, 124 | save_iter=cfg.train.save_iter, 125 | display_iter=cfg.train.display_iter, 126 | tfboard_iter=cfg.train.tfboard_iter, 127 | eval_iter=cfg.train.eval_iter, 128 | model_dir=cfg.train.model_dir, 129 | label_smoothing=cfg.train.label_smoothing, 130 | grads_clip=cfg.train.grads_clip, 131 | cfg=cfg, 132 | ) 133 | 134 | 135 | def main(): 136 | cfg,args = parse_args() 137 | train(cfg,args) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /utils/Evaluation_TextRecog/constrain_select.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import editdistance 4 | 5 | datatype = sys.argv[1] 6 | lexdir = '../Dataset/TextRecog' 7 | lexdir = os.path.join(lexdir, datatype) 8 | 9 | 10 | ### for icdar2003, svt, iiit5k, they all have lexicon_50 11 | gt = open(sys.argv[2], 'r') 12 | gt_lines = gt.readlines() 13 | gt.close() 14 | 15 | sub = open(sys.argv[3], 'r') 16 | sub_lines = sub.readlines() 17 | sub.close() 18 | 19 | imgname = open(sys.argv[4], 'r') 20 | img_lines = imgname.readlines() 21 | imgname.close() 22 | 23 | sub_50 = open(((sys.argv[3])[:-4] + '_50.txt'), 'w') 24 | 25 | for i in range(len(gt_lines)):#for each gt 26 | img = img_lines[i].strip() 27 | gt = (gt_lines[i].strip()).split('\"')[1] 28 | sub = (sub_lines[i].strip()).split('\"')[1] 29 | 30 | lex_file = open(os.path.join(lexdir, 'lexicon_50', 'lexicon_' + img + '_' + gt + '.txt'), 'r') 31 | lex = lex_file.readlines() 32 | lex_file.close() 33 | 34 | min_dis = 10000 35 | min_word = sub 36 | for word in lex: 37 | word = word.strip() 38 | word = word.lower() 39 | dis = editdistance.eval(sub, word) 40 | if min_dis > dis: 41 | min_word = word 42 | min_dis = dis 43 | sub_50.write('word_' + str(i+1) + '.png, \"' + str(min_word) + '\"\n') 44 | sub_50.close() 45 | 46 | 47 | 48 | ### for icdar2003, it has lexicon_Full 49 | if datatype == 'icdar2003': 50 | gt = open(sys.argv[2], 'r') 51 | gt_lines = gt.readlines() 52 | gt.close() 53 | 54 | sub = open(sys.argv[3], 'r') 55 | sub_lines = sub.readlines() 56 | sub.close() 57 | 58 | imgname = open(sys.argv[4], 'r') 59 | img_lines = imgname.readlines() 60 | imgname.close() 61 | 62 | sub_Full = open(((sys.argv[3])[:-4] + '_Full.txt'), 'w') 63 | for i in range(len(gt_lines)):#for each gt 64 | img = img_lines[i].strip() 65 | gt = (gt_lines[i].strip()).split('\"')[1] 66 | sub = (sub_lines[i].strip()).split('\"')[1] 67 | 68 | lex_file = open(os.path.join(lexdir, 'lexicon_Full.txt'), 'r') 69 | lex = lex_file.readlines() 70 | lex_file.close() 71 | 72 | min_dis = 10000 73 | min_word = sub 74 | for word in lex: 75 | word = word.strip() 76 | word = word.lower() 77 | dis = editdistance.eval(sub, word) 78 | if min_dis > dis: 79 | min_word = word 80 | min_dis = dis 81 | sub_Full.write('word_' + str(i+1) + '.png, \"' + str(min_word) + '\"\n') 82 | sub_Full.close() 83 | 84 | elif datatype == 'iiit5k':## for iiit5k, it has lexicon_1k 85 | gt = open(sys.argv[2], 'r') 86 | gt_lines = gt.readlines() 87 | gt.close() 88 | 89 | sub = open(sys.argv[3], 'r') 90 | sub_lines = sub.readlines() 91 | sub.close() 92 | 93 | imgname = open(sys.argv[4], 'r') 94 | img_lines = imgname.readlines() 95 | imgname.close() 96 | 97 | sub_1k = open(((sys.argv[3])[:-4] + '_1k.txt'), 'w') 98 | for i in range(len(gt_lines)):#for each gt 99 | img = img_lines[i].strip() 100 | gt = (gt_lines[i].strip()).split('\"')[1] 101 | sub = (sub_lines[i].strip()).split('\"')[1] 102 | 103 | lex_file = open(os.path.join(lexdir, 'lexicon_1k', 'lexicon_' + img + '_' + gt + '.txt'), 'r') 104 | lex = lex_file.readlines() 105 | lex_file.close() 106 | 107 | min_dis = 10000 108 | min_word = sub 109 | for word in lex: 110 | word = word.strip() 111 | dis = editdistance.eval(sub, word) 112 | if min_dis > dis: 113 | min_word = word 114 | min_dis = dis 115 | sub_1k.write('word_' + str(i+1) + '.png, \"' + str(min_word) + '\"\n') 116 | sub_1k.close() 117 | -------------------------------------------------------------------------------- /utils/Evaluation_TextRecog/readme.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 2.7. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py -g=gt.txt -s=submit.txt 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | parameters: 15 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 16 | -s: Path of your method's results file. 17 | 18 | Optional parameters: 19 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 20 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 21 | 22 | Example: python script.py –g=gt.txt –s=submit.zip –o=./ -p='{\"IOU_CONSTRAINT\":0.8}' 23 | -------------------------------------------------------------------------------- /utils/Evaluation_TextRecog/readme_sff: -------------------------------------------------------------------------------- 1 | 离线安装editdistance 2 | 1、官网下载安装包:https://pypi.python.org/pypi/editdistance/0.2 3 | 2、pip install editdistance-0.2.tar.gz 4 | -------------------------------------------------------------------------------- /utils/Evaluation_TextRecog/rrc_evaluation_funcs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | #encoding: UTF-8 3 | import json 4 | import sys;sys.path.append('./') 5 | import zipfile 6 | import re 7 | import sys 8 | import os 9 | import codecs 10 | import importlib 11 | from StringIO import StringIO 12 | 13 | def print_help(): 14 | sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) 15 | sys.exit(2) 16 | 17 | 18 | def load_zip_file_keys(file,fileNameRegExp=''): 19 | """ 20 | Returns an array with the entries of the ZIP file that match with the regular expression. 21 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 22 | """ 23 | try: 24 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 25 | except : 26 | raise Exception('Error loading the ZIP archive.') 27 | 28 | pairs = [] 29 | 30 | for name in archive.namelist(): 31 | addFile = True 32 | keyName = name 33 | if fileNameRegExp!="": 34 | m = re.match(fileNameRegExp,name) 35 | if m == None: 36 | addFile = False 37 | else: 38 | if len(m.groups())>0: 39 | keyName = m.group(1) 40 | 41 | if addFile: 42 | pairs.append( keyName ) 43 | 44 | return pairs 45 | 46 | 47 | def load_zip_file(file,fileNameRegExp='',allEntries=False): 48 | """ 49 | Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. 50 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 51 | allEntries validates that all entries in the ZIP file pass the fileNameRegExp 52 | """ 53 | try: 54 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 55 | except : 56 | raise Exception('Error loading the ZIP archive') 57 | 58 | pairs = [] 59 | for name in archive.namelist(): 60 | addFile = True 61 | keyName = name 62 | if fileNameRegExp!="": 63 | m = re.match(fileNameRegExp,name) 64 | if m == None: 65 | addFile = False 66 | else: 67 | if len(m.groups())>0: 68 | keyName = m.group(1) 69 | 70 | if addFile: 71 | pairs.append( [ keyName , archive.read(name)] ) 72 | else: 73 | if allEntries: 74 | raise Exception('ZIP entry not valid: %s' %name) 75 | 76 | return dict(pairs) 77 | 78 | def decode_utf8(raw): 79 | """ 80 | Returns a Unicode object on success, or None on failure 81 | """ 82 | try: 83 | raw = codecs.decode(raw,'utf-8', 'replace') 84 | #extracts BOM if exists 85 | raw = raw.encode('utf8') 86 | if raw.startswith(codecs.BOM_UTF8): 87 | raw = raw.replace(codecs.BOM_UTF8, '', 1) 88 | return raw.decode('utf-8') 89 | except: 90 | return None 91 | 92 | def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 93 | """ 94 | This function validates that all lines of the file calling the Line validation function for each line 95 | """ 96 | utf8File = decode_utf8(file_contents) 97 | if (utf8File is None) : 98 | raise Exception("The file %s is not UTF-8" %fileName) 99 | 100 | lines = utf8File.split( "\r\n" if CRLF else "\n" ) 101 | for line in lines: 102 | line = line.replace("\r","").replace("\n","") 103 | if(line != ""): 104 | try: 105 | validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 106 | except Exception as e: 107 | raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) 108 | 109 | 110 | 111 | def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): 112 | """ 113 | Validate the format of the line. If the line is not valid an exception will be raised. 114 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 115 | Posible values are: 116 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 117 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 118 | """ 119 | get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 120 | 121 | 122 | def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 123 | """ 124 | Validate the format of the line. If the line is not valid an exception will be raised. 125 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 126 | Posible values are: 127 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 128 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 129 | Returns values from a textline. Points , [Confidences], [Transcriptions] 130 | """ 131 | confidence = 0.0 132 | transcription = ""; 133 | points = [] 134 | 135 | numPoints = 4; 136 | 137 | if LTRB: 138 | 139 | numPoints = 4; 140 | 141 | if withTranscription and withConfidence: 142 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 143 | if m == None : 144 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 145 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") 146 | elif withConfidence: 147 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 148 | if m == None : 149 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") 150 | elif withTranscription: 151 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) 152 | if m == None : 153 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") 154 | else: 155 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) 156 | if m == None : 157 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") 158 | 159 | xmin = int(m.group(1)) 160 | ymin = int(m.group(2)) 161 | xmax = int(m.group(3)) 162 | ymax = int(m.group(4)) 163 | if(xmax0 and imHeight>0): 171 | validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); 172 | validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); 173 | 174 | else: 175 | 176 | numPoints = 8; 177 | 178 | if withTranscription and withConfidence: 179 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 180 | if m == None : 181 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") 182 | elif withConfidence: 183 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 184 | if m == None : 185 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") 186 | elif withTranscription: 187 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) 188 | if m == None : 189 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") 190 | else: 191 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) 192 | if m == None : 193 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") 194 | 195 | points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] 196 | 197 | validate_clockwise_points(points) 198 | 199 | if (imWidth>0 and imHeight>0): 200 | validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); 201 | validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); 202 | validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); 203 | validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); 204 | 205 | 206 | if withConfidence: 207 | try: 208 | confidence = float(m.group(numPoints+1)) 209 | except ValueError: 210 | raise Exception("Confidence value must be a float") 211 | 212 | if withTranscription: 213 | posTranscription = numPoints + (2 if withConfidence else 1) 214 | transcription = m.group(posTranscription) 215 | m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) 216 | if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters 217 | transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") 218 | 219 | return points,confidence,transcription 220 | 221 | 222 | def validate_point_inside_bounds(x,y,imWidth,imHeight): 223 | if(x<0 or x>imWidth): 224 | raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) 225 | if(y<0 or y>imHeight): 226 | raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) 227 | 228 | def validate_clockwise_points(points): 229 | """ 230 | Validates that the points that the 4 points that dlimite a polygon are in clockwise order. 231 | """ 232 | 233 | if len(points) != 8: 234 | raise Exception("Points list not valid." + str(len(points))) 235 | 236 | point = [ 237 | [int(points[0]) , int(points[1])], 238 | [int(points[2]) , int(points[3])], 239 | [int(points[4]) , int(points[5])], 240 | [int(points[6]) , int(points[7])] 241 | ] 242 | edge = [ 243 | ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), 244 | ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), 245 | ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), 246 | ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) 247 | ] 248 | 249 | summatory = edge[0] + edge[1] + edge[2] + edge[3]; 250 | if summatory>0: 251 | raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") 252 | 253 | def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): 254 | """ 255 | Returns all points, confindences and transcriptions of a file in lists. Valid line formats: 256 | xmin,ymin,xmax,ymax,[confidence],[transcription] 257 | x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] 258 | """ 259 | pointsList = [] 260 | transcriptionsList = [] 261 | confidencesList = [] 262 | 263 | lines = content.split( "\r\n" if CRLF else "\n" ) 264 | for line in lines: 265 | line = line.replace("\r","").replace("\n","") 266 | if(line != "") : 267 | points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); 268 | pointsList.append(points) 269 | transcriptionsList.append(transcription) 270 | confidencesList.append(confidence) 271 | 272 | if withConfidence and len(confidencesList)>0 and sort_by_confidences: 273 | import numpy as np 274 | sorted_ind = np.argsort(-np.array(confidencesList)) 275 | confidencesList = [confidencesList[i] for i in sorted_ind] 276 | pointsList = [pointsList[i] for i in sorted_ind] 277 | transcriptionsList = [transcriptionsList[i] for i in sorted_ind] 278 | 279 | return pointsList,confidencesList,transcriptionsList 280 | 281 | def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): 282 | """ 283 | This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. 284 | Params: 285 | p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. 286 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 287 | validate_data_fn: points to a method that validates the corrct format of the submission 288 | evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results 289 | """ 290 | 291 | if (p == None): 292 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 293 | if(len(sys.argv)<3): 294 | print_help() 295 | 296 | evalParams = default_evaluation_params_fn() 297 | if 'p' in p.keys(): 298 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 299 | 300 | resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} 301 | try: 302 | validate_data_fn(p['g'], p['s'], evalParams) 303 | evalData = evaluate_method_fn(p['g'], p['s'], evalParams) 304 | resDict.update(evalData) 305 | 306 | except Exception, e: 307 | resDict['Message']= str(e) 308 | resDict['calculated']=False 309 | 310 | if 'o' in p: 311 | if not os.path.exists(p['o']): 312 | os.makedirs(p['o']) 313 | 314 | resultsOutputname = p['o'] + '/results.zip' 315 | outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) 316 | 317 | del resDict['per_sample'] 318 | if 'output_items' in resDict.keys(): 319 | del resDict['output_items'] 320 | 321 | outZip.writestr('method.json',json.dumps(resDict)) 322 | 323 | if not resDict['calculated']: 324 | if show_result: 325 | sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') 326 | if 'o' in p: 327 | outZip.close() 328 | return resDict 329 | 330 | if 'o' in p: 331 | if per_sample == True: 332 | for k,v in evalData['per_sample'].iteritems(): 333 | outZip.writestr( k + '.json',json.dumps(v)) 334 | 335 | if 'output_items' in evalData.keys(): 336 | for k, v in evalData['output_items'].iteritems(): 337 | outZip.writestr( k,v) 338 | 339 | outZip.close() 340 | 341 | if show_result: 342 | sys.stdout.write("Calculated!") 343 | sys.stdout.write(json.dumps(resDict['method'])) 344 | 345 | return resDict 346 | 347 | 348 | def main_validation(default_evaluation_params_fn,validate_data_fn): 349 | """ 350 | This process validates a method 351 | Params: 352 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 353 | validate_data_fn: points to a method that validates the corrct format of the submission 354 | """ 355 | try: 356 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 357 | evalParams = default_evaluation_params_fn() 358 | if 'p' in p.keys(): 359 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 360 | 361 | validate_data_fn(p['g'], p['s'], evalParams) 362 | print 'SUCCESS' 363 | sys.exit(0) 364 | except Exception as e: 365 | print str(e) 366 | sys.exit(101) -------------------------------------------------------------------------------- /utils/Evaluation_TextRecog/rrc_evaluation_funcs.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/utils/Evaluation_TextRecog/rrc_evaluation_funcs.pyc -------------------------------------------------------------------------------- /utils/Evaluation_TextRecog/script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import re 4 | from StringIO import StringIO 5 | import rrc_evaluation_funcs 6 | import importlib 7 | import pdb 8 | def evaluation_imports(): 9 | """ 10 | evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. 11 | """ 12 | return { 13 | 'xlsxwriter':'xlsxwriter', 14 | 'editdistance':'editdistance' 15 | } 16 | 17 | def default_evaluation_params(): 18 | """ 19 | default_evaluation_params: Default parameters to use for the validation and evaluation. 20 | """ 21 | return { 22 | 'SAMPLE_NAME_2_ID':'(?:word_)?([0-9]+).png', 23 | 'CRLF':False, 24 | 'DOUBLE_QUOTES':True 25 | } 26 | 27 | def validate_data(gtFilePath, submFilePath,evaluationParams): 28 | """ 29 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents). 30 | Validates also that there are no missing files in the folder. 31 | If some error detected, the method raises the error 32 | """ 33 | 34 | gtFile = rrc_evaluation_funcs.decode_utf8(open(gtFilePath,'rb').read()) 35 | if (gtFile is None) : 36 | raise Exception("The GT file is not UTF-8") 37 | 38 | gtLines = gtFile.split( "\r\n" if evaluationParams['CRLF'] else "\n" ) 39 | ids = {} 40 | for line in gtLines: 41 | line = line.replace("\r","").replace("\n","") 42 | if(line != ""): 43 | if (evaluationParams['DOUBLE_QUOTES']): 44 | m = re.match(r'^' + evaluationParams['SAMPLE_NAME_2_ID'] + ',\s?\"(.*)\"\s*\t?$',line) 45 | else: 46 | m = re.match(r'^' + evaluationParams['SAMPLE_NAME_2_ID'] + ',\s?(.*)$',line) 47 | 48 | if m == None : 49 | if (evaluationParams['DOUBLE_QUOTES']): 50 | raise Exception(("Line in GT not valid.Found: %s should be: %s"%(line,evaluationParams['SAMPLE_NAME_2_ID'] + ',transcription' )).encode('utf-8', 'replace')) 51 | else: 52 | raise Exception(("Line in GT not valid.Found: %s should be: %s"%(line,evaluationParams['SAMPLE_NAME_2_ID'] + ',"transcription"' )).encode('utf-8', 'replace')) 53 | ids[m.group(1)] = {'gt':m.group(2),'det':''} 54 | 55 | submFile = rrc_evaluation_funcs.decode_utf8(open(submFilePath,'rb').read()) 56 | if (submFile is None) : 57 | raise Exception("The Det file is not UTF-8") 58 | 59 | submLines = submFile.split("\r\n" if evaluationParams['CRLF'] else "\n") 60 | for line in submLines: 61 | line = line.replace("\r","").replace("\n","") 62 | if(line != ""): 63 | if (evaluationParams['DOUBLE_QUOTES']): 64 | m = re.match(r'^' + evaluationParams['SAMPLE_NAME_2_ID'] + ',\s?\"(.*)\"\s*\t?$',line) 65 | else: 66 | m = re.match(r'^' + evaluationParams['SAMPLE_NAME_2_ID'] + ',\s?(.*)$',line) 67 | 68 | if m == None : 69 | if (evaluationParams['DOUBLE_QUOTES']): 70 | raise Exception(("Line in results not valid.Found: %s should be: %s"%(line,evaluationParams['SAMPLE_NAME_2_ID'] + ',transcription' )).encode('utf-8', 'replace')) 71 | else: 72 | raise Exception(("Line in results not valid.Found: %s should be: %s"%(line,evaluationParams['SAMPLE_NAME_2_ID'] + ',"transcription"' )).encode('utf-8', 'replace')) 73 | try: 74 | ids[m.group(1)]['det'] = m.group(2) 75 | except Exception as e: 76 | raise Exception(("Line in results not valid. Line: %s Sample item not valid: %s" %(line,m.group(1))).encode('utf-8', 'replace')) 77 | 78 | def evaluate_method(gtFilePath, submFilePath,evaluationParams): 79 | """ 80 | Method evaluate_method: evaluate method and returns the results 81 | Results. Dictionary with the following values: 82 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } 83 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } 84 | """ 85 | 86 | for module,alias in evaluation_imports().iteritems(): 87 | globals()[alias] = importlib.import_module(module) 88 | 89 | gtFile = rrc_evaluation_funcs.decode_utf8(open(gtFilePath,'rb').read()) 90 | gtLines = gtFile.split("\r\n" if evaluationParams['CRLF'] else "\n")#'CRLF':False, 91 | ids = {} 92 | for line in gtLines: 93 | line = line.replace("\r","").replace("\n","") 94 | if(line != ""): 95 | if (evaluationParams['DOUBLE_QUOTES']):#'DOUBLE_QUOTES':True 96 | m = re.match(r'^' + evaluationParams['SAMPLE_NAME_2_ID'] + ',\s?\"(.+)\"$',line) 97 | ids[m.group(1)] = {"gt" : m.group(2).replace("\\\\", "\\").replace("\\\"", "\""),"det":""} 98 | else: 99 | m = re.match(r'^' + evaluationParams['SAMPLE_NAME_2_ID'] + ',\s?(.+)$',line) 100 | ids[m.group(1)] = {"gt" :m.group(2),"det":""} 101 | 102 | 103 | totalDistance = 0.0 104 | totalLength = 0.0 105 | totalDistanceUpper = 0.0 106 | totalLengthUpper = 0.0 107 | numWords = 0 108 | correctWords = 0.0 109 | correctWordsUpper = 0.0 110 | 111 | perSampleMetrics = {} 112 | 113 | submFile = rrc_evaluation_funcs.decode_utf8(open(submFilePath,'rb').read()) 114 | if (submFile is None) : 115 | raise Exception("The file is not UTF-8") 116 | 117 | xls_output = StringIO() 118 | workbook = xlsxwriter.Workbook(xls_output) 119 | worksheet = workbook.add_worksheet() 120 | worksheet.write(1, 1 , "sample") 121 | worksheet.write(1, 2 , "gt") 122 | worksheet.write(1, 3 , "E.D.") 123 | worksheet.write(1, 4 , "normalized") 124 | worksheet.write(1, 5 , "E.D. upper") 125 | worksheet.write(1, 6 , "normalized upper") 126 | 127 | submLines = submFile.split("\r\n" if evaluationParams['CRLF'] else "\n") 128 | for line in submLines: 129 | line = line.replace("\r","").replace("\n","") 130 | if(line != ""): 131 | 132 | numWords = numWords + 1 133 | 134 | if (evaluationParams['DOUBLE_QUOTES']): 135 | m = re.match(r'^' + evaluationParams['SAMPLE_NAME_2_ID'] + ',\s?\"(.*)\"\s*\t?$',line) 136 | detected = m.group(2).replace("\\\\", "\\").replace("\\\"", "\"") 137 | else: 138 | m = re.match(r'^' + evaluationParams['SAMPLE_NAME_2_ID'] + ',\s?(.*)$',line) 139 | detected = m.group(2) 140 | 141 | ids[m.group(1)]['det'] = detected 142 | 143 | row = 1 144 | for k,v in ids.iteritems(): 145 | 146 | gt = v['gt'] 147 | detected = v['det'] 148 | 149 | if gt == detected : 150 | correctWords = correctWords + 1 151 | 152 | if gt.upper() == detected.upper() : 153 | correctWordsUpper = correctWordsUpper + 1 154 | 155 | ''' 156 | distance = editdistance.eval(gt, detected) 157 | length = float(distance) / len (gt ) 158 | 159 | distance_up = editdistance.eval(gt.upper(), detected.upper()) 160 | length_up = float(distance_up) / len (gt ) 161 | 162 | totalDistance += distance 163 | totalLength += length 164 | 165 | totalDistanceUpper += distance_up 166 | totalLengthUpper += length_up 167 | ''' 168 | 169 | distance = editdistance.eval(gt, detected) 170 | length = len (gt ) 171 | 172 | distance_up = editdistance.eval(gt.upper(), detected.upper()) 173 | length_up = len (gt ) 174 | 175 | totalDistance += distance 176 | totalLength += length 177 | 178 | totalDistanceUpper += distance_up 179 | totalLengthUpper += length_up 180 | 181 | 182 | perSampleMetrics[k] = { 183 | 'gt':gt, 184 | 'det':detected, 185 | 'edist':distance, 186 | 'norm':length , 187 | 'edistUp':distance_up, 188 | 'normUp':length_up 189 | } 190 | row = row + 1 191 | worksheet.write(row, 1, k) 192 | worksheet.write(row, 2, gt) 193 | worksheet.write(row, 3, detected) 194 | worksheet.write(row, 4, distance) 195 | worksheet.write(row, 5, length) 196 | worksheet.write(row, 6, distance_up) 197 | worksheet.write(row, 7, length_up) 198 | 199 | methodMetrics = { 200 | 'totalWords':len(ids), 201 | 'detWords':numWords, 202 | 'crwN':correctWords, 203 | 'crwupN':correctWordsUpper, 204 | 'ted':totalDistance, 205 | 'tedL': 1.0 -(float(totalDistance) / totalLength), 206 | 'crw':0 if numWords==0 else correctWords/numWords, 207 | 'crwN':correctWords, 208 | 'tedup':totalDistanceUpper, 209 | 'tedupL': 1.0 - (float(totalDistanceUpper) / totalLengthUpper), 210 | 'crwup':0 if numWords==0 else correctWordsUpper/numWords, 211 | 'crwupN':correctWordsUpper 212 | } 213 | 214 | workbook.close() 215 | output_items = {'samples.xlsx':xls_output.getvalue()} 216 | xls_output.close() 217 | 218 | resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics,'output_items':output_items} 219 | return resDict; 220 | 221 | if __name__=='__main__': 222 | rrc_evaluation_funcs.main_evaluation(None,default_evaluation_params,validate_data,evaluate_method) 223 | -------------------------------------------------------------------------------- /utils/fig2_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/utils/fig2_00.png -------------------------------------------------------------------------------- /utils/fig5_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simplify23/CDistNet/21a5f13a665ed6afd5aeb6f552dc49dd61405a2d/utils/fig5_00.png -------------------------------------------------------------------------------- /utils/label_proc.py: -------------------------------------------------------------------------------- 1 | import os 2 | def process_label(add_path,gt_file): 3 | gt_list = [] 4 | with open(gt_file, 'r', encoding='UTF-8') as f: 5 | all = f.readlines() 6 | for each in all: 7 | each = add_path+each 8 | each = each.strip()+'\n' 9 | gt_list.append(each) 10 | print(gt_list) 11 | with open(gt_file, 'w', encoding='UTF-8') as f: 12 | f.writelines(gt_list) 13 | 14 | def write_txt(gt_list,gt_file): 15 | with open(gt_file, 'w', encoding='UTF-8') as f: 16 | f.writelines(gt_list) 17 | 18 | def strip_label(gt_file): 19 | gt_list = [] 20 | with open(gt_file, 'r', encoding='UTF-8') as f: 21 | all = f.readlines() 22 | for each in all: 23 | each = each.strip() 24 | gt_list.append(each+'\n') 25 | with open(gt_file, 'w', encoding='UTF-8') as f: 26 | f.writelines(gt_list) 27 | 28 | def dict_label(gt_file): 29 | val_label = '../train_data/ppdataset/train/labelval.txt' 30 | dict_path = '../ppocr/utils/ppocr_keys_v2.txt' 31 | # val_label = 'labelval.txt' 32 | gt_list = [] 33 | max_len = 0 34 | dict_char = {} 35 | set_char = set() 36 | set_val = set() 37 | with open(gt_file, 'r', encoding='UTF-8') as f: 38 | all= f.readlines() 39 | for each in all: 40 | origin_label = each 41 | each = each.strip().split('\t') 42 | text = each[1] 43 | for i in text: 44 | set_char.add(i+'\n') 45 | value = dict_char.setdefault(i, 0) 46 | if (value < 60 and value % 20 == 0 ) or value % 500 == 0: 47 | set_val.add(origin_label) 48 | value +=1 49 | dict_char.update({i:value}) 50 | max_len = len(text) if max_len= self.dst_w: 51 | break 52 | 53 | j = 0 54 | while 1: 55 | if self.dst_h <= j < self.dst_h + self.grid_size - 1: 56 | j = self.dst_h - 1 57 | elif j >= self.dst_h: 58 | break 59 | 60 | sw = 0 61 | swp = np.zeros(2, dtype=np.float32) 62 | swq = np.zeros(2, dtype=np.float32) 63 | new_pt = np.zeros(2, dtype=np.float32) 64 | cur_pt = np.array([i, j], dtype=np.float32) 65 | 66 | k = 0 67 | for k in range(self.pt_count): 68 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 69 | break 70 | 71 | w[k] = 1. / ( 72 | (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) + 73 | (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1])) 74 | 75 | sw += w[k] 76 | swp = swp + w[k] * np.array(self.dst_pts[k]) 77 | swq = swq + w[k] * np.array(self.src_pts[k]) 78 | 79 | if k == self.pt_count - 1: 80 | pstar = 1 / sw * swp 81 | qstar = 1 / sw * swq 82 | 83 | miu_s = 0 84 | for k in range(self.pt_count): 85 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 86 | continue 87 | pt_i = self.dst_pts[k] - pstar 88 | miu_s += w[k] * np.sum(pt_i * pt_i) 89 | 90 | cur_pt -= pstar 91 | cur_pt_j = np.array([-cur_pt[1], cur_pt[0]]) 92 | 93 | for k in range(self.pt_count): 94 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 95 | continue 96 | 97 | pt_i = self.dst_pts[k] - pstar 98 | pt_j = np.array([-pt_i[1], pt_i[0]]) 99 | 100 | tmp_pt = np.zeros(2, dtype=np.float32) 101 | tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \ 102 | np.sum(pt_j * cur_pt) * self.src_pts[k][1] 103 | tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \ 104 | np.sum(pt_j * cur_pt_j) * self.src_pts[k][1] 105 | tmp_pt *= (w[k] / miu_s) 106 | new_pt += tmp_pt 107 | 108 | new_pt += qstar 109 | else: 110 | new_pt = self.src_pts[k] 111 | 112 | self.rdx[j, i] = new_pt[0] - i 113 | self.rdy[j, i] = new_pt[1] - j 114 | 115 | j += self.grid_size 116 | i += self.grid_size 117 | 118 | def gen_img(self): 119 | src_h, src_w = self.src.shape[:2] 120 | dst = np.zeros_like(self.src, dtype=np.float32) 121 | 122 | for i in np.arange(0, self.dst_h, self.grid_size): 123 | for j in np.arange(0, self.dst_w, self.grid_size): 124 | ni = i + self.grid_size 125 | nj = j + self.grid_size 126 | w = h = self.grid_size 127 | if ni >= self.dst_h: 128 | ni = self.dst_h - 1 129 | h = ni - i + 1 130 | if nj >= self.dst_w: 131 | nj = self.dst_w - 1 132 | w = nj - j + 1 133 | 134 | di = np.reshape(np.arange(h), (-1, 1)) 135 | dj = np.reshape(np.arange(w), (1, -1)) 136 | delta_x = self.__bilinear_interp( 137 | di / h, dj / w, self.rdx[i, j], self.rdx[i, nj], 138 | self.rdx[ni, j], self.rdx[ni, nj]) 139 | delta_y = self.__bilinear_interp( 140 | di / h, dj / w, self.rdy[i, j], self.rdy[i, nj], 141 | self.rdy[ni, j], self.rdy[ni, nj]) 142 | nx = j + dj + delta_x * self.trans_ratio 143 | ny = i + di + delta_y * self.trans_ratio 144 | nx = np.clip(nx, 0, src_w - 1) 145 | ny = np.clip(ny, 0, src_h - 1) 146 | nxi = np.array(np.floor(nx), dtype=np.int32) 147 | nyi = np.array(np.floor(ny), dtype=np.int32) 148 | nxi1 = np.array(np.ceil(nx), dtype=np.int32) 149 | nyi1 = np.array(np.ceil(ny), dtype=np.int32) 150 | 151 | if len(self.src.shape) == 3: 152 | x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3)) 153 | y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3)) 154 | else: 155 | x = ny - nyi 156 | y = nx - nxi 157 | dst[i:i + h, j:j + w] = self.__bilinear_interp( 158 | x, y, self.src[nyi, nxi], self.src[nyi, nxi1], 159 | self.src[nyi1, nxi], self.src[nyi1, nxi1]) 160 | 161 | dst = np.clip(dst, 0, 255) 162 | dst = np.array(dst, dtype=np.uint8) 163 | 164 | return dst --------------------------------------------------------------------------------