├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── configs ├── abinet_finetune.yaml ├── abinet_pretrain_vision_model.yaml ├── pretrain_language_model.yaml ├── semimtr_finetune.yaml ├── semimtr_pretrain_vision_model.yaml └── template.yaml ├── data ├── DATA.md └── charset_36.txt ├── demo.py ├── figures ├── .DS_Store ├── abinet_model_architecture.svg ├── semimtr_consistency_regularization.svg └── semimtr_vision_pretraining.svg ├── main.py ├── notebook_demo.ipynb ├── requirements.txt ├── semimtr ├── __init__.py ├── callbacks │ ├── __init__.py │ └── callbacks.py ├── dataset │ ├── __init__.py │ ├── augmentation_pipelines.py │ ├── dataset.py │ ├── dataset_consistency_regularization.py │ ├── dataset_selfsupervised.py │ └── weighted_sampler.py ├── losses │ ├── __init__.py │ ├── consistency_regularization_loss.py │ ├── losses.py │ └── seqclr_loss.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── backbone.py │ ├── model.py │ ├── model_abinet.py │ ├── model_abinet_iter.py │ ├── model_alignment.py │ ├── model_fusion_consistency_regularization.py │ ├── model_fusion_teacher_student_ema.py │ ├── model_language.py │ ├── model_seqclr_vision.py │ ├── model_vision.py │ ├── projections.py │ ├── resnet.py │ ├── seqclr_proj.py │ └── transformer.py └── utils │ ├── __init__.py │ ├── test.py │ ├── transforms.py │ └── utils.py ├── setup.py └── tools ├── create_lmdb_dataset.py ├── crop_by_word_bb_syn90k.py ├── prepare_wikitext103.ipynb └── remove_labels_from_lmdb.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Semi-Supervised Learning for Text Recognition 2 | 3 | The official code implementation of SemiMTR [Paper](https://arxiv.org/pdf/2205.03873) 4 | | [Pretrained Models](#Pretrained-Models) | [SeqCLR Paper](https://arxiv.org/pdf/2012.10873) 5 | | [Citation](#citation) | [Demo](#demo). 6 | 7 | **[Aviad Aberdam](https://sites.google.com/view/aviad-aberdam/home), 8 | [Roy Ganz](https://il.linkedin.com/in/roy-ganz-270592), 9 | [Shai Mazor](https://il.linkedin.com/in/shai-mazor-529771b), 10 | [Ron Litman](https://scholar.google.com/citations?hl=iw&user=69GY5dEAAAAJ)** 11 | 12 | We introduce a multimodal semi-supervised learning algorithm for text recognition, which is customized for modern 13 | vision-language multimodal architectures. To this end, we present a unified one-stage pretraining method for the vision 14 | model, which suits scene text recognition. In addition, we offer a sequential, character-level, consistency 15 | regularization in which each modality teaches itself. Extensive experiments demonstrate state-of-the-art performance on 16 | multiple scene text recognition benchmarks. 17 | 18 | ### Figures 19 | 20 |
21 |

semimtr vision model pretraining

22 |

Figure 1: SemiMTR vision model pretraining: Contrastive learning

23 |
24 |

25 | 26 |
27 |

semimtr fine-tuning

28 |

Figure 2: SemiMTR model fine-tuning: Consistency regularization

29 |
30 | 31 | 36 | 37 | 38 | # Getting Started 39 | 40 |

41 | Run Demo with Pretrained Model 42 | 44 | Open In Colab 45 | 46 |

47 | 48 | ## Dependencies 49 | 50 | - Inference and demo requires PyTorch >= 1.7.1 51 | - For training and evaluation, install the dependencies 52 | 53 | ``` 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | ## Pretrained Models 58 | 59 | Download pretrained models: 60 | 61 | - [SemiMTR Real-L + Real-U](https://awscv-public-data.s3.us-west-2.amazonaws.com/semimtr/semimtr_real_l_and_u.pth) 62 | - [SemiMTR Real-L + Real-U + Synth](https://awscv-public-data.s3.us-west-2.amazonaws.com/semimtr/semimtr_real_l_and_u_and_synth.pth) 63 | - [SemiMTR Real-L + Real-U + TextOCR](https://awscv-public-data.s3.us-west-2.amazonaws.com/semimtr/semimtr_real_l_and_u_and_textocr.pth) 64 | 65 | Pretrained vision models: 66 | 67 | - [SemiMTR Vision Model Real-L + Real-U](https://awscv-public-data.s3.us-west-2.amazonaws.com/semimtr/semimtr_vision_model_real_l_and_u.pth) 68 | 69 | Pretrained language model: 70 | 71 | - [ABINet Language Model](https://awscv-public-data.s3.us-west-2.amazonaws.com/semimtr/abinet_language_model.pth) 72 | 73 | 74 | For fine-tuning SemiMTR without vision and language pretraining, locate the above models in a `workdir` directory, as follows: 75 | 76 | workdir 77 | ├── semimtr_vision_model_real_l_and_u.pth 78 | ├── abinet_language_model.pth 79 | └── semimtr_real_l_and_u.pth 80 | 81 | ### SemiMTR Models Accuracy 82 | 83 | |Training Data|IIIT|SVT|IC13|IC15|SVTP|CUTE|Avg.|COCO|RCTW|Uber|ArT|LSVT|MLT19|ReCTS|Avg.| 84 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| 85 | |Synth (ABINet)|96.4|93.2|95.1|82.1|89.0|89.2|91.2|63.1|59.7|39.6|68.3|59.5|85.0|86.7|52.0| 86 | |Real-L+U|97.0|95.8|96.1|84.7|90.7|94.1|92.8|72.2|76.1|58.5|71.6|77.1|90.4|92.4|65.4| 87 | |Real-L+U+Synth|97.4|96.8|96.5|84.7|92.9|95.1|93.3|73.0|75.7|58.6|72.4|77.5|90.4|93.1|65.8| 88 | |Real-L+U+TextOCR|97.3|97.7|96.9|86.0|92.2|94.4|93.7|73.8|77.7|58.6|73.5|78.3|91.3|93.3|66.1| 89 | 90 | 91 | ## Datasets 92 | 93 | - Download preprocessed lmdb dataset for training and 94 | evaluation. [Link](https://github.com/ku21fan/STR-Fewer-Labels/blob/main/data.md#download-preprocessed-lmdb-dataset-for-traininig-and-evaluation) 95 | - For training the language model, download WikiText103. [Link](https://github.com/FangShancheng/ABINet#datasets) 96 | - The final structure of `data` directory can be found in [`DATA.md`](data/DATA.md). 97 | 98 | ## Training 99 | 100 | 1. Pretrain vision model 101 | ``` 102 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/semimtr_pretrain_vision_model.yaml 103 | ``` 104 | 2. Pretrain language model 105 | ``` 106 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/pretrain_language_model.yaml 107 | ``` 108 | 3. Train SemiMTR 109 | ``` 110 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/semimtr_finetune.yaml 111 | ``` 112 | 113 | Note: 114 | 115 | - You can set the `checkpoint` path for vision and language models separately for specific pretrained model, or set 116 | to `None` to train from scratch 117 | 118 | ### Training ABINet 119 | 120 | 1. Pre-train vision model 121 | ``` 122 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/abinet_pretrain_vision_model.yaml 123 | ``` 124 | 2. Pre-train language model 125 | ``` 126 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/pretrain_language_model.yaml 127 | ``` 128 | 3. Train ABINet 129 | ``` 130 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/abinet_finetune.yaml 131 | ``` 132 | 133 | ## Evaluation 134 | 135 | ``` 136 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/semimtr_finetune.yaml --run_only_test 137 | ``` 138 | 139 | ## Arguments: 140 | 141 | - `--checkpoint /path/to/checkpoint` set the path of evaluation model 142 | - `--test_root /path/to/dataset` set the path of evaluation dataset 143 | - `--model_eval [alignment|vision]` which sub-model to evaluate 144 | 145 | ## Citation 146 | 147 | If you find our method useful for your research, please cite 148 | 149 | ``` 150 | @article{aberdam2022multimodal, 151 | title={Multimodal Semi-Supervised Learning for Text Recognition}, 152 | author={Aberdam, Aviad and Ganz, Roy and Mazor, Shai and Litman, Ron}, 153 | journal={arXiv preprint arXiv:2205.03873}, 154 | year={2022} 155 | } 156 | 157 | @inproceedings{aberdam2021sequence, 158 | title={Sequence-to-sequence contrastive learning for text recognition}, 159 | author={Aberdam, Aviad and Litman, Ron and Tsiper, Shahar and Anschel, Oron and Slossberg, Ron and Mazor, Shai and Manmatha, R and Perona, Pietro}, 160 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 161 | pages={15302--15312}, 162 | year={2021} 163 | } 164 | ``` 165 | 166 | ## Acknowledgements 167 | 168 | This implementation is based on the repository [ABINet](https://github.com/FangShancheng/ABINet). 169 | 170 | ## Security 171 | 172 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 173 | 174 | ## License 175 | 176 | This project is licensed under the Apache-2.0 License. 177 | 178 | ## Contact 179 | 180 | Feel free to contact us if there is any question: [Aviad Aberdam](mailto:aaberdam@amazon.com?subject=[GitHub-SemiMTR]) 181 | -------------------------------------------------------------------------------- /configs/abinet_finetune.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: train-abinet 3 | phase: train 4 | stage: train-supervised 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: [ 11 | 'data/training/label/real', 12 | ], 13 | batch_size: 384 14 | } 15 | valid: { 16 | roots: [ 17 | 'data/validation', 18 | ], 19 | batch_size: 384 20 | } 21 | test: { 22 | roots: [ 23 | 'data/evaluation/benchmark', 24 | 'data/evaluation/addition', 25 | ], 26 | batch_size: 384 27 | } 28 | data_aug: True 29 | multiscales: False 30 | num_workers: 14 31 | 32 | training: 33 | epochs: 50 34 | show_iters: 50 35 | eval_iters: 3000 36 | save_iters: 3000 37 | 38 | optimizer: 39 | type: Adam 40 | true_wd: False 41 | wd: 0.0 42 | bn_wd: False 43 | clip_grad: 20 44 | lr: 0.0001 45 | scheduler: { 46 | periods: [ 35, 10, 5 ], 47 | gamma: 0.1, 48 | } 49 | 50 | model: 51 | name: 'semimtr.modules.model_abinet_iter.ABINetIterModel' 52 | iter_size: 3 53 | ensemble: '' 54 | use_vision: False 55 | vision: { 56 | checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth, 57 | loss_weight: 1., 58 | attention: 'position', 59 | backbone: 'transformer', 60 | backbone_ln: 3, 61 | } 62 | language: { 63 | checkpoint: workdir/abinet_language_model.pth, 64 | num_layers: 4, 65 | loss_weight: 1., 66 | detach: True, 67 | use_self_attn: False 68 | } 69 | alignment: { 70 | loss_weight: 1., 71 | } 72 | -------------------------------------------------------------------------------- /configs/abinet_pretrain_vision_model.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: pretrain-vision-model 3 | phase: train 4 | stage: pretrain-vision 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: [ 11 | 'data/training/label/real', 12 | ], 13 | batch_size: 384 14 | } 15 | valid: { 16 | roots: [ 17 | 'data/validation', 18 | ], 19 | batch_size: 384 20 | } 21 | test: { 22 | roots: [ 23 | 'data/evaluation/benchmark', 24 | 'data/evaluation/addition', 25 | ], 26 | batch_size: 384 27 | } 28 | data_aug: True 29 | multiscales: False 30 | num_workers: 14 31 | 32 | training: 33 | epochs: 150 34 | show_iters: 50 35 | eval_iters: 3000 36 | save_iters: 3000 37 | 38 | optimizer: 39 | type: Adam 40 | true_wd: False 41 | wd: 0.0 42 | bn_wd: False 43 | clip_grad: 20 44 | lr: 0.0001 45 | scheduler: { 46 | periods: [ 100, 40, 10 ], 47 | gamma: 0.1, 48 | } 49 | 50 | model: 51 | name: 'semimtr.modules.model_vision.BaseVision' 52 | checkpoint: ~ 53 | vision: { 54 | loss_weight: 1., 55 | attention: 'position', 56 | backbone: 'transformer', 57 | backbone_ln: 3, 58 | } 59 | -------------------------------------------------------------------------------- /configs/pretrain_language_model.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: pretrain-language-model 3 | phase: train 4 | stage: pretrain-language 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: [ 11 | 'data/WikiText-103.csv', 12 | ], 13 | batch_size: 4096 14 | } 15 | valid: { 16 | roots: [ 17 | 'data/WikiText-103_eval_d1.csv', 18 | ], 19 | batch_size: 4096 20 | } 21 | 22 | training: 23 | epochs: 80 24 | show_iters: 50 25 | eval_iters: 6000 26 | save_iters: 3000 27 | 28 | optimizer: 29 | type: Adam 30 | true_wd: False 31 | wd: 0.0 32 | bn_wd: False 33 | clip_grad: 20 34 | lr: 0.0001 35 | scheduler: { 36 | periods: [ 70, 10 ], 37 | gamma: 0.1, 38 | } 39 | 40 | model: 41 | name: 'semimtr.modules.model_language.BCNLanguage' 42 | language: { 43 | num_layers: 4, 44 | loss_weight: 1., 45 | use_self_attn: False 46 | } 47 | -------------------------------------------------------------------------------- /configs/semimtr_finetune.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: consistency-regularization 3 | phase: train 4 | stage: train-semi-supervised 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | scheme: consistency_regularization 10 | type: ST 11 | train: { 12 | roots: [ 13 | 'data/training/label/real', 14 | 'data/training/unlabel', 15 | ], 16 | batch_size: 232, 17 | weights: [ 0.3, 0.7 ] 18 | } 19 | valid: { 20 | roots: [ 21 | 'data/validation', 22 | ], 23 | batch_size: 232 24 | } 25 | test: { 26 | roots: [ 27 | 'data/evaluation/benchmark', 28 | 'data/evaluation/addition', 29 | ], 30 | batch_size: 232 31 | } 32 | data_aug: True 33 | multiscales: False 34 | num_workers: 14 35 | augmentation_severity: 1 36 | 37 | training: 38 | epochs: 5 39 | show_iters: 50 40 | eval_iters: 1000 41 | save_iters: 3000 42 | 43 | optimizer: 44 | type: Adam 45 | true_wd: False 46 | wd: 0.0001 47 | bn_wd: False 48 | clip_grad: 20 49 | lr: 0.0001 50 | scheduler: { 51 | periods: [ 3, 1, 1 ], 52 | gamma: 0.1, 53 | } 54 | 55 | model: 56 | name: 'semimtr.modules.model_fusion_consistency_regularization.ConsistencyRegularizationFusionModel' 57 | iter_size: 3 58 | vision: { 59 | checkpoint: workdir/semimtr_vision_model_real_l_and_u.pth, 60 | loss_weight: 1., 61 | attention: 'position', 62 | backbone: 'transformer', 63 | backbone_ln: 3, 64 | checkpoint_submodule: vision, 65 | } 66 | language: { 67 | checkpoint: workdir/abinet_language_model.pth, 68 | num_layers: 4, 69 | loss_weight: 1., 70 | detach: True, 71 | use_self_attn: False 72 | } 73 | alignment: { 74 | loss_weight: 1., 75 | } 76 | consistency: { 77 | loss_weight: 1., 78 | supervised_flag: True, 79 | all_to_all: True, 80 | # teacher_layer: vision, # alignment | language | vision (doesn't matter if all_to_all is True) 81 | # student_layer: all, # all | alignment | language | vision (doesn't matter if all_to_all is True) 82 | teacher_one_hot: True, 83 | kl_div: False, 84 | teacher_stop_gradients: True, 85 | use_threshold: False, 86 | ema: False, 87 | ema_decay: 0.9999 88 | } 89 | -------------------------------------------------------------------------------- /configs/semimtr_pretrain_vision_model.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: seqclr-pretrain-vision-model 3 | phase: train 4 | stage: pretrain-vision 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | scheme: selfsupervised 10 | type: ST 11 | train: { 12 | roots: [ 13 | 'data/training/label/real', 14 | 'data/training/unlabel', 15 | ], 16 | weights: [ 0.3, 0.7 ], 17 | batch_size: 304, 18 | } 19 | valid: { 20 | roots: [ 21 | 'data/validation', 22 | ], 23 | batch_size: 304 24 | } 25 | test: { 26 | roots: [ 27 | 'data/evaluation/benchmark', 28 | 'data/evaluation/addition', 29 | ], 30 | batch_size: 304 31 | } 32 | data_aug: True 33 | multiscales: False 34 | num_workers: 14 35 | augmentation_severity: 1 36 | 37 | training: 38 | epochs: 25 39 | show_iters: 50 40 | eval_iters: 3000 41 | save_iters: 3000 42 | 43 | optimizer: 44 | type: Adam 45 | true_wd: False 46 | wd: 0.0001 47 | bn_wd: False 48 | clip_grad: 20 49 | lr: 0.0001 50 | scheduler: { 51 | periods: [ 17, 5, 3 ], 52 | gamma: 0.1, 53 | } 54 | 55 | model: 56 | name: 'semimtr.modules.model_seqclr_vision.SeqCLRModel' 57 | checkpoint: ~ 58 | vision: { 59 | loss_weight: 1., 60 | attention: 'position', 61 | backbone: 'transformer', 62 | backbone_ln: 3, 63 | } 64 | proj: { 65 | layer: backbone_feature, # 'feature'|'backbone_feature' 66 | scheme: null, # null|'bilstm'|'linear_per_column' 67 | # hidden: 256, 68 | # output: 256, 69 | } 70 | contrastive: { 71 | loss_weight: 1., 72 | supervised_flag: True, 73 | } 74 | instance_mapping: { 75 | frame_to_instance: False, 76 | fixed: instances, # instances|frames 77 | w: 5, 78 | } 79 | -------------------------------------------------------------------------------- /configs/template.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: exp 3 | phase: train 4 | stage: pretrain-vision 5 | workdir: workdir 6 | seed: ~ 7 | debug: False 8 | 9 | dataset: 10 | scheme: 'supervised' 11 | type: 'ST' 12 | train: { 13 | roots: [ 14 | 'data/training/label/real', 15 | ], 16 | batch_size: 128, 17 | weights: ~, 18 | } 19 | valid: { 20 | roots: [ 21 | 'data/validation', 22 | ], 23 | batch_size: 128 24 | } 25 | test: { 26 | roots: [ 27 | 'data/evaluation/benchmark', 28 | 'data/evaluation/addition', 29 | ], 30 | batch_size: 128 31 | } 32 | portion: 1.0 33 | charset_path: data/charset_36.txt 34 | num_workers: 4 35 | max_length: 25 36 | image_height: 32 37 | image_width: 128 38 | case_sensitive: False 39 | eval_case_sensitive: False 40 | data_aug: True 41 | multiscales: False 42 | pin_memory: True 43 | smooth_label: False 44 | smooth_factor: 0.1 45 | use_sm: False 46 | filter_single_punctuation: False 47 | 48 | training: 49 | epochs: 6 50 | show_iters: 50 51 | eval_iters: 3000 52 | save_iters: 20000 53 | start_iters: 0 54 | stats_iters: 1000 55 | hist_iters: 10000000 56 | 57 | optimizer: 58 | type: Adam 59 | true_wd: False 60 | wd: 0.0 61 | bn_wd: False 62 | clip_grad: 20 63 | lr: 0.0001 64 | scheduler: { 65 | periods: [ 3, 1, 1 ], 66 | gamma: 0.1, 67 | } 68 | 69 | model: 70 | name: 'semimtr.modules.model_abinet.ABINetModel' 71 | checkpoint: ~ 72 | strict: True 73 | -------------------------------------------------------------------------------- /data/DATA.md: -------------------------------------------------------------------------------- 1 | # DATA Structure 2 | 3 | - Training and evaluation require download preprocessed lmdb. [Link](https://github.com/ku21fan/STR-Fewer-Labels/blob/main/data.md#download-preprocessed-lmdb-dataset-for-traininig-and-evaluation) 4 | - Pretraining the language model requires the WikiText103. [Link](https://github.com/FangShancheng/ABINet#datasets) 5 | - The final structure of `data` directory is: 6 | 7 | ``` 8 | data 9 | ├── charset_36.txt 10 | ├── training 11 | │ ├── label 12 | │ │ ├── real 13 | │ │ │ ├── 1.SVT 14 | │ │ │ ├── 2.IIIT 15 | │ │ │ ├── 3.IC13 16 | │ │ │ ├── 4.IC15 17 | │ │ │ ├── 5.COCO 18 | │ │ │ ├── 6.RCTW17 19 | │ │ │ ├── 7.Uber 20 | │ │ │ ├── 8.ArT 21 | │ │ │ ├── 9.LSVT 22 | │ │ │ ├── 10.MLT19 23 | │ │ │ └── 11.ReCTS 24 | │ │ └── synth (for synthetic data, follow guideline at https://github.com/ku21fan/STR-Fewer-Labels/blob/main/data.md) 25 | │ │ ├── MJ 26 | │ │ │ ├── MJ_train 27 | │ │ │ ├── MJ_valid 28 | │ │ │ └── MJ_test 29 | │ │ ├── ST 30 | │ │ ├── ST_spe 31 | │ │ └── SA 32 | │ └── unlabel 33 | │ ├── U1.Book32 34 | │ ├── U2.TextVQA 35 | │ └── U3.STVQA 36 | ├── validation 37 | │ ├── 1.SVT 38 | │ ├── 2.IIIT 39 | │ ├── 3.IC13 40 | │ ├── 4.IC15 41 | │ ├── 5.COCO 42 | │ ├── 6.RCTW17 43 | │ ├── 7.Uber 44 | │ ├── 8.ArT 45 | │ ├── 9.LSVT 46 | │ ├── 10.MLT19 47 | │ └── 11.ReCTS 48 | ├── evaluation 49 | │ ├── benchmark 50 | │ │ ├── SVT 51 | │ │ ├── IIIT5k_3000 52 | │ │ ├── IC13_1015 53 | │ │ ├── IC15_2077 54 | │ │ ├── SVTP 55 | │ │ └── CUTE80 56 | │ └── addition 57 | │ ├── 5.COCO 58 | │ ├── 6.RCTW17 59 | │ ├── 7.Uber 60 | │ ├── 8.ArT 61 | │ ├── 9.LSVT 62 | │ ├── 10.MLT19 63 | │ └── 11.ReCTS 64 | ├── WikiText-103.csv (for training LM) 65 | └── WikiText-103_eval_d1.csv (for training LM) 66 | ``` 67 | -------------------------------------------------------------------------------- /data/charset_36.txt: -------------------------------------------------------------------------------- 1 | 0 a 2 | 1 b 3 | 2 c 4 | 3 d 5 | 4 e 6 | 5 f 7 | 6 g 8 | 7 h 9 | 8 i 10 | 9 j 11 | 10 k 12 | 11 l 13 | 12 m 14 | 13 n 15 | 14 o 16 | 15 p 17 | 16 q 18 | 17 r 19 | 18 s 20 | 19 t 21 | 20 u 22 | 21 v 23 | 22 w 24 | 23 x 25 | 24 y 26 | 25 z 27 | 26 1 28 | 27 2 29 | 28 3 30 | 29 4 31 | 30 5 32 | 31 6 33 | 32 7 34 | 33 8 35 | 34 9 36 | 35 0 -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | 6 | import PIL 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import tqdm 12 | from semimtr.utils.utils import Config, Logger, CharsetMapper 13 | from torchvision import transforms 14 | 15 | 16 | def get_model(config): 17 | import importlib 18 | names = config.model_name.split('.') 19 | module_name, class_name = '.'.join(names[:-1]), names[-1] 20 | cls = getattr(importlib.import_module(module_name), class_name) 21 | model = cls(config) 22 | logging.info(model) 23 | model = model.eval() 24 | return model 25 | 26 | 27 | def preprocess(img, width, height): 28 | img = cv2.resize(np.array(img), (width, height)) 29 | img = transforms.ToTensor()(img).unsqueeze(0) 30 | mean = torch.tensor([0.485, 0.456, 0.406]) 31 | std = torch.tensor([0.229, 0.224, 0.225]) 32 | return (img - mean[..., None, None]) / std[..., None, None] 33 | 34 | 35 | def postprocess(raw_output, charset, model_eval): 36 | def _extract_output_list(last_output): 37 | if isinstance(last_output, (tuple, list)): 38 | return last_output 39 | elif isinstance(last_output, dict) and 'supervised_outputs_view0' in last_output: 40 | return last_output['supervised_outputs_view0'] 41 | elif isinstance(last_output, dict) and 'teacher_outputs' in last_output: 42 | return last_output['teacher_outputs'] 43 | else: 44 | return 45 | 46 | def _get_output(last_output, model_eval): 47 | output_list = _extract_output_list(last_output) 48 | if output_list is not None: 49 | if isinstance(output_list, (tuple, list)): 50 | for res in output_list: 51 | if res['name'] == model_eval: output = res 52 | else: 53 | output = output_list 54 | else: 55 | output = last_output 56 | return output 57 | 58 | def _decode(logit): 59 | """ Greed decode """ 60 | out = F.softmax(logit, dim=2) 61 | pt_text, pt_scores, pt_lengths = [], [], [] 62 | for o in out: 63 | text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) 64 | text = text.split(charset.null_char)[0] # end at end-token 65 | pt_text.append(text) 66 | pt_scores.append(o.max(dim=1)[0]) 67 | pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token 68 | return pt_text, pt_scores, pt_lengths 69 | 70 | output = _get_output(raw_output, model_eval) 71 | logits, pt_lengths = output['logits'], output['pt_lengths'] 72 | pt_text, pt_scores, pt_lengths_ = _decode(logits) 73 | 74 | return pt_text, pt_scores, pt_lengths_ 75 | 76 | 77 | def load(model, file, device=None, strict=True): 78 | if device is None: 79 | device = 'cpu' 80 | elif isinstance(device, int): 81 | device = torch.device('cuda', device) 82 | assert os.path.isfile(file) 83 | state = torch.load(file, map_location=device) 84 | if set(state.keys()) == {'model', 'opt'}: 85 | state = state['model'] 86 | model.load_state_dict(state, strict=strict) 87 | return model 88 | 89 | 90 | def main(): 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--config', type=str, default='configs/semimtr_finetune.yaml', 93 | help='path to config file') 94 | parser.add_argument('--input', type=str, default='figs/test') 95 | parser.add_argument('--cuda', type=int, default=-1) 96 | parser.add_argument('--checkpoint', type=str, 97 | default='workdir/consistency-regularization/best-consistency-regularization.pth') 98 | parser.add_argument('--model_eval', type=str, default='alignment', 99 | choices=['alignment', 'vision', 'language']) 100 | args = parser.parse_args() 101 | config = Config(args.config) 102 | if args.checkpoint is not None: config.model_checkpoint = args.checkpoint 103 | if args.model_eval is not None: config.model_eval = args.model_eval 104 | config.global_phase = 'test' 105 | config.model_vision_checkpoint, config.model_language_checkpoint = None, None 106 | device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}' 107 | 108 | Logger.init(config.global_workdir, config.global_name, config.global_phase) 109 | Logger.enable_file() 110 | logging.info(config) 111 | 112 | logging.info('Construct model.') 113 | model = get_model(config).to(device) 114 | model = load(model, config.model_checkpoint, device=device) 115 | charset = CharsetMapper(filename=config.dataset_charset_path, 116 | max_length=config.dataset_max_length + 1) 117 | 118 | if os.path.isdir(args.input): 119 | paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] 120 | else: 121 | paths = glob.glob(os.path.expanduser(args.input)) 122 | assert paths, "The input path(s) was not found" 123 | pt_outputs = {} 124 | paths = sorted(paths) 125 | for path in tqdm.tqdm(paths): 126 | img = PIL.Image.open(path).convert('RGB') 127 | img = preprocess(img, config.dataset_image_width, config.dataset_image_height) 128 | img = img.to(device) 129 | res = model(img, forward_only_teacher=True) 130 | pt_text, _, __ = postprocess(res, charset, config.model_eval) 131 | pt_outputs[path] = pt_text[0] 132 | logging.info(f'SemiMTR Prediction of the path: {path} is: {pt_text[0]}') 133 | return pt_outputs 134 | 135 | 136 | if __name__ == '__main__': 137 | pt_outputs = main() 138 | logging.info('Finished!') 139 | for k, v in pt_outputs.items(): 140 | print(k, v) 141 | -------------------------------------------------------------------------------- /figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/semimtr-text-recognition/043d65b3caac416a65ccd10ecd965ce3bbaa62ad/figures/.DS_Store -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | from torch.backends import cudnn 5 | from fastai.vision import * 6 | from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase 7 | 8 | from semimtr.callbacks.callbacks import IterationCallback, TextAccuracy, TopKTextAccuracy, EMA 9 | from semimtr.dataset.dataset import ImageDataset, TextDataset, collate_fn_filter_none 10 | from semimtr.dataset.dataset_selfsupervised import ImageDatasetSelfSupervised 11 | from semimtr.dataset.dataset_consistency_regularization import ImageDatasetConsistencyRegularization 12 | from semimtr.dataset.weighted_sampler import WeightedDatasetRandomSampler 13 | from semimtr.losses.losses import MultiCELosses 14 | from semimtr.losses.seqclr_loss import SeqCLRLoss 15 | from semimtr.losses.consistency_regularization_loss import ConsistencyRegularizationLoss 16 | from semimtr.utils.utils import Config, Logger, MyDataParallel, \ 17 | MyConcatDataset, if_none 18 | from semimtr.utils.test import test_on_each_ds 19 | 20 | 21 | def _set_random_seed(seed): 22 | cudnn.deterministic = True 23 | if seed is not None: 24 | random.seed(seed) 25 | torch.manual_seed(seed) 26 | logging.warning('You have chosen to seed training. ' 27 | 'This will slow down your training!') 28 | 29 | 30 | def _get_training_phases(config, n): 31 | lr = np.array(config.optimizer_lr) 32 | periods = config.optimizer_scheduler_periods 33 | sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))] 34 | phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i]) 35 | for i in range(len(periods))] 36 | return phases 37 | 38 | 39 | def _get_dataset(ds_type, paths, is_training, config, **kwargs): 40 | kwargs.update({ 41 | 'img_h': config.dataset_image_height, 42 | 'img_w': config.dataset_image_width, 43 | 'max_length': config.dataset_max_length, 44 | 'case_sensitive': config.dataset_case_sensitive, 45 | 'charset_path': config.dataset_charset_path, 46 | 'data_aug': config.dataset_data_aug, 47 | 'deteriorate_ratio': config.dataset_deteriorate_ratio, 48 | 'multiscales': config.dataset_multiscales, 49 | 'data_portion': config.dataset_portion, 50 | 'filter_single_punctuation': config.dataset_filter_single_punctuation, 51 | }) 52 | datasets = [] 53 | for p in paths: 54 | subfolders = [f.path for f in os.scandir(p) if f.is_dir()] 55 | if subfolders: # Concat all subfolders 56 | datasets.append(_get_dataset(ds_type, subfolders, is_training, config, **kwargs)) 57 | else: 58 | datasets.append(ds_type(path=p, is_training=is_training, **kwargs)) 59 | if len(datasets) > 1: 60 | return MyConcatDataset(datasets) 61 | else: 62 | return datasets[0] 63 | 64 | 65 | def _get_language_databaunch(config): 66 | kwargs = { 67 | 'max_length': config.dataset_max_length, 68 | 'case_sensitive': config.dataset_case_sensitive, 69 | 'charset_path': config.dataset_charset_path, 70 | 'smooth_label': config.dataset_smooth_label, 71 | 'smooth_factor': config.dataset_smooth_factor, 72 | 'use_sm': config.dataset_use_sm, 73 | } 74 | train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs) 75 | valid_ds = TextDataset(config.dataset_valid_roots[0], is_training=False, **kwargs) 76 | data = DataBunch.create( 77 | path=train_ds.path, 78 | train_ds=train_ds, 79 | valid_ds=valid_ds, 80 | bs=config.dataset_train_batch_size, 81 | val_bs=config.dataset_test_batch_size, 82 | num_workers=config.dataset_num_workers, 83 | pin_memory=config.dataset_pin_memory) 84 | logging.info(f'{len(data.train_ds)} training items found.') 85 | if not data.empty_val: 86 | logging.info(f'{len(data.valid_ds)} valid items found.') 87 | return data 88 | 89 | 90 | def _get_databaunch(config): 91 | bunch_kwargs = {} 92 | ds_kwargs = {} 93 | bunch_kwargs['collate_fn'] = collate_fn_filter_none 94 | if config.dataset_scheme == 'supervised': 95 | dataset_class = ImageDataset 96 | elif config.dataset_scheme == 'selfsupervised': 97 | dataset_class = ImageDatasetSelfSupervised 98 | if config.dataset_augmentation_severity is not None: 99 | ds_kwargs['augmentation_severity'] = config.dataset_augmentation_severity 100 | ds_kwargs['supervised_flag'] = if_none(config.model_contrastive_supervised_flag, False) 101 | elif config.dataset_scheme == 'consistency_regularization': 102 | dataset_class = ImageDatasetConsistencyRegularization 103 | if config.dataset_augmentation_severity is not None: 104 | ds_kwargs['augmentation_severity'] = config.dataset_augmentation_severity 105 | ds_kwargs['supervised_flag'] = if_none(config.model_consistency_regularization_supervised_flag, True) 106 | else: 107 | raise NotImplementedError(f'dataset_scheme={config.dataset_scheme} is not supported') 108 | train_ds = _get_dataset(dataset_class, config.dataset_train_roots, True, config, **ds_kwargs) 109 | valid_ds = _get_dataset(dataset_class, config.dataset_valid_roots, False, config, **ds_kwargs) 110 | if config.dataset_test_roots is not None: 111 | test_ds = _get_dataset(dataset_class, config.dataset_test_roots, False, config, **ds_kwargs) 112 | bunch_kwargs['test_ds'] = test_ds 113 | data = ImageDataBunch.create( 114 | train_ds=train_ds, 115 | valid_ds=valid_ds, 116 | bs=config.dataset_train_batch_size, 117 | val_bs=config.dataset_test_batch_size, 118 | num_workers=config.dataset_num_workers, 119 | pin_memory=config.dataset_pin_memory, 120 | **bunch_kwargs, 121 | ).normalize(imagenet_stats) 122 | ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd 123 | data.add_tfm(ar_tfm) 124 | if config.dataset_train_weights is not None: 125 | weighted_sampler = WeightedDatasetRandomSampler(dataset_weights=config.dataset_train_weights, 126 | dataset_sizes=[len(ds) for ds in train_ds.datasets]) 127 | data.train_dl = data.train_dl.new(shuffle=False, sampler=weighted_sampler) 128 | 129 | logging.info(f'{len(data.train_ds)} training items found.') 130 | if not data.empty_val: 131 | logging.info(f'{len(data.valid_ds)} valid items found.') 132 | if data.test_dl: 133 | logging.info(f'{len(data.test_ds)} test items found.') 134 | 135 | return data 136 | 137 | 138 | def _get_model(config): 139 | import importlib 140 | names = config.model_name.split('.') 141 | module_name, class_name = '.'.join(names[:-1]), names[-1] 142 | cls = getattr(importlib.import_module(module_name), class_name) 143 | model = cls(config) 144 | # logging.info(model) 145 | return model 146 | 147 | 148 | def _get_learner(config, data, model): 149 | if config.global_stage == 'pretrain-language': 150 | metrics = [TopKTextAccuracy( 151 | k=if_none(config.model_k, 5), 152 | charset_path=config.dataset_charset_path, 153 | max_length=config.dataset_max_length + 1, 154 | case_sensitive=config.dataset_eval_case_sensitive, 155 | model_eval=config.model_eval)] 156 | elif config.dataset_scheme == 'selfsupervised' and not config.model_contrastive_supervised_flag: 157 | metrics = None 158 | else: 159 | metrics = [TextAccuracy( 160 | charset_path=config.dataset_charset_path, 161 | max_length=config.dataset_max_length + 1, 162 | case_sensitive=config.dataset_eval_case_sensitive, 163 | model_eval=config.model_eval)] 164 | opt_type = getattr(torch.optim, config.optimizer_type) 165 | if config.dataset_scheme == 'supervised': 166 | loss_func = MultiCELosses() 167 | elif config.dataset_scheme == 'selfsupervised': 168 | loss_func = SeqCLRLoss(supervised_flag=config.model_contrastive_supervised_flag) 169 | elif config.dataset_scheme == 'consistency_regularization': 170 | loss_func = ConsistencyRegularizationLoss( 171 | supervised_flag=config.model_consistency_supervised_flag, 172 | all_teacher_layers_to_all_student_layers=config.model_consistency_all_to_all, 173 | teacher_layer=config.model_consistency_teacher_layer, 174 | student_layer=config.model_consistency_student_layer, 175 | teacher_one_hot_labels=config.model_consistency_teacher_one_hot, 176 | consistency_kl_div=config.model_consistency_kl_div, 177 | teacher_stop_gradients=config.model_consistency_teacher_stop_gradients, 178 | use_threshold=config.model_consistency_use_threshold, 179 | ) 180 | else: 181 | raise NotImplementedError(f'dataset_scheme={config.dataset_scheme} is not supported') 182 | learner = Learner(data, model, silent=True, model_dir='.', 183 | true_wd=config.optimizer_true_wd, 184 | wd=config.optimizer_wd, 185 | bn_wd=config.optimizer_bn_wd, 186 | path=config.global_workdir, 187 | metrics=metrics, 188 | opt_func=partial(opt_type, **config.optimizer_args or dict()), 189 | loss_func=loss_func) 190 | 191 | phases = _get_training_phases(config, len(learner.data.train_dl)) 192 | learner.callback_fns += [ 193 | partial(GeneralScheduler, phases=phases), 194 | partial(GradientClipping, clip=config.optimizer_clip_grad), 195 | partial(IterationCallback, name=config.global_name, 196 | show_iters=config.training_show_iters, 197 | eval_iters=config.training_eval_iters, 198 | save_iters=config.training_save_iters, 199 | start_iters=config.training_start_iters, 200 | stats_iters=config.training_stats_iters, 201 | hist_iters=config.training_hist_iters, 202 | debug=config.global_debug)] 203 | 204 | if config.model_consistency_ema: 205 | learner.callback_fns += [partial(EMA)] 206 | 207 | if torch.cuda.device_count() > 1: 208 | logging.info(f'Use {torch.cuda.device_count()} GPUs.') 209 | learner.model = MyDataParallel(learner.model) 210 | 211 | if config.model_checkpoint: 212 | with open(config.model_checkpoint, 'rb') as f: 213 | buffer = io.BytesIO(f.read()) 214 | learner.load(buffer, strict=config.model_strict) 215 | logging.info(f'Read model from {config.model_checkpoint}') 216 | elif config.global_phase == 'test': 217 | learner.load(f'best-{config.global_name}', strict=config.model_strict) 218 | logging.info(f'Read model from best-{config.global_name}') 219 | 220 | return learner 221 | 222 | 223 | def _parse_arguments(): 224 | parser = argparse.ArgumentParser() 225 | parser.add_argument('-c', '--config', type=str, required=True, 226 | help='path to config file') 227 | parser.add_argument('-b', '--batch_size', type=int, default=None, 228 | help='batch size') 229 | parser.add_argument('--run_only_test', action='store_true', default=None, 230 | help='flag to run only test and skip training') 231 | parser.add_argument('--test_root', type=str, default=None, 232 | help='path to test datasets') 233 | parser.add_argument('--checkpoint', type=str, default=None, 234 | help='path to model checkpoint') 235 | parser.add_argument('--vision_checkpoint', type=str, default=None, 236 | help='path to vision model pretrained') 237 | parser.add_argument('--debug', action='store_true', default=None, 238 | help='flag for running on debug without saving model checkpoints') 239 | parser.add_argument('--model_eval', type=str, default=None, 240 | choices=['alignment', 'vision', 'language'], 241 | help='model decoder that outputs predictions') 242 | parser.add_argument('--workdir', type=str, default=None, 243 | help='path to workdir folder') 244 | parser.add_argument('--subworkdir', type=str, default=None, 245 | help='optional prefix to workdir path') 246 | parser.add_argument('--epochs', type=int, default=None, 247 | help='number of training epochs') 248 | parser.add_argument('--eval_iters', type=int, default=None, 249 | help='evaluate performance on validation set every this number iterations') 250 | args = parser.parse_args() 251 | config = Config(args.config) 252 | if args.batch_size is not None: 253 | config.dataset_train_batch_size = args.batch_size 254 | config.dataset_valid_batch_size = args.batch_size 255 | config.dataset_test_batch_size = args.batch_size 256 | if args.run_only_test is not None: 257 | config.global_phase = 'Test' if args.run_only_test else 'Train' 258 | if args.test_root is not None: 259 | config.dataset_test_roots = [args.test_root] 260 | args_to_config_dict = { 261 | 'checkpoint': 'model_checkpoint', 262 | 'vision_checkpoint': 'model_vision_checkpoint', 263 | 'debug': 'global_debug', 264 | 'model_eval': 'model_eval', 265 | 'workdir': 'global_workdir', 266 | 'subworkdir': 'global_subworkdir', 267 | 'epochs': 'training_epochs', 268 | 'eval_iters': 'training_eval_iters', 269 | } 270 | for args_attr, config_attr in args_to_config_dict.items(): 271 | if getattr(args, args_attr) is not None: 272 | setattr(config, config_attr, getattr(args, args_attr)) 273 | return config 274 | 275 | 276 | def main(): 277 | config = _parse_arguments() 278 | Logger.init(config.global_workdir, config.global_name, config.global_phase) 279 | Logger.enable_file() 280 | _set_random_seed(config.global_seed) 281 | logging.info(config) 282 | 283 | logging.info('Construct dataset.') 284 | if config.global_stage == 'pretrain-language': 285 | data = _get_language_databaunch(config) 286 | else: 287 | data = _get_databaunch(config) 288 | 289 | logging.info('Construct model.') 290 | model = _get_model(config) 291 | 292 | logging.info('Construct learner.') 293 | learner = _get_learner(config, data, model) 294 | 295 | if config.global_phase == 'train': 296 | logging.info('Start training.') 297 | learner.fit(epochs=config.training_epochs, 298 | lr=config.optimizer_lr) 299 | logging.info('Finish training.') 300 | 301 | logging.info('Start testing') 302 | test_on_each_ds(learner) 303 | 304 | 305 | if __name__ == '__main__': 306 | main() 307 | -------------------------------------------------------------------------------- /notebook_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": false 7 | }, 8 | "source": [ 9 | "# SemiMTR Demo" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "collapsed": false, 17 | "pycharm": { 18 | "name": "#%%\n" 19 | } 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "from matplotlib import pyplot as plt\n", 24 | "from PIL import Image\n", 25 | "import urllib.request\n", 26 | "!pip install -U PyYAML" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "collapsed": false, 34 | "pycharm": { 35 | "name": "#%%\n" 36 | } 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "#@ Install SemiMTR code\n", 41 | "!git clone 'https://github.com/amazon-research/semimtr-text-recognition'\n", 42 | "%cd semimtr-text-recognition" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "collapsed": false, 50 | "pycharm": { 51 | "name": "#%%\n" 52 | } 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "#@ Download a pretrained model\n", 57 | "!mkdir workdir\n", 58 | "!wget 'https://awscv-public-data.s3.us-west-2.amazonaws.com/semimtr/semimtr_real_l_and_u_and_textocr.pth' -O 'workdir/semimtr_model.pth'" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "collapsed": false, 66 | "pycharm": { 67 | "name": "#%%\n" 68 | } 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "#@ Choose an image\n", 73 | "image_id = 1 #@param {type:\"slider\", min:1, max:10, step:1}\n", 74 | "image_url = f\"https://raw.githubusercontent.com/ku21fan/STR-Fewer-Labels/main/demo_image/{image_id}.png\"\n", 75 | "\n", 76 | "#@markdown ---\n", 77 | "#@markdown Or provide a url path to cropped text image (Optional).\n", 78 | "#@markdown ### Enter a file path:\n", 79 | "file_path = \"\" #@param {type:\"string\"}\n", 80 | "if file_path:\n", 81 | " image_url = file_path" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": { 88 | "collapsed": false, 89 | "pycharm": { 90 | "name": "#%%\n" 91 | } 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "file_name, headers = urllib.request.urlretrieve(image_url)\n", 96 | "output = !CUDA_VISIBLE_DEVICES=0 python3 demo.py --config 'configs/semimtr_finetune.yaml' --input $file_name --checkpoint 'workdir/semimtr_model.pth'" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "collapsed": false, 104 | "pycharm": { 105 | "name": "#%%\n" 106 | } 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "init_ind = [i for i, f in enumerate(output.fields()) if \"Finished!\" in f][0]\n", 111 | "for f in output.fields()[init_ind + 1:]:\n", 112 | " im = Image.open(f[0])\n", 113 | " plt.imshow(im)\n", 114 | " plt.title(f[1])\n", 115 | " plt.axis('off')\n", 116 | " plt.show()" 117 | ] 118 | } 119 | ], 120 | "metadata": { 121 | "kernelspec": { 122 | "display_name": "Python 3", 123 | "language": "python", 124 | "name": "python3" 125 | }, 126 | "language_info": { 127 | "codemirror_mode": { 128 | "name": "ipython", 129 | "version": 2 130 | }, 131 | "file_extension": ".py", 132 | "mimetype": "text/x-python", 133 | "name": "python", 134 | "nbconvert_exporter": "python", 135 | "pygments_lexer": "ipython2", 136 | "version": "2.7.6" 137 | } 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 0 141 | } 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.7.1 2 | torchvision 3 | fastai >= 1.0.60, <2.0 4 | LMDB 5 | Pillow 6 | opencv-python 7 | tensorboardX 8 | editdistance 9 | imgaug -------------------------------------------------------------------------------- /semimtr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/semimtr-text-recognition/043d65b3caac416a65ccd10ecd965ce3bbaa62ad/semimtr/__init__.py -------------------------------------------------------------------------------- /semimtr/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/semimtr-text-recognition/043d65b3caac416a65ccd10ecd965ce3bbaa62ad/semimtr/callbacks/__init__.py -------------------------------------------------------------------------------- /semimtr/callbacks/callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import editdistance as ed 3 | from semimtr.utils.utils import CharsetMapper, Timer, blend_mask 4 | from fastai.callbacks.tensorboard import LearnerTensorboardWriter 5 | from fastai.vision import * 6 | from torch.nn.parallel import DistributedDataParallel 7 | 8 | from semimtr.utils.utils import if_none 9 | 10 | 11 | class IterationCallback(LearnerTensorboardWriter): 12 | "A `TrackerCallback` that monitor in each iteration." 13 | 14 | def __init__(self, learn: Learner, name: str = 'model', checpoint_keep_num=5, 15 | show_iters: int = 50, eval_iters: int = 1000, save_iters: int = 20000, 16 | start_iters: int = 0, stats_iters=20000, hist_iters=20000, debug=False): 17 | super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters, 18 | stats_iters=stats_iters, hist_iters=hist_iters) 19 | self.name, self.bestname = Path(name).name, f'best-{Path(name).name}' 20 | self.show_iters = show_iters 21 | self.eval_iters = eval_iters 22 | self.save_iters = save_iters 23 | self.start_iters = start_iters 24 | self.checpoint_keep_num = checpoint_keep_num 25 | self.metrics_root = 'metrics/' # rewrite 26 | self.timer = Timer() 27 | self.host = True 28 | self.debug = debug 29 | 30 | def _write_metrics(self, iteration: int, names: List[str], last_metrics: MetricsList) -> None: 31 | "Writes training metrics to Tensorboard." 32 | for i, name in enumerate(names): 33 | if last_metrics is None or len(last_metrics) < i + 1: return 34 | scalar_value = last_metrics[i] 35 | self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration) 36 | 37 | def _write_sub_loss(self, iteration: int, last_losses: dict) -> None: 38 | "Writes sub loss to Tensorboard." 39 | for name, loss in last_losses.items(): 40 | scalar_value = to_np(loss) 41 | tag = self.metrics_root + name 42 | self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration) 43 | 44 | def _save(self, name): 45 | if self.debug: return 46 | if isinstance(self.learn.model, DistributedDataParallel): 47 | tmp = self.learn.model 48 | self.learn.model = self.learn.model.module 49 | self.learn.save(name) 50 | self.learn.model = tmp 51 | else: 52 | self.learn.save(name) 53 | 54 | def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False): 55 | "Validate on `dl` with potential `callbacks` and `metrics`." 56 | dl = if_none(dl, self.learn.data.valid_dl) 57 | metrics = if_none(metrics, self.learn.metrics) 58 | cb_handler = CallbackHandler(if_none(callbacks, []), metrics) 59 | cb_handler.on_train_begin(1, None, metrics); 60 | cb_handler.on_epoch_begin() 61 | if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[])) 62 | val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler) 63 | cb_handler.on_epoch_end(val_metrics) 64 | if keeped_items: 65 | return cb_handler.state_dict['keeped_items'] 66 | else: 67 | return cb_handler.state_dict['last_metrics'] 68 | 69 | def jump_to_epoch_iter(self, epoch: int, iteration: int) -> None: 70 | try: 71 | self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False) 72 | logging.info(f'Loaded {self.name}_{epoch}_{iteration}') 73 | except: 74 | logging.info(f'Model {self.name}_{epoch}_{iteration} not found.') 75 | 76 | def on_train_begin(self, n_epochs, **kwargs): 77 | # TODO: can not write graph here 78 | # super().on_train_begin(**kwargs) 79 | self.best = -float('inf') 80 | self.timer.tic() 81 | if self.host: 82 | checkpoint_path = self.learn.path / 'checkpoint.yaml' 83 | if checkpoint_path.exists(): 84 | os.remove(checkpoint_path) 85 | open(checkpoint_path, 'w').close() 86 | return {'skip_validate': True, 'iteration': self.start_iters} # disable default validate 87 | 88 | def on_batch_begin(self, **kwargs: Any) -> None: 89 | self.timer.toc_data() 90 | super().on_batch_begin(**kwargs) 91 | 92 | def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs): 93 | super().on_batch_end(last_loss, iteration, train, **kwargs) 94 | if iteration == 0: return 95 | 96 | if iteration % self.loss_iters == 0: 97 | last_losses = self.learn.loss_func.last_losses 98 | self._write_sub_loss(iteration=iteration, last_losses=last_losses) 99 | self.tbwriter.add_scalar(tag=self.metrics_root + 'lr', 100 | scalar_value=self.opt.lr, global_step=iteration) 101 | 102 | if iteration % self.show_iters == 0: 103 | log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \ 104 | f'smooth loss = {smooth_loss:6.4f} ' 105 | logging.info(log_str) 106 | # log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s' 107 | # logging.info(log_str) 108 | 109 | if iteration % self.eval_iters == 0: 110 | self._eval_model(iteration, epoch) 111 | 112 | if iteration % self.save_iters == 0 and self.host: 113 | logging.info(f'Save model {self.name}_{epoch}_{iteration}') 114 | filename = f'{self.name}_{epoch}_{iteration}' 115 | self._save(filename) 116 | 117 | checkpoint_path = self.learn.path / 'checkpoint.yaml' 118 | if not checkpoint_path.exists(): 119 | open(checkpoint_path, 'w').close() 120 | with open(checkpoint_path, 'r') as file: 121 | checkpoints = yaml.safe_load(file) or dict() 122 | checkpoints['all_checkpoints'] = ( 123 | checkpoints.get('all_checkpoints') or list()) 124 | checkpoints['all_checkpoints'].insert(0, filename) 125 | if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num: 126 | removed_checkpoint = checkpoints['all_checkpoints'].pop() 127 | removed_checkpoint = self.learn.path / self.learn.model_dir / f'{removed_checkpoint}.pth' 128 | os.remove(removed_checkpoint) 129 | checkpoints['current_checkpoint'] = filename 130 | with open(checkpoint_path, 'w') as file: 131 | yaml.dump(checkpoints, file) 132 | 133 | self.timer.toc_running() 134 | 135 | def _eval_model(self, iteration=None, epoch=None): 136 | if iteration is None or epoch is None: 137 | msg_start = f'last iteration' 138 | else: 139 | msg_start = f'epoch {epoch} iter {iteration}' 140 | # 1. Record time 141 | log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \ 142 | f'average running time = {self.timer.average_running_time():.4f}s' 143 | logging.info(log_str) 144 | 145 | # 2. Call validate 146 | last_metrics = self._validate() 147 | self.learn.model.train() 148 | names = self._metrics_to_logging(last_metrics, msg_start) 149 | if len(last_metrics) > 1: 150 | current_eval_loss = last_metrics[2] 151 | else: # only eval loss 152 | current_eval_loss = last_metrics[0] 153 | 154 | if iteration is not None: 155 | self._write_metrics(iteration, names, last_metrics) 156 | 157 | # 3. Save best model 158 | if current_eval_loss is not None and current_eval_loss > self.best: 159 | logging.info(f'Better model found at {msg_start} with accuracy value: {current_eval_loss:6.4f}.') 160 | self.best = current_eval_loss 161 | self._save(f'{self.bestname}') 162 | 163 | @staticmethod 164 | def _metrics_to_logging(metrics, msg_start, dl_len=None): 165 | log_str = f'{msg_start}: ' 166 | if dl_len is not None: 167 | log_str += f'dataset size = {dl_len} ' 168 | log_str += f'eval loss = {metrics[0]:6.3f}, ' 169 | names = ['eval_loss'] 170 | if len(metrics) > 1: 171 | log_str += f'ccr = {metrics[1]:6.3f}, cwr = {metrics[2]:6.3f}, ' \ 172 | f'ted = {metrics[3]:6.3f}, ned = {metrics[4]:6.0f}, ' \ 173 | f'ted/w = {metrics[5]:6.3f}, ' 174 | names += ['ccr', 'cwr', 'ted', 'ned', 'ted/w'] 175 | logging.info(log_str) 176 | return names 177 | 178 | def on_train_end(self, **kwargs): 179 | logging.info('Train ended') 180 | self._eval_model() 181 | self.learn.load(f'{self.bestname}', purge=False) 182 | logging.info(f'Loading best model from {self.learn.path}/{self.learn.model_dir}/{self.bestname}.pth') 183 | 184 | def on_epoch_end(self, last_metrics: MetricsList, iteration: int, **kwargs) -> None: 185 | self._write_embedding(iteration=iteration) 186 | 187 | 188 | class EMA(LearnerCallback): 189 | def on_step_end(self, **kwargs): 190 | if isinstance(self.learn.model, nn.DataParallel): 191 | self.learn.model.module.update_teacher() 192 | else: 193 | self.learn.model.update_teacher() 194 | 195 | 196 | class TextAccuracy(Callback): 197 | _names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w'] 198 | 199 | def __init__(self, charset_path, max_length, case_sensitive, model_eval): 200 | self.charset_path = charset_path 201 | self.max_length = max_length 202 | self.case_sensitive = case_sensitive 203 | self.charset = CharsetMapper(charset_path, self.max_length) 204 | self.names = self._names 205 | 206 | self.model_eval = model_eval or 'alignment' 207 | assert self.model_eval in ['vision', 'language', 'alignment'] 208 | 209 | def on_epoch_begin(self, **kwargs): 210 | self.total_num_char = 0. 211 | self.total_num_word = 0. 212 | self.correct_num_char = 0. 213 | self.correct_num_word = 0. 214 | self.total_ed = 0. 215 | self.total_ned = 0. 216 | 217 | @staticmethod 218 | def _extract_output_list(last_output): 219 | if isinstance(last_output, (tuple, list)): 220 | return last_output 221 | elif isinstance(last_output, dict) and 'supervised_outputs_view0' in last_output: 222 | return last_output['supervised_outputs_view0'] 223 | elif isinstance(last_output, dict) and 'teacher_outputs' in last_output: 224 | return last_output['teacher_outputs'] 225 | else: 226 | return 227 | 228 | def _get_output(self, last_output): 229 | output_list = self._extract_output_list(last_output) 230 | if output_list is not None: 231 | if isinstance(output_list, (tuple, list)): 232 | for res in output_list: 233 | if res['name'] == self.model_eval: output = res 234 | else: 235 | output = output_list 236 | else: 237 | output = last_output 238 | return output 239 | 240 | def _update_output(self, last_output, items): 241 | output_list = self._extract_output_list(last_output) 242 | if output_list is not None: 243 | if isinstance(output_list, (tuple, list)): 244 | for res in output_list: 245 | if res['name'] == self.model_eval: res.update(items) 246 | else: 247 | output_list.update(items) 248 | else: 249 | last_output.update(items) 250 | return last_output 251 | 252 | def on_batch_end(self, last_output, last_target, **kwargs): 253 | output = self._get_output(last_output) 254 | logits, pt_lengths = output['logits'], output['pt_lengths'] 255 | pt_text, pt_scores, pt_lengths_ = self.decode(logits) 256 | if not (pt_lengths == pt_lengths_).all(): 257 | for pt_lengths_i, pt_lengths_i_, pt_text_i in zip(pt_lengths, pt_lengths_, pt_text): 258 | if pt_lengths_i != pt_lengths_i_: 259 | logging.warning(f'{pt_lengths_i} != {pt_lengths_i_} for {pt_text_i}') 260 | last_output = self._update_output(last_output, {'pt_text': pt_text, 'pt_scores': pt_scores}) 261 | 262 | pt_text = [self.charset.trim(t) for t in pt_text] 263 | label = last_target['label'] 264 | if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label 265 | gt_text = [self.charset.get_text(l, trim=True) for l in label] 266 | 267 | for i in range(len(gt_text)): 268 | if not self.case_sensitive: 269 | gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower() 270 | distance = ed.eval(gt_text[i], pt_text[i]) 271 | self.total_ed += distance 272 | self.total_ned += float(distance) / max(len(gt_text[i]), 1) 273 | 274 | if gt_text[i] == pt_text[i]: 275 | self.correct_num_word += 1 276 | self.total_num_word += 1 277 | 278 | for j in range(min(len(gt_text[i]), len(pt_text[i]))): 279 | if gt_text[i][j] == pt_text[i][j]: 280 | self.correct_num_char += 1 281 | self.total_num_char += len(gt_text[i]) 282 | 283 | return {'last_output': last_output} 284 | 285 | def on_epoch_end(self, last_metrics, **kwargs): 286 | mets = [self.correct_num_char / self.total_num_char, 287 | self.correct_num_word / self.total_num_word, 288 | self.total_ed, 289 | self.total_ned, 290 | self.total_ed / self.total_num_word] 291 | return add_metrics(last_metrics, mets) 292 | 293 | def decode(self, logit): 294 | """ Greed decode """ 295 | # TODO: test running time and decode on GPU 296 | out = F.softmax(logit, dim=2) 297 | pt_text, pt_scores, pt_lengths = [], [], [] 298 | for o in out: 299 | text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False) 300 | text = text.split(self.charset.null_char)[0] # end at end-token 301 | pt_text.append(text) 302 | pt_scores.append(o.max(dim=1)[0]) 303 | pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token 304 | pt_scores = torch.stack(pt_scores) 305 | pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long) 306 | return pt_text, pt_scores, pt_lengths 307 | 308 | 309 | class TopKTextAccuracy(TextAccuracy): 310 | _names = ['ccr', 'cwr'] 311 | 312 | def __init__(self, k, charset_path, max_length, case_sensitive, model_eval): 313 | self.k = k 314 | self.charset_path = charset_path 315 | self.max_length = max_length 316 | self.case_sensitive = case_sensitive 317 | self.charset = CharsetMapper(charset_path, self.max_length) 318 | self.names = self._names 319 | 320 | def on_epoch_begin(self, **kwargs): 321 | self.total_num_char = 0. 322 | self.total_num_word = 0. 323 | self.correct_num_char = 0. 324 | self.correct_num_word = 0. 325 | 326 | def on_batch_end(self, last_output, last_target, **kwargs): 327 | logits, pt_lengths = last_output['logits'], last_output['pt_lengths'] 328 | gt_labels, gt_lengths = last_target['label'], last_target['length'] 329 | 330 | for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths): 331 | word_flag = True 332 | for i in range(length): 333 | char_logit = logit[i].topk(self.k)[1] 334 | char_label = label[i].argmax(-1) 335 | if char_label in char_logit: 336 | self.correct_num_char += 1 337 | else: 338 | word_flag = False 339 | self.total_num_char += 1 340 | if pt_length == length and word_flag: 341 | self.correct_num_word += 1 342 | self.total_num_word += 1 343 | 344 | def on_epoch_end(self, last_metrics, **kwargs): 345 | mets = [self.correct_num_char / self.total_num_char, 346 | self.correct_num_word / self.total_num_word, 347 | 0., 0., 0.] 348 | return add_metrics(last_metrics, mets) 349 | -------------------------------------------------------------------------------- /semimtr/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/semimtr-text-recognition/043d65b3caac416a65ccd10ecd965ce3bbaa62ad/semimtr/dataset/__init__.py -------------------------------------------------------------------------------- /semimtr/dataset/augmentation_pipelines.py: -------------------------------------------------------------------------------- 1 | from imgaug import augmenters as iaa 2 | 3 | 4 | def get_augmentation_pipeline(augmentation_severity=1): 5 | """ 6 | Defining the augmentation pipeline for SemiMTR pre-training and fine-tuning. 7 | :param augmentation_severity: 8 | 0 - ABINet augmentation pipeline 9 | 1 - SemiMTR augmentation pipeline 10 | 2 - SeqCLR augmentation pipeline 11 | :return: augmentation_pipeline 12 | """ 13 | if augmentation_severity == 1: 14 | augmentations = iaa.Sequential([ 15 | iaa.Invert(0.5), 16 | iaa.OneOf([ 17 | iaa.ChannelShuffle(0.35), 18 | iaa.Grayscale(alpha=(0.0, 1.0)), 19 | iaa.KMeansColorQuantization(), 20 | iaa.HistogramEqualization(), 21 | iaa.Dropout(p=(0, 0.2), per_channel=0.5), 22 | iaa.GammaContrast((0.5, 2.0)), 23 | iaa.MultiplyBrightness((0.5, 1.5)), 24 | iaa.AddToHueAndSaturation((-50, 50), per_channel=True), 25 | iaa.ChangeColorTemperature((1100, 10000)) 26 | ]), 27 | iaa.OneOf([ 28 | iaa.Sharpen(alpha=(0.0, 0.5), lightness=(0.0, 0.5)), 29 | iaa.OneOf([ 30 | iaa.GaussianBlur((0.5, 1.5)), 31 | iaa.AverageBlur(k=(2, 6)), 32 | iaa.MedianBlur(k=(3, 7)), 33 | iaa.MotionBlur(k=5) 34 | ]) 35 | ]), 36 | iaa.OneOf([ 37 | iaa.Emboss(alpha=(0.0, 1.0), strength=(0.5, 1.5)), 38 | iaa.AdditiveGaussianNoise(scale=(0, 0.2 * 255)), 39 | iaa.ImpulseNoise(0.1), 40 | iaa.MultiplyElementwise((0.5, 1.5)) 41 | ]) 42 | ]) 43 | elif augmentation_severity == 2: 44 | optional_augmentations_list = [ 45 | iaa.LinearContrast((0.5, 1.0)), 46 | iaa.GaussianBlur((0.5, 1.5)), 47 | iaa.Crop(percent=((0, 0.4), (0, 0), (0, 0.4), (0, 0.0)), keep_size=True), 48 | iaa.Crop(percent=((0, 0.0), (0, 0.02), (0, 0), (0, 0.02)), keep_size=True), 49 | iaa.Sharpen(alpha=(0.0, 0.5), lightness=(0.0, 0.5)), 50 | # iaa.PiecewiseAffine(scale=(0.02, 0.03), mode='edge'), # In SeqCLR but replaced with a faster aug 51 | iaa.ElasticTransformation(alpha=(0, 0.8), sigma=0.25), 52 | iaa.PerspectiveTransform(scale=(0.01, 0.02)), 53 | ] 54 | augmentations = iaa.SomeOf((1, None), optional_augmentations_list, random_order=True) 55 | else: 56 | raise NotImplementedError(f'augmentation_severity={augmentation_severity} is not supported') 57 | 58 | return augmentations 59 | -------------------------------------------------------------------------------- /semimtr/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import math 4 | import warnings 5 | import PIL 6 | from pathlib import Path 7 | from typing import Union 8 | import numpy as np 9 | import re 10 | import cv2 11 | import lmdb 12 | import six 13 | import pandas as pd 14 | import torch 15 | from torchvision import transforms 16 | from torch.utils.data.dataloader import default_collate 17 | from torch.utils.data import Dataset 18 | 19 | from semimtr.utils.transforms import CVColorJitter, CVDeterioration, CVGeometry 20 | from semimtr.utils.utils import CharsetMapper, onehot 21 | 22 | 23 | class ImageDataset(Dataset): 24 | "`ImageDataset` read data from LMDB database." 25 | 26 | def __init__(self, 27 | path: Union[Path, str], 28 | is_training: bool = True, 29 | img_h: int = 32, 30 | img_w: int = 100, 31 | max_length: int = 25, 32 | check_length: bool = True, 33 | filter_single_punctuation: bool = False, 34 | case_sensitive: bool = False, 35 | charset_path: str = 'data/charset_36.txt', 36 | convert_mode: str = 'RGB', 37 | data_aug: bool = True, 38 | multiscales: bool = True, 39 | one_hot_y: bool = True, 40 | data_portion: float = 1.0, 41 | **kwargs): 42 | self.path, self.name = Path(path), Path(path).name 43 | assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory." 44 | self.convert_mode, self.check_length = convert_mode, check_length 45 | self.img_h, self.img_w = img_h, img_w 46 | self.max_length, self.one_hot_y = max_length, one_hot_y 47 | self.case_sensitive, self.is_training = case_sensitive, is_training 48 | self.filter_single_punctuation = filter_single_punctuation 49 | self.data_aug, self.multiscales = data_aug, multiscales 50 | self.charset = CharsetMapper(charset_path, max_length=max_length + 1) 51 | self.charset_string = ''.join([*self.charset.char_to_label]) 52 | self.charset_string = re.sub('-', r'\-', self.charset_string) # escaping the hyphen for later use in regex 53 | self.c = self.charset.num_classes 54 | 55 | self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False) 56 | assert self.env, f'Cannot open LMDB dataset from {path}.' 57 | with self.env.begin(write=False) as txn: 58 | dataset_length = int(txn.get('num-samples'.encode())) 59 | self.use_portion = self.is_training and not data_portion == 1.0 60 | if not self.use_portion: 61 | self.length = dataset_length 62 | else: 63 | self.length = int(data_portion * dataset_length) 64 | self.optional_ind = np.random.permutation(dataset_length)[:self.length] 65 | 66 | if self.is_training and self.data_aug: 67 | self.augment_tfs = transforms.Compose([ 68 | CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5), 69 | CVDeterioration(var=20, degrees=6, factor=4, p=0.25), 70 | CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25) 71 | ]) 72 | self.totensor = transforms.ToTensor() 73 | 74 | def __len__(self): 75 | return self.length 76 | 77 | def _next_image(self): 78 | if not self.is_training: 79 | return 80 | next_index = random.randint(0, len(self) - 1) 81 | if self.use_portion: 82 | next_index = self.optional_ind[next_index] 83 | return self.get(next_index) 84 | 85 | def _check_image(self, x, pixels=6): 86 | if x.size[0] <= pixels or x.size[1] <= pixels: 87 | return False 88 | else: 89 | return True 90 | 91 | def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): 92 | def _resize_ratio(img, ratio, fix_h=True): 93 | if ratio * self.img_w < self.img_h: 94 | if fix_h: 95 | trg_h = self.img_h 96 | else: 97 | trg_h = int(ratio * self.img_w) 98 | trg_w = self.img_w 99 | else: 100 | trg_h, trg_w = self.img_h, int(self.img_h / ratio) 101 | img = cv2.resize(img, (trg_w, trg_h)) 102 | pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2 103 | top, bottom = math.ceil(pad_h), math.floor(pad_h) 104 | left, right = math.ceil(pad_w), math.floor(pad_w) 105 | img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType) 106 | return img 107 | 108 | if self.is_training: 109 | if random.random() < 0.5: 110 | base, maxh, maxw = self.img_h, self.img_h, self.img_w 111 | h, w = random.randint(base, maxh), random.randint(base, maxw) 112 | return _resize_ratio(img, h / w) 113 | else: 114 | return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio 115 | else: 116 | return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio 117 | 118 | def resize(self, img): 119 | if self.multiscales: 120 | return self.resize_multiscales(img, cv2.BORDER_REPLICATE) 121 | else: 122 | return cv2.resize(img, (self.img_w, self.img_h)) 123 | 124 | def get(self, idx): 125 | with self.env.begin(write=False) as txn: 126 | image_key, label_key = f'image-{idx + 1:09d}', f'label-{idx + 1:09d}' 127 | exception_flag = False 128 | try: 129 | raw_label = str(txn.get(label_key.encode()), 'utf-8') # label 130 | if not self.case_sensitive: raw_label = raw_label.lower() 131 | label = re.sub(f'[^{self.charset_string}]', '', raw_label) 132 | # label = re.sub('[^0-9a-zA-Z]+', '', raw_label) 133 | len_issue = 0 < self.max_length < len(label) or len(label) <= 0 134 | single_punctuation = len(label) == 1 and label in '!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ ' 135 | if (self.check_length and len_issue) or (self.filter_single_punctuation and single_punctuation): 136 | return self._next_image() 137 | label = label[:self.max_length] 138 | 139 | imgbuf = txn.get(image_key.encode()) # image 140 | buf = six.BytesIO() 141 | buf.write(imgbuf) 142 | buf.seek(0) 143 | with warnings.catch_warnings(): 144 | warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin 145 | image = PIL.Image.open(buf).convert(self.convert_mode) 146 | except: 147 | import traceback 148 | traceback.print_exc() 149 | exception_flag = True 150 | if "label" in locals(): 151 | logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}') 152 | else: 153 | logging.info(f'Corrupted image is found: {self.name}, {idx}') 154 | return self._next_image() 155 | if exception_flag or not self._check_image(image): 156 | return self._next_image() if self.is_training else None 157 | return {'image': image, 'label': label, 'idx': idx} 158 | 159 | def _process_training(self, image): 160 | if self.data_aug: image = self.augment_tfs(image) 161 | image = self.totensor(self.resize(np.array(image))) 162 | return image 163 | 164 | def _process_test(self, image): 165 | return self.totensor(self.resize(np.array(image))) 166 | 167 | def __getitem__(self, idx): 168 | if self.use_portion: 169 | idx = self.optional_ind[idx] 170 | datum = self.get(idx) 171 | if datum is None: 172 | return 173 | image, text, idx_new = datum['image'], datum['label'], datum['idx'] 174 | 175 | if self.is_training: 176 | image = self._process_training(image) 177 | else: 178 | image = self._process_test(image) 179 | y = self._label_postprocessing(text) 180 | return image, y 181 | 182 | def _label_postprocessing(self, text): 183 | length = torch.tensor(len(text) + 1).to(dtype=torch.long) # one for end token 184 | label = self.charset.get_labels(text, case_sensitive=self.case_sensitive) 185 | label = torch.tensor(label).to(dtype=torch.long) 186 | if self.one_hot_y: label = onehot(label, self.charset.num_classes) 187 | return {'label': label, 'length': length} 188 | 189 | 190 | class TextDataset(Dataset): 191 | def __init__(self, 192 | path: Union[Path, str], 193 | delimiter: str = '\t', 194 | max_length: int = 25, 195 | charset_path: str = 'data/charset_36.txt', 196 | case_sensitive=False, 197 | one_hot_x=True, 198 | one_hot_y=True, 199 | is_training=True, 200 | smooth_label=False, 201 | smooth_factor=0.2, 202 | use_sm=False, 203 | **kwargs): 204 | self.path = Path(path) 205 | self.case_sensitive, self.use_sm = case_sensitive, use_sm 206 | self.smooth_factor, self.smooth_label = smooth_factor, smooth_label 207 | self.charset = CharsetMapper(charset_path, max_length=max_length + 1) 208 | # convert the charset to string for regex filtering 209 | self.charset_string = ''.join([*self.charset.char_to_label]) 210 | self.charset_string = re.sub('-', r'\-', self.charset_string) # escaping the hyphen for later use in regex 211 | self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training 212 | if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset) 213 | 214 | dtype = {'inp': str, 'gt': str} 215 | self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False) 216 | self.inp_col, self.gt_col = 0, 1 217 | 218 | def __len__(self): 219 | return len(self.df) 220 | 221 | def __getitem__(self, idx): 222 | text_x = self.df.iloc[idx, self.inp_col] 223 | if not self.case_sensitive: text_x = text_x.lower() 224 | text_x = re.sub(f'[^{self.charset_string}]', '', text_x) 225 | if self.is_training and self.use_sm: text_x = self.sm(text_x) 226 | 227 | length_x = torch.tensor(len(text_x) + 1).to(dtype=torch.long) # one for end token 228 | label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive) 229 | label_x = torch.tensor(label_x) 230 | if self.one_hot_x: 231 | label_x = onehot(label_x, self.charset.num_classes) 232 | if self.is_training and self.smooth_label: 233 | label_x = torch.stack([self.prob_smooth_label(l) for l in label_x]) 234 | x = {'label': label_x, 'length': length_x} 235 | 236 | text_y = self.df.iloc[idx, self.gt_col] 237 | if not self.case_sensitive: text_y = text_y.lower() 238 | text_y = re.sub(f'[^{self.charset_string}]', '', text_y) 239 | length_y = torch.tensor(len(text_y) + 1).to(dtype=torch.long) # one for end token 240 | label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive) 241 | label_y = torch.tensor(label_y) 242 | if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes) 243 | y = {'label': label_y, 'length': length_y} 244 | return x, y 245 | 246 | def prob_smooth_label(self, one_hot): 247 | one_hot = one_hot.float() 248 | delta = torch.rand([]) * self.smooth_factor 249 | num_classes = len(one_hot) 250 | noise = torch.rand(num_classes) 251 | noise = noise / noise.sum() * delta 252 | one_hot = one_hot * (1 - delta) + noise 253 | return one_hot 254 | 255 | 256 | class SpellingMutation(object): 257 | def __init__(self, pn0=0.7, pn1=0.85, pn2=0.95, pt0=0.7, pt1=0.85, charset=None): 258 | """ 259 | Args: 260 | pn0: the prob of not modifying characters is (pn0) 261 | pn1: the prob of modifying one characters is (pn1 - pn0) 262 | pn2: the prob of modifying two characters is (pn2 - pn1), 263 | and three (1 - pn2) 264 | pt0: the prob of replacing operation is pt0. 265 | pt1: the prob of inserting operation is (pt1 - pt0), 266 | and deleting operation is (1 - pt1) 267 | """ 268 | super().__init__() 269 | self.pn0, self.pn1, self.pn2 = pn0, pn1, pn2 270 | self.pt0, self.pt1 = pt0, pt1 271 | self.charset = charset 272 | logging.info(f'the probs: pn0={self.pn0}, pn1={self.pn1} ' + 273 | f'pn2={self.pn2}, pt0={self.pt0}, pt1={self.pt1}') 274 | 275 | def is_digit(self, text, ratio=0.5): 276 | length = max(len(text), 1) 277 | digit_num = sum([t in self.charset.digits for t in text]) 278 | if digit_num / length < ratio: return False 279 | return True 280 | 281 | def is_unk_char(self, char): 282 | # return char == self.charset.unk_char 283 | return (char not in self.charset.digits) and (char not in self.charset.alphabets) 284 | 285 | def get_num_to_modify(self, length): 286 | prob = random.random() 287 | if prob < self.pn0: 288 | num_to_modify = 0 289 | elif prob < self.pn1: 290 | num_to_modify = 1 291 | elif prob < self.pn2: 292 | num_to_modify = 2 293 | else: 294 | num_to_modify = 3 295 | 296 | if length <= 1: 297 | num_to_modify = 0 298 | elif length >= 2 and length <= 4: 299 | num_to_modify = min(num_to_modify, 1) 300 | else: 301 | num_to_modify = min(num_to_modify, length // 2) # smaller than length // 2 302 | return num_to_modify 303 | 304 | def __call__(self, text, debug=False): 305 | if self.is_digit(text): return text 306 | length = len(text) 307 | num_to_modify = self.get_num_to_modify(length) 308 | if num_to_modify <= 0: return text 309 | 310 | chars = [] 311 | index = np.arange(0, length) 312 | random.shuffle(index) 313 | index = index[: num_to_modify] 314 | if debug: self.index = index 315 | for i, t in enumerate(text): 316 | if i not in index: 317 | chars.append(t) 318 | elif self.is_unk_char(t): 319 | chars.append(t) 320 | else: 321 | prob = random.random() 322 | if prob < self.pt0: # replace 323 | chars.append(random.choice(self.charset.alphabets)) 324 | elif prob < self.pt1: # insert 325 | chars.append(random.choice(self.charset.alphabets)) 326 | chars.append(t) 327 | else: # delete 328 | continue 329 | new_text = ''.join(chars[: self.charset.max_length - 1]) 330 | return new_text if len(new_text) >= 1 else text 331 | 332 | 333 | def collate_fn_filter_none(batch): 334 | batch = list(filter(lambda x: x is not None, batch)) 335 | return default_collate(batch) 336 | -------------------------------------------------------------------------------- /semimtr/dataset/dataset_consistency_regularization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import transforms 4 | 5 | from semimtr.utils.transforms import ImageToPIL, ImageToArray 6 | from semimtr.dataset.dataset import ImageDataset 7 | from semimtr.dataset.augmentation_pipelines import get_augmentation_pipeline 8 | 9 | 10 | class ImageDatasetConsistencyRegularization(ImageDataset): 11 | """ 12 | Image Dataset for Self Supervised training that outputs pairs of images 13 | """ 14 | 15 | def __init__(self, augmentation_severity: int = 1, supervised_flag=False, **kwargs): 16 | super().__init__(**kwargs) 17 | self.supervised_flag = supervised_flag 18 | regular_aug = self.augment_tfs.transforms if hasattr(self, 'augment_tfs') else [] 19 | self.augment_tfs_teacher = transforms.Compose([ImageToPIL()] + regular_aug + [ImageToArray()]) 20 | if self.data_aug: 21 | if augmentation_severity == 0 or (not self.is_training and supervised_flag): 22 | self.augment_tfs_student = self.augment_tfs_teacher 23 | else: 24 | self.augment_tfs_student = get_augmentation_pipeline(augmentation_severity).augment_image 25 | 26 | def _process_training(self, image): 27 | image = np.array(image) 28 | image_views = [] 29 | for tfs in (self.augment_tfs_teacher, self.augment_tfs_student): 30 | if self.data_aug: 31 | image_view = tfs(image) 32 | else: 33 | image_view = image 34 | image_views.append(self.totensor(self.resize(image_view))) 35 | return np.stack(image_views, axis=0) 36 | 37 | def _process_test(self, image): 38 | return self._process_training(image) 39 | 40 | def _label_postprocessing(self, text): 41 | y = super()._label_postprocessing(text) 42 | if text.lower() == 'unlabeleddata': 43 | y['length'] = torch.tensor(0).to(dtype=torch.long) # don't calculate cross entropy on this image 44 | return y 45 | -------------------------------------------------------------------------------- /semimtr/dataset/dataset_selfsupervised.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import transforms 4 | 5 | from semimtr.utils.transforms import ImageToPIL, ImageToArray 6 | from semimtr.dataset.dataset import ImageDataset 7 | from semimtr.dataset.augmentation_pipelines import get_augmentation_pipeline 8 | 9 | 10 | class ImageDatasetSelfSupervised(ImageDataset): 11 | """ 12 | Image Dataset for Self Supervised training that outputs pairs of images 13 | """ 14 | 15 | def __init__(self, augmentation_severity: int = 1, supervised_flag=False, **kwargs): 16 | super().__init__(**kwargs) 17 | self.supervised_flag = supervised_flag 18 | if self.data_aug: 19 | if augmentation_severity == 0 or (not self.is_training and supervised_flag): 20 | regular_aug = self.augment_tfs.transforms if hasattr(self, 'augment_tfs') else [] 21 | self.augment_tfs = transforms.Compose([ImageToPIL()] + regular_aug + [ImageToArray()]) 22 | else: 23 | self.augment_tfs = get_augmentation_pipeline(augmentation_severity).augment_image 24 | 25 | def _process_training(self, image): 26 | image = np.array(image) 27 | image_views = [] 28 | for _ in range(2): 29 | if self.data_aug: 30 | image_view = self.augment_tfs(image) 31 | else: 32 | image_view = image 33 | image_views.append(self.totensor(self.resize(image_view))) 34 | return np.stack(image_views, axis=0) 35 | 36 | def _process_test(self, image): 37 | return self._process_training(image) 38 | 39 | def _label_postprocessing(self, text): 40 | y = super()._label_postprocessing(text) 41 | if text.lower() == 'unlabeleddata': 42 | y['length'] = torch.tensor(0).to(dtype=torch.long) # don't calculate cross entropy on this image 43 | return y 44 | -------------------------------------------------------------------------------- /semimtr/dataset/weighted_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Sequence, Tuple, Optional, List 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data.sampler import Sampler 6 | import torch.distributed as dist 7 | 8 | 9 | class WeightedDatasetRandomSampler(Sampler[int]): 10 | r"""Samples datasets from ``[0,..,len(weights)-1]`` with given probabilities (weights), 11 | and provide a random index for the chosen dataset. 12 | Args: 13 | dataset_weights (sequence) : a sequence of weights, necessary summing up to one 14 | dataset_sizes (sequence): size of each dataset 15 | Example: 16 | >>> WeightedDatasetRandomSampler([0.2, 0.8], [1, 7]) 17 | [(1, 6),(1, 2),(1, 0),(0, 0),(1, 5),(1, 3),(1, 1),(0, 0),(1, 4)] 18 | """ 19 | 20 | def __init__(self, dataset_weights: Sequence[float], dataset_sizes: List[int], adopt_to_ddp: bool = False) -> None: 21 | try: 22 | np.random.choice(len(dataset_sizes), p=dataset_weights) 23 | except ValueError as e: 24 | raise e 25 | self.dataset_weights = torch.Tensor(dataset_weights) 26 | self.dataset_sizes = dataset_sizes 27 | self.sum_cum = np.cumsum([0] + self.dataset_sizes) 28 | self.num_datasets = len(dataset_sizes) 29 | self.num_samples = int(max([ds_size / ds_weight for ds_size, ds_weight in zip(dataset_sizes, dataset_weights)])) 30 | self.epoch = 0 31 | self.ddp_mode = False 32 | if adopt_to_ddp: 33 | self._distributed_sampler() 34 | 35 | def _distributed_sampler(self): 36 | try: 37 | num_replicas = dist.get_world_size() 38 | rank = dist.get_rank() 39 | except: 40 | return 41 | self.ddp_mode = True 42 | self.num_replicas = num_replicas 43 | self.rank = rank 44 | self.num_samples = self.num_samples // self.num_replicas 45 | 46 | def __iter__(self) -> Iterator[int]: 47 | # deterministically shuffle based on epoch 48 | self.generator = torch.Generator() 49 | self.generator.manual_seed(self.epoch) 50 | 51 | if not self.ddp_mode: 52 | self.perm_lists = [EndlessGeneratePermutedIndices(ds_size, g) for ds_size in self.dataset_sizes] 53 | else: 54 | print(f"Init Sampler with rank of {self.rank} and num_replicas {self.num_replicas}") 55 | self.perm_lists = [ 56 | EndlessGeneratePermutedIndicesNew(torch.arange(ds_size)[self.rank:ds_size:self.num_replicas], 57 | self.epoch) for ds_size in self.dataset_sizes] 58 | return self 59 | 60 | def __next__(self) -> int: 61 | if all([perm_list.finished for perm_list in self.perm_lists]): 62 | raise StopIteration 63 | dataset_idx = torch.multinomial(torch.Tensor(self.dataset_weights), 1, generator=self.generator) 64 | return self.sum_cum[dataset_idx] + next(self.perm_lists[dataset_idx]) 65 | 66 | def __len__(self) -> int: 67 | return self.num_samples 68 | 69 | def set_epoch(self, epoch): 70 | self.epoch = epoch 71 | 72 | 73 | class EndlessGeneratePermutedIndices: 74 | def __init__(self, length: int, generator: torch.Generator = None) -> None: 75 | self.length = length 76 | self.finished = False 77 | self.generator = generator 78 | self._sample_perm() 79 | 80 | def _sample_perm(self) -> None: 81 | self.perm_list = torch.randperm(self.length, generator=self.generator).tolist() 82 | 83 | def __iter__(self): 84 | self.finished = False 85 | self._sample_perm() 86 | 87 | def __next__(self) -> int: 88 | if len(self.perm_list) == 0: 89 | self._sample_perm() 90 | self.finished = True 91 | return self.perm_list.pop() 92 | 93 | 94 | class EndlessGeneratePermutedIndicesDistributed(EndlessGeneratePermutedIndices): 95 | def __init__(self, indices: torch.Tensor, generator: torch.Generator = None) -> None: 96 | self.indices = indices 97 | self.finished = False 98 | self.generator = generator 99 | self._sample_perm() 100 | 101 | def _sample_perm(self) -> None: 102 | self.perm_list = self.indices[torch.randperm(len(self.indices), generator=self.generator)].tolist() 103 | 104 | 105 | class EndlessGeneratePermutedIndicesNew: 106 | def __init__(self, indices: torch.Tensor, epoch: int) -> None: 107 | self.indices = indices 108 | self.epoch = epoch 109 | self.__iter__() 110 | 111 | def _sample_perm(self) -> None: 112 | # torch.randperm(self.length, generator=self.generator).tolist() 113 | self.perm_list = self.indices[torch.randperm(len(self.indices), generator=self.generator)].tolist() 114 | 115 | def __iter__(self): 116 | self.generator = torch.Generator() 117 | self.generator.manual_seed(self.epoch) 118 | self.finished = False 119 | self._sample_perm() 120 | 121 | def __next__(self) -> int: 122 | if len(self.perm_list) == 0: 123 | self.finished = True 124 | self._sample_perm() 125 | return self.perm_list.pop() 126 | -------------------------------------------------------------------------------- /semimtr/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/semimtr-text-recognition/043d65b3caac416a65ccd10ecd965ce3bbaa62ad/semimtr/losses/__init__.py -------------------------------------------------------------------------------- /semimtr/losses/consistency_regularization_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from semimtr.losses.losses import MultiCELosses 6 | 7 | layer_name_to_output_ind = {'alignment': 0, 'language': 1, 'vision': 2} 8 | 9 | 10 | class ConsistencyRegularizationLoss(nn.Module): 11 | def __init__( 12 | self, 13 | record=True, 14 | supervised_flag=False, 15 | all_teacher_layers_to_all_student_layers=False, 16 | teacher_layer='vision', 17 | student_layer='all', 18 | teacher_one_hot_labels=False, 19 | consistency_kl_div=False, 20 | teacher_stop_gradients=True, 21 | use_threshold=False, 22 | threshold_value=0.9 23 | ): 24 | super().__init__() 25 | if not all_teacher_layers_to_all_student_layers: 26 | if teacher_layer not in layer_name_to_output_ind.keys(): 27 | raise NotImplementedError(f'Teacher layer can be one of {list(layer_name_to_output_ind.keys())}') 28 | if student_layer != 'all' and student_layer not in layer_name_to_output_ind.keys(): 29 | raise NotImplementedError(f'Student layer can be \'all\' or ' 30 | f'one of {list(layer_name_to_output_ind.keys())}') 31 | self.teacher_layer_ind = layer_name_to_output_ind[teacher_layer] 32 | self.student_layer_ind = None if student_layer == 'all' else layer_name_to_output_ind[student_layer] 33 | self.record = record 34 | self.supervised_flag = supervised_flag 35 | self.supervised_ce_loss = MultiCELosses() 36 | self.consistency_ce_loss = MultiCELosses(kl_div=consistency_kl_div) 37 | self.all_teacher_layers_to_all_student_layers = all_teacher_layers_to_all_student_layers 38 | self.teacher_one_hot_labels = teacher_one_hot_labels 39 | self.teacher_stop_gradients = teacher_stop_gradients 40 | self.use_threshold = use_threshold 41 | self.threshold_value = threshold_value 42 | 43 | @property 44 | def last_losses(self): 45 | return self.losses 46 | 47 | def forward(self, outputs, *args): 48 | if isinstance(outputs, (tuple, list)): 49 | raise NotImplementedError 50 | self.losses = {} 51 | ce_loss = 0 52 | if self.supervised_flag: 53 | ce_loss_teacher = self.supervised_ce_loss(outputs['teacher_outputs'], *args) 54 | self.losses.update({f'{k}_teacher': v for k, v in self.supervised_ce_loss.last_losses.items()}) 55 | ce_loss_student = self.supervised_ce_loss(outputs['student_outputs'], *args) 56 | self.losses.update({f'{k}_student': v for k, v in self.supervised_ce_loss.last_losses.items()}) 57 | ce_loss += ce_loss_teacher + ce_loss_student 58 | 59 | if not self.all_teacher_layers_to_all_student_layers: 60 | teacher_predictions = outputs['teacher_outputs'][self.teacher_layer_ind] 61 | pt_labels_teacher, pt_lengths_teacher, threshold_mask = self.create_teacher_labels(teacher_predictions) 62 | if self.student_layer_ind is not None: 63 | student_predictions = outputs['student_outputs'][self.student_layer_ind] 64 | else: 65 | student_predictions = outputs['student_outputs'] 66 | pt_teacher = {'label': pt_labels_teacher, 'length': pt_lengths_teacher} 67 | ce_loss_student_teacher = self.consistency_ce_loss(student_predictions, pt_teacher, *args[1:], 68 | mask=threshold_mask) 69 | else: 70 | ce_loss_student_teacher = 0 71 | for teacher_predictions, student_predictions in zip(outputs['teacher_outputs'], outputs['student_outputs']): 72 | pt_labels_teacher, pt_lengths_teacher, threshold_mask = self.create_teacher_labels(teacher_predictions) 73 | pt_teacher = {'label': pt_labels_teacher, 'length': pt_lengths_teacher} 74 | ce_loss_student_teacher += self.consistency_ce_loss(student_predictions, pt_teacher, *args[1:], 75 | mask=threshold_mask) 76 | self.losses.update({f'{k}_teacher_student': v for k, v in self.consistency_ce_loss.last_losses.items()}) 77 | ce_loss += outputs['loss_weight'] * ce_loss_student_teacher 78 | return ce_loss 79 | 80 | def create_teacher_labels(self, teacher_predictions): 81 | if isinstance(teacher_predictions, list): 82 | teacher_predictions = teacher_predictions[-1] 83 | pt_lengths_teacher = teacher_predictions['pt_lengths'] 84 | pt_logits_teacher = teacher_predictions['logits'] 85 | if self.teacher_stop_gradients: 86 | pt_logits_teacher = pt_logits_teacher.detach() 87 | pt_labels_teacher = F.softmax(pt_logits_teacher, dim=-1) 88 | max_values, max_indices = torch.max(pt_labels_teacher, dim=-1, keepdim=True) 89 | if self.teacher_one_hot_labels: 90 | pt_labels_teacher = torch.zeros_like(pt_logits_teacher).scatter_(-1, max_indices, 1) 91 | threshold_mask = (max_values.squeeze() > self.threshold_value).float() if self.use_threshold else None 92 | return pt_labels_teacher, pt_lengths_teacher, threshold_mask 93 | -------------------------------------------------------------------------------- /semimtr/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MultiCELosses(nn.Module): 7 | def __init__(self, reduction="batchmean", kl_div=False): 8 | super().__init__() 9 | self.ce = SoftCrossEntropyLoss(reduction=reduction, kl_div=kl_div) 10 | 11 | @property 12 | def last_losses(self): 13 | return self.losses 14 | 15 | @staticmethod 16 | def _flatten(sources, lengths): 17 | return torch.cat([t[:l] for t, l in zip(sources, lengths)]) 18 | 19 | @staticmethod 20 | def _merge_list(all_res): 21 | if not isinstance(all_res, (list, tuple)): 22 | return all_res 23 | 24 | def merge(items): 25 | if isinstance(items[0], torch.Tensor): 26 | return torch.cat(items, dim=0) 27 | else: 28 | return items[0] 29 | 30 | res = dict() 31 | for key in all_res[0].keys(): 32 | items = [r[key] for r in all_res] 33 | res[key] = merge(items) 34 | return res 35 | 36 | def _ce_loss(self, output, gt_labels, gt_lengths, record=True, mask=None): 37 | loss_name = output.get('name') 38 | pt_logits, weight = output['logits'], output['loss_weight'] 39 | 40 | assert pt_logits.shape[0] % gt_labels.shape[0] == 0 41 | iter_size = pt_logits.shape[0] // gt_labels.shape[0] 42 | if iter_size > 1: 43 | gt_labels = gt_labels.repeat(3, 1, 1) 44 | gt_lengths = gt_lengths.repeat(3) 45 | flat_gt_labels = self._flatten(gt_labels, gt_lengths) 46 | flat_pt_logits = self._flatten(pt_logits, gt_lengths) 47 | 48 | if mask is not None: 49 | if iter_size > 1: 50 | mask = mask.repeat(3, 1) 51 | mask = self._flatten(mask, gt_lengths) 52 | 53 | loss = self.ce(flat_pt_logits, flat_gt_labels, mask=mask) * weight 54 | if record and loss_name is not None: 55 | self.losses[f'{loss_name}_loss'] = loss 56 | 57 | return loss 58 | 59 | def forward(self, outputs, gt_dict, record=False, mask=None): 60 | self.losses = {} 61 | gt_labels, gt_lengths = gt_dict['label'], gt_dict['length'] 62 | if isinstance(outputs, (tuple, list)): 63 | outputs = [self._merge_list(o) for o in outputs] 64 | return sum([self._ce_loss(o, gt_labels, gt_lengths, mask=mask) for o in outputs if o['loss_weight'] > 0.]) 65 | else: 66 | return self._ce_loss(outputs, gt_labels, gt_lengths, record=record, mask=mask) 67 | 68 | 69 | class SoftCrossEntropyLoss(nn.Module): 70 | def __init__(self, reduction="batchmean", apply_softmax=True, kl_div=False): 71 | super().__init__() 72 | self.reduction = reduction 73 | self.apply_softmax = apply_softmax 74 | self.kl_div = kl_div 75 | 76 | def forward(self, input, target, mask=None, eps=1e-12): 77 | if self.apply_softmax: 78 | log_prob = F.log_softmax(input, dim=-1) 79 | else: 80 | log_prob = torch.log(input) 81 | if not self.kl_div: # cross entropy loss 82 | loss = - target * log_prob 83 | else: # KL divergence: F.kl_div(log_prob, target, reduction=self.reduction) 84 | loss = target * torch.log(target + eps) - target * log_prob 85 | loss = loss.sum(dim=-1) 86 | if mask is not None: 87 | loss = mask * loss 88 | if self.reduction == "batchmean": 89 | loss = loss.mean() 90 | elif self.reduction == "sum": 91 | loss = loss.sum() 92 | else: 93 | raise NotImplementedError(f'reduction={self.reduction} is not implemented for CE loss') 94 | return loss 95 | -------------------------------------------------------------------------------- /semimtr/losses/seqclr_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from semimtr.losses.losses import MultiCELosses 6 | 7 | 8 | class SeqCLRLoss(nn.Module): 9 | def __init__(self, temp=0.1, reduction="batchmean", record=True, supervised_flag=False): 10 | super().__init__() 11 | self.reduction = reduction 12 | self.temp = temp 13 | self.record = record 14 | self.supervised_flag = supervised_flag 15 | self.supervised_loss = MultiCELosses() 16 | 17 | @property 18 | def last_losses(self): 19 | return self.losses 20 | 21 | def _seqclr_loss(self, features0, features1, n_instances_per_view, n_instances_per_image): 22 | instances = torch.cat((features0, features1), dim=0) 23 | normalized_instances = F.normalize(instances, dim=1) 24 | similarity_matrix = normalized_instances @ normalized_instances.T 25 | similarity_matrix_exp = (similarity_matrix / self.temp).exp_() 26 | cross_entropy_denominator = similarity_matrix_exp.sum(dim=1) - similarity_matrix_exp.diag() 27 | cross_entropy_nominator = torch.cat(( 28 | similarity_matrix_exp.diagonal(offset=n_instances_per_view)[:n_instances_per_view], 29 | similarity_matrix_exp.diagonal(offset=-n_instances_per_view) 30 | ), dim=0) 31 | cross_entropy_similarity = cross_entropy_nominator / cross_entropy_denominator 32 | loss = - cross_entropy_similarity.log() 33 | 34 | if self.reduction == "batchmean": 35 | loss = loss.mean() 36 | elif self.reduction == "sum": 37 | loss = loss.sum() 38 | elif self.reduction == "mean_instances_per_image": 39 | loss = loss.sum() / n_instances_per_image 40 | return loss 41 | 42 | def forward(self, outputs, gt_dict, *args, **kwargs): 43 | if isinstance(outputs, (tuple, list)): 44 | raise NotImplementedError 45 | self.losses = {} 46 | ce_loss = 0 47 | if self.supervised_flag: 48 | ce_loss += self.supervised_loss(outputs['supervised_outputs_view0'], gt_dict, record=True) 49 | ce_view0_last_losses = self.supervised_loss.last_losses 50 | ce_loss += self.supervised_loss(outputs['supervised_outputs_view1'], gt_dict, record=True) 51 | ce_view1_last_losses = self.supervised_loss.last_losses 52 | self.losses.update({k: (v + ce_view1_last_losses[k]) / 2 for k, v in ce_view0_last_losses.items()}) 53 | 54 | loss_name = outputs.get('name') 55 | gt_lengths = gt_dict['length'] 56 | seqclr_loss = 0 57 | if loss_name == 'seqclr_fusion': 58 | pt_length = outputs['pt_lengths'] 59 | pt_length[gt_lengths != 0] = gt_lengths[gt_lengths != 0] # Use ground truth length if available 60 | # TODO: spread on gpus 61 | for features0, features1 in zip(outputs['instances_view0'], outputs['instances_view1']): 62 | features0 = MultiCELosses._flatten(sources=features0, lengths=pt_length) 63 | features1 = MultiCELosses._flatten(sources=features1, lengths=pt_length) 64 | n_instances_per_image = pt_length.float().mean() 65 | n_instances_per_view = features0.shape[0] 66 | seqclr_loss += self._seqclr_loss(features0, features1, n_instances_per_view, n_instances_per_image) 67 | seqclr_loss /= len(outputs['instances_view0']) # Average seqclr losses 68 | else: 69 | features0 = torch.flatten(outputs['instances_view0'], start_dim=0, end_dim=1) 70 | features1 = torch.flatten(outputs['instances_view1'], start_dim=0, end_dim=1) 71 | n_instances_per_image = outputs['instances_view0'].shape[1] 72 | n_instances_per_view = outputs['instances_view0'].shape[0] * n_instances_per_image 73 | seqclr_loss += self._seqclr_loss(features0, features1, n_instances_per_view, n_instances_per_image) 74 | seqclr_loss *= outputs['loss_weight'] 75 | 76 | if self.record and loss_name is not None: 77 | self.losses[f'{loss_name}_loss'] = seqclr_loss 78 | return seqclr_loss + ce_loss 79 | -------------------------------------------------------------------------------- /semimtr/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/semimtr-text-recognition/043d65b3caac416a65ccd10ecd965ce3bbaa62ad/semimtr/modules/__init__.py -------------------------------------------------------------------------------- /semimtr/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .transformer import PositionalEncoding 4 | 5 | class Attention(nn.Module): 6 | def __init__(self, in_channels=512, max_length=25, n_feature=256): 7 | super().__init__() 8 | self.max_length = max_length 9 | 10 | self.f0_embedding = nn.Embedding(max_length, in_channels) 11 | self.w0 = nn.Linear(max_length, n_feature) 12 | self.wv = nn.Linear(in_channels, in_channels) 13 | self.we = nn.Linear(in_channels, max_length) 14 | 15 | self.active = nn.Tanh() 16 | self.softmax = nn.Softmax(dim=2) 17 | 18 | def forward(self, enc_output): 19 | enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) 20 | reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) 21 | reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) 22 | reading_order_embed = self.f0_embedding(reading_order) # b,25,512 23 | 24 | t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 25 | t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 26 | 27 | attn = self.we(t) # b,256,25 28 | attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 29 | g_output = torch.bmm(attn, enc_output) # b,25,512 30 | return g_output, attn.view(*attn.shape[:2], 8, 32) 31 | 32 | 33 | def encoder_layer(in_c, out_c, k=3, s=2, p=1): 34 | return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), 35 | nn.BatchNorm2d(out_c), 36 | nn.ReLU(True)) 37 | 38 | def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): 39 | align_corners = None if mode=='nearest' else True 40 | return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, 41 | mode=mode, align_corners=align_corners), 42 | nn.Conv2d(in_c, out_c, k, s, p), 43 | nn.BatchNorm2d(out_c), 44 | nn.ReLU(True)) 45 | 46 | 47 | class PositionAttention(nn.Module): 48 | def __init__(self, max_length, in_channels=512, num_channels=64, 49 | h=8, w=32, mode='nearest', **kwargs): 50 | super().__init__() 51 | self.max_length = max_length 52 | self.k_encoder = nn.Sequential( 53 | encoder_layer(in_channels, num_channels, s=(1, 2)), 54 | encoder_layer(num_channels, num_channels, s=(2, 2)), 55 | encoder_layer(num_channels, num_channels, s=(2, 2)), 56 | encoder_layer(num_channels, num_channels, s=(2, 2)) 57 | ) 58 | self.k_decoder = nn.Sequential( 59 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 60 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 61 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 62 | decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) 63 | ) 64 | 65 | self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length) 66 | self.project = nn.Linear(in_channels, in_channels) 67 | 68 | def forward(self, x): 69 | N, E, H, W = x.size() 70 | k, v = x, x # (N, E, H, W) 71 | 72 | # calculate key vector 73 | features = [] 74 | for i in range(0, len(self.k_encoder)): 75 | k = self.k_encoder[i](k) 76 | features.append(k) 77 | for i in range(0, len(self.k_decoder) - 1): 78 | k = self.k_decoder[i](k) 79 | k = k + features[len(self.k_decoder) - 2 - i] 80 | k = self.k_decoder[-1](k) 81 | 82 | # calculate query vector 83 | # TODO q=f(q,k) 84 | zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) 85 | q = self.pos_encoder(zeros) # (T, N, E) 86 | q = q.permute(1, 0, 2) # (N, T, E) 87 | q = self.project(q) # (N, T, E) 88 | 89 | # calculate attention 90 | attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) 91 | attn_scores = attn_scores / (E ** 0.5) 92 | attn_scores = torch.softmax(attn_scores, dim=-1) 93 | 94 | v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) 95 | attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) 96 | 97 | return attn_vecs, attn_scores.view(N, -1, H, W) 98 | -------------------------------------------------------------------------------- /semimtr/modules/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from semimtr.modules.model import _default_tfmer_cfg 4 | from semimtr.modules.resnet import resnet45 5 | from semimtr.modules.transformer import (PositionalEncoding, 6 | TransformerEncoder, 7 | TransformerEncoderLayer) 8 | from semimtr.utils.utils import if_none 9 | 10 | 11 | class ResTranformer(nn.Module): 12 | def __init__(self, config): 13 | super().__init__() 14 | self.resnet = resnet45() 15 | 16 | self.d_model = if_none(config.model_vision_d_model, _default_tfmer_cfg['d_model']) 17 | nhead = if_none(config.model_vision_nhead, _default_tfmer_cfg['nhead']) 18 | d_inner = if_none(config.model_vision_d_inner, _default_tfmer_cfg['d_inner']) 19 | dropout = if_none(config.model_vision_dropout, _default_tfmer_cfg['dropout']) 20 | activation = if_none(config.model_vision_activation, _default_tfmer_cfg['activation']) 21 | num_layers = if_none(config.model_vision_backbone_ln, 2) 22 | 23 | self.pos_encoder = PositionalEncoding(self.d_model, max_len=8 * 32) 24 | encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead, 25 | dim_feedforward=d_inner, dropout=dropout, activation=activation) 26 | self.transformer = TransformerEncoder(encoder_layer, num_layers) 27 | 28 | def forward(self, images, *args): 29 | feature = self.resnet(images) 30 | n, c, h, w = feature.shape 31 | feature = feature.view(n, c, -1).permute(2, 0, 1) 32 | feature = self.pos_encoder(feature) 33 | feature = self.transformer(feature) 34 | feature = feature.permute(1, 2, 0).view(n, c, h, w) 35 | return feature 36 | -------------------------------------------------------------------------------- /semimtr/modules/model.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | 6 | from semimtr.utils.utils import CharsetMapper 7 | 8 | _default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024 9 | dropout=0.1, activation='relu') 10 | 11 | 12 | class Model(nn.Module): 13 | 14 | def __init__(self, config): 15 | super().__init__() 16 | self.max_length = config.dataset_max_length + 1 17 | self.charset = CharsetMapper(config.dataset_charset_path, max_length=self.max_length) 18 | 19 | def load(self, source, device=None, strict=True, submodule=None, exclude=None): 20 | state = torch.load(source, map_location=device) 21 | if source.endswith('.ckpt'): 22 | model_dict = state['state_dict'] 23 | if list(model_dict.keys())[0].startswith('model.'): 24 | model_dict = collections.OrderedDict( 25 | {k[6:]: v for k, v in model_dict.items() if k.startswith('model.')}) 26 | else: 27 | model_dict = state 28 | if 'model' in model_dict: 29 | model_dict = model_dict['model'] 30 | 31 | if submodule is None: 32 | stat = self.load_state_dict(model_dict, strict=strict) 33 | else: 34 | submodule_dict = collections.OrderedDict( 35 | {k.split('.', 1)[1]: v for k, v in model_dict.items() 36 | if k.split('.', 1)[0] == submodule and k.split('.')[1] != exclude} 37 | ) 38 | stat = self.load_state_dict(submodule_dict, strict=strict and exclude is None) 39 | if stat.missing_keys or stat.unexpected_keys: 40 | logging.warning(f'Loading model with missing keys: {stat.missing_keys}' 41 | f' and unexpected keys: {stat.unexpected_keys}') 42 | 43 | def _get_length(self, logit, dim=-1): 44 | """ Greed decoder to obtain length from logit""" 45 | out = (logit.argmax(dim=-1) == self.charset.null_label) 46 | abn = out.any(dim) 47 | out = ((out.cumsum(dim) == 1) & out).max(dim)[1] 48 | out = out + 1 # additional end token 49 | out = torch.where(abn, out, out.new_tensor(logit.shape[1])) 50 | return out 51 | 52 | @staticmethod 53 | def _get_padding_mask(length, max_length): 54 | length = length.unsqueeze(-1) 55 | grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) 56 | return grid >= length 57 | 58 | @staticmethod 59 | def _get_square_subsequent_mask(sz, device, diagonal=0, fw=True): 60 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 61 | Unmasked positions are filled with float(0.0). 62 | """ 63 | mask = (torch.triu(torch.ones(sz, sz, device=device), diagonal=diagonal) == 1) 64 | if fw: mask = mask.transpose(0, 1) 65 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 66 | return mask 67 | 68 | @staticmethod 69 | def _get_location_mask(sz, device=None): 70 | mask = torch.eye(sz, device=device) 71 | mask = mask.float().masked_fill(mask == 1, float('-inf')) 72 | return mask 73 | -------------------------------------------------------------------------------- /semimtr/modules/model_abinet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from semimtr.modules.model_vision import BaseVision 5 | from semimtr.modules.model_language import BCNLanguage 6 | from semimtr.modules.model_alignment import BaseAlignment 7 | from semimtr.utils.utils import if_none 8 | 9 | 10 | class ABINetModel(nn.Module): 11 | def __init__(self, config): 12 | super().__init__() 13 | self.use_alignment = if_none(config.model_use_alignment, True) 14 | self.max_length = config.dataset_max_length + 1 # additional stop token 15 | # self.vision_no_grad = if_none(config.model_vision_no_grad, False) 16 | self.vision = BaseVision(config) 17 | self.language = BCNLanguage(config) 18 | if self.use_alignment: self.alignment = BaseAlignment(config) 19 | 20 | def forward(self, images, *args, **kwargs): 21 | v_res = self.vision(images, *args, **kwargs) 22 | v_tokens = torch.softmax(v_res['logits'], dim=-1) 23 | v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model 24 | 25 | samples = {'label': v_tokens, 'length': v_lengths} 26 | l_res = self.language(samples, *args, **kwargs) 27 | if not self.use_alignment: 28 | return l_res, v_res 29 | l_feature, v_feature = l_res['feature'], v_res['feature'] 30 | 31 | a_res = self.alignment(l_feature, v_feature, *args, **kwargs) 32 | return a_res, l_res, v_res 33 | -------------------------------------------------------------------------------- /semimtr/modules/model_abinet_iter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from semimtr.modules.model_vision import BaseVision 5 | from semimtr.modules.model_language import BCNLanguage 6 | from semimtr.modules.model_alignment import BaseAlignment 7 | from semimtr.utils.utils import if_none 8 | 9 | 10 | class ABINetIterModel(nn.Module): 11 | def __init__(self, config): 12 | super().__init__() 13 | self.iter_size = if_none(config.model_iter_size, 1) 14 | self.max_length = config.dataset_max_length + 1 # additional stop token 15 | self.vision = BaseVision(config) 16 | self.language = BCNLanguage(config) 17 | self.alignment = BaseAlignment(config) 18 | 19 | def forward(self, images, *args, **kwargs): 20 | v_res = self.vision(images, *args, **kwargs) 21 | a_res = v_res 22 | all_l_res, all_a_res = [], [] 23 | for _ in range(self.iter_size): 24 | tokens = torch.softmax(a_res['logits'], dim=-1) 25 | lengths = a_res['pt_lengths'] 26 | lengths.clamp_(2, self.max_length) # TODO:move to langauge model 27 | samples = {'label': tokens, 'length': lengths} 28 | l_res = self.language(samples, *args, **kwargs) 29 | all_l_res.append(l_res) 30 | a_res = self.alignment(l_res['feature'], v_res['feature'], *args, **kwargs) 31 | all_a_res.append(a_res) 32 | if self.training: 33 | return all_a_res, all_l_res, v_res 34 | else: 35 | return a_res, all_l_res[-1], v_res 36 | -------------------------------------------------------------------------------- /semimtr/modules/model_alignment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from semimtr.modules.model import Model, _default_tfmer_cfg 5 | from semimtr.utils.utils import if_none 6 | 7 | 8 | class BaseAlignment(Model): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | d_model = if_none(config.model_alignment_d_model, _default_tfmer_cfg['d_model']) 12 | 13 | self.loss_weight = if_none(config.model_alignment_loss_weight, 1.0) 14 | self.max_length = config.dataset_max_length + 1 # additional stop token 15 | self.w_att = nn.Linear(2 * d_model, d_model) 16 | self.cls = nn.Linear(d_model, self.charset.num_classes) 17 | 18 | def forward(self, l_feature, v_feature, *args, **kwargs): 19 | """ 20 | Args: 21 | l_feature: (N, T, E) where T is length, N is batch size and d is dim of model 22 | v_feature: (N, T, E) shape the same as l_feature 23 | l_lengths: (N,) 24 | v_lengths: (N,) 25 | """ 26 | f = torch.cat((l_feature, v_feature), dim=2) 27 | f_att = torch.sigmoid(self.w_att(f)) 28 | output = f_att * v_feature + (1 - f_att) * l_feature 29 | 30 | logits = self.cls(output) # (N, T, C) 31 | pt_lengths = self._get_length(logits) 32 | 33 | return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight, 34 | 'alignment_feature': output, 'name': 'alignment'} 35 | -------------------------------------------------------------------------------- /semimtr/modules/model_fusion_consistency_regularization.py: -------------------------------------------------------------------------------- 1 | from semimtr.modules.model_abinet_iter import ABINetIterModel 2 | from semimtr.utils.utils import if_none 3 | 4 | 5 | class ConsistencyRegularizationFusionModel(ABINetIterModel): 6 | def __init__(self, config): 7 | super().__init__(config) 8 | self.loss_weight = if_none(config.model_teacher_student_loss_weight, 1.0) 9 | 10 | def forward(self, images, *args, forward_only_teacher=False, **kwargs): 11 | if forward_only_teacher: 12 | a_res_teacher, l_res_teacher, v_res_teacher = super().forward(images, *args, **kwargs) 13 | a_res_student, l_res_student, v_res_student = 0, 0, 0 14 | else: 15 | images_teacher_view, images_student_view = images[:, 0], images[:, 1] 16 | a_res_teacher, l_res_teacher, v_res_teacher = super().forward(images_teacher_view, *args, **kwargs) 17 | a_res_student, l_res_student, v_res_student = super().forward(images_student_view, *args, **kwargs) 18 | 19 | return {'teacher_outputs': [a_res_teacher, l_res_teacher, v_res_teacher], 20 | 'student_outputs': [a_res_student, l_res_student, v_res_student], 21 | 'loss_weight': self.loss_weight, 22 | 'name': 'teacher_student_fusion'} 23 | -------------------------------------------------------------------------------- /semimtr/modules/model_fusion_teacher_student_ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from semimtr.modules.model_abinet_iter import ABINetIterModel 5 | from semimtr.utils.utils import if_none 6 | 7 | 8 | class TeacherStudentFusionEMA(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | self.student = ABINetIterModel(config) 12 | self.teacher = ABINetIterModel(config) 13 | self.teacher.load_state_dict(self.student.state_dict()) 14 | self.loss_weight = if_none(config.model_teacher_student_loss_weight, 1.0) 15 | self.decay = if_none(config.model_teacher_student_ema_decay, 0.9999) 16 | 17 | def update_teacher(self): 18 | with torch.no_grad(): 19 | for param_student, param_teacher in zip(self.student.parameters(), self.teacher.parameters()): 20 | param_teacher.data.mul_(self.decay).add_((1 - self.decay) * param_student.detach().data) 21 | 22 | def forward(self, images, *args, **kwargs): 23 | with torch.no_grad(): 24 | a_res_teacher, l_res_teacher, v_res_teacher = self.teacher(images[:, 0], *args, **kwargs) 25 | a_res_student, l_res_student, v_res_student = self.student(images[:, 1], *args, **kwargs) 26 | 27 | return {'teacher_outputs': [a_res_teacher, l_res_teacher, v_res_teacher], 28 | 'student_outputs': [a_res_student, l_res_student, v_res_student], 29 | 'loss_weight': self.loss_weight, 30 | 'name': 'teacher_student_fusion'} 31 | -------------------------------------------------------------------------------- /semimtr/modules/model_language.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from semimtr.modules.model import _default_tfmer_cfg 5 | from semimtr.modules.model import Model 6 | from semimtr.modules.transformer import (PositionalEncoding, 7 | TransformerDecoder, 8 | TransformerDecoderLayer) 9 | from semimtr.utils.utils import if_none 10 | 11 | 12 | class BCNLanguage(Model): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | d_model = if_none(config.model_language_d_model, _default_tfmer_cfg['d_model']) 16 | nhead = if_none(config.model_language_nhead, _default_tfmer_cfg['nhead']) 17 | d_inner = if_none(config.model_language_d_inner, _default_tfmer_cfg['d_inner']) 18 | dropout = if_none(config.model_language_dropout, _default_tfmer_cfg['dropout']) 19 | activation = if_none(config.model_language_activation, _default_tfmer_cfg['activation']) 20 | num_layers = if_none(config.model_language_num_layers, 4) 21 | self.d_model = d_model 22 | self.detach = if_none(config.model_language_detach, True) 23 | self.use_self_attn = if_none(config.model_language_use_self_attn, False) 24 | self.loss_weight = if_none(config.model_language_loss_weight, 1.0) 25 | self.max_length = config.dataset_max_length + 1 # additional stop token 26 | self.debug = if_none(config.global_debug, False) 27 | 28 | self.proj = nn.Linear(self.charset.num_classes, d_model, False) 29 | self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) 30 | self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) 31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, 32 | activation, self_attn=self.use_self_attn, debug=self.debug) 33 | self.model = TransformerDecoder(decoder_layer, num_layers) 34 | 35 | self.cls = nn.Linear(d_model, self.charset.num_classes) 36 | 37 | if config.model_language_checkpoint is not None: 38 | logging.info(f'Read language model from {config.model_language_checkpoint}.') 39 | self.load(config.model_language_checkpoint) 40 | 41 | def forward(self, samples, *args, **kwargs): 42 | """ 43 | Args: 44 | samples: dict 45 | tokens: (N, T, C) where T is length, N is batch size and C is classes number 46 | lengths: (N,) 47 | """ 48 | tokens, lengths = samples['label'], samples['length'] 49 | if self.detach: tokens = tokens.detach() 50 | embed = self.proj(tokens) # (N, T, E) 51 | embed = embed.permute(1, 0, 2) # (T, N, E) 52 | embed = self.token_encoder(embed) # (T, N, E) 53 | padding_mask = self._get_padding_mask(lengths, self.max_length) 54 | 55 | zeros = embed.new_zeros(*embed.shape) 56 | qeury = self.pos_encoder(zeros) 57 | location_mask = self._get_location_mask(self.max_length, tokens.device) 58 | output = self.model(qeury, embed, 59 | tgt_key_padding_mask=padding_mask, 60 | memory_mask=location_mask, 61 | memory_key_padding_mask=padding_mask) # (T, N, E) 62 | output = output.permute(1, 0, 2) # (N, T, E) 63 | 64 | logits = self.cls(output) # (N, T, C) 65 | pt_lengths = self._get_length(logits) 66 | 67 | res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, 68 | 'loss_weight': self.loss_weight, 'name': 'language'} 69 | return res 70 | -------------------------------------------------------------------------------- /semimtr/modules/model_seqclr_vision.py: -------------------------------------------------------------------------------- 1 | from semimtr.modules.model_vision import BaseVision 2 | from semimtr.modules.model import Model 3 | from semimtr.modules.seqclr_proj import SeqCLRProj 4 | from semimtr.utils.utils import if_none 5 | 6 | 7 | class SeqCLRModel(Model): 8 | def __init__(self, config): 9 | super().__init__(config) 10 | self.vision = BaseVision(config) 11 | self.seqclr_proj = SeqCLRProj(config) 12 | self.loss_weight = if_none(config.model_contrastive_loss_weight, 1.0) 13 | 14 | def forward(self, images, *args, **kwargs): 15 | v_res_view0 = self.vision(images[:, 0], *args, **kwargs) 16 | v_res_view1 = self.vision(images[:, 1], *args, **kwargs) 17 | 18 | projected_features_view0 = self.seqclr_proj(v_res_view0)[0] 19 | projected_features_view1 = self.seqclr_proj(v_res_view1)[0] 20 | 21 | return {'supervised_outputs_view0': v_res_view0, 22 | 'supervised_outputs_view1': v_res_view1, 23 | 'instances_view0': projected_features_view0['instances'], 24 | 'instances_view1': projected_features_view1['instances'], 25 | 'loss_weight': self.loss_weight, 26 | 'name': 'seqclr_vision'} 27 | -------------------------------------------------------------------------------- /semimtr/modules/model_vision.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from semimtr.modules.attention import PositionAttention, Attention 5 | from semimtr.modules.backbone import ResTranformer 6 | from semimtr.modules.model import Model 7 | from semimtr.modules.resnet import resnet45 8 | from semimtr.utils.utils import if_none 9 | 10 | 11 | class BaseVision(Model): 12 | def __init__(self, config): 13 | super().__init__(config) 14 | self.loss_weight = if_none(config.model_vision_loss_weight, 1.0) 15 | self.out_channels = if_none(config.model_vision_d_model, 512) 16 | 17 | if config.model_vision_backbone == 'transformer': 18 | self.backbone = ResTranformer(config) 19 | else: 20 | self.backbone = resnet45() 21 | 22 | if config.model_vision_attention == 'position': 23 | mode = if_none(config.model_vision_attention_mode, 'nearest') 24 | self.attention = PositionAttention( 25 | max_length=config.dataset_max_length + 1, # additional stop token 26 | mode=mode, 27 | ) 28 | elif config.model_vision_attention == 'attention': 29 | self.attention = Attention( 30 | max_length=config.dataset_max_length + 1, # additional stop token 31 | n_feature=8 * 32, 32 | ) 33 | else: 34 | raise NotImplementedError(f'{config.model_vision_attention} is not valid.') 35 | self.cls = nn.Linear(self.out_channels, self.charset.num_classes) 36 | 37 | if config.model_vision_checkpoint is not None: 38 | logging.info(f'Read vision model from {config.model_vision_checkpoint}.') 39 | self.load(config.model_vision_checkpoint, submodule=config.model_vision_checkpoint_submodule, 40 | exclude=config.model_vision_exclude) 41 | 42 | def forward(self, images, *args, **kwargs): 43 | features = self.backbone(images) # (N, E, H, W) 44 | attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) 45 | logits = self.cls(attn_vecs) # (N, T, C) 46 | pt_lengths = self._get_length(logits) 47 | 48 | return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, 49 | 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision', 50 | 'backbone_feature': features} 51 | -------------------------------------------------------------------------------- /semimtr/modules/projections.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BidirectionalLSTM(nn.Module): 6 | 7 | def __init__(self, input_size, hidden_size, output_size): 8 | super().__init__() 9 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 10 | self.linear = nn.Linear(hidden_size * 2, output_size) 11 | 12 | def forward(self, input): 13 | """ 14 | input : visual feature [batch_size x T x input_size] 15 | output : contextual feature [batch_size x T x output_size] 16 | """ 17 | self.rnn.flatten_parameters() 18 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 19 | output = self.linear(recurrent) # batch_size x T x output_size 20 | return output 21 | 22 | 23 | class AttnLinear(nn.Module): 24 | 25 | def __init__(self, input_size, hidden_size, output_size): 26 | super().__init__() 27 | self.w_att = nn.Linear(input_size, hidden_size) 28 | self.cls = nn.Linear(hidden_size, output_size) 29 | 30 | def forward(self, features): 31 | f_att = torch.sigmoid(self.w_att(features)) 32 | v_feature, l_feature = torch.chunk(features, 2, dim=-1) 33 | output = f_att * v_feature + (1 - f_att) * l_feature 34 | return self.cls(output) # (N, T, C) 35 | -------------------------------------------------------------------------------- /semimtr/modules/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def conv1x1(in_planes, out_planes, stride=1): 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | "3x3 convolution with padding" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv1x1(inplanes, planes) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = conv3x3(planes, planes, stride) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class ResNet(nn.Module): 49 | 50 | def __init__(self, block, layers): 51 | self.inplanes = 32 52 | super(ResNet, self).__init__() 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, 54 | bias=False) 55 | self.bn1 = nn.BatchNorm2d(32) 56 | self.relu = nn.ReLU(inplace=True) 57 | 58 | self.layer1 = self._make_layer(block, 32, layers[0], stride=2) 59 | self.layer2 = self._make_layer(block, 64, layers[1], stride=1) 60 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 61 | self.layer4 = self._make_layer(block, 256, layers[3], stride=1) 62 | self.layer5 = self._make_layer(block, 512, layers[4], stride=1) 63 | 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 67 | m.weight.data.normal_(0, math.sqrt(2. / n)) 68 | elif isinstance(m, nn.BatchNorm2d): 69 | m.weight.data.fill_(1) 70 | m.bias.data.zero_() 71 | 72 | def _make_layer(self, block, planes, blocks, stride=1): 73 | downsample = None 74 | if stride != 1 or self.inplanes != planes * block.expansion: 75 | downsample = nn.Sequential( 76 | nn.Conv2d(self.inplanes, planes * block.expansion, 77 | kernel_size=1, stride=stride, bias=False), 78 | nn.BatchNorm2d(planes * block.expansion), 79 | ) 80 | 81 | layers = [] 82 | layers.append(block(self.inplanes, planes, stride, downsample)) 83 | self.inplanes = planes * block.expansion 84 | for i in range(1, blocks): 85 | layers.append(block(self.inplanes, planes)) 86 | 87 | return nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | x = self.conv1(x) 91 | x = self.bn1(x) 92 | x = self.relu(x) 93 | x = self.layer1(x) 94 | x = self.layer2(x) 95 | x = self.layer3(x) 96 | x = self.layer4(x) 97 | x = self.layer5(x) 98 | return x 99 | 100 | 101 | def resnet45(): 102 | return ResNet(BasicBlock, [3, 4, 6, 6, 3]) 103 | -------------------------------------------------------------------------------- /semimtr/modules/seqclr_proj.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from semimtr.modules.model import Model, _default_tfmer_cfg 5 | from semimtr.modules.projections import BidirectionalLSTM, AttnLinear 6 | from semimtr.utils.utils import if_none 7 | 8 | 9 | class SeqCLRProj(Model): 10 | def __init__(self, config): 11 | super().__init__(config) 12 | vision_d_model = if_none(config.model_vision_d_model, _default_tfmer_cfg['d_model']) 13 | self.working_layer = if_none(config.model_proj_layer, 'feature') 14 | if self.working_layer in ['feature', 'backbone_feature', 'alignment_feature']: 15 | projection_input_size = vision_d_model 16 | else: 17 | raise NotImplementedError(f'SeqCLR projection head does not support working layer of {self.working_layer}.') 18 | 19 | if config.model_proj_scheme is None: 20 | self.projection = nn.Identity() 21 | projection_output_size = projection_input_size 22 | elif config.model_proj_scheme == 'bilstm': 23 | projection_hidden_size = if_none(config.model_proj_hidden, projection_input_size) 24 | projection_output_size = if_none(config.model_proj_output, projection_input_size) 25 | self.projection = BidirectionalLSTM(projection_input_size, 26 | projection_hidden_size, 27 | projection_output_size) 28 | elif config.model_proj_scheme == 'linear_per_column': 29 | projection_output_size = if_none(config.model_proj_output, projection_input_size) 30 | self.projection = nn.Linear(projection_input_size, projection_output_size) 31 | elif config.model_proj_scheme == 'attn_linear_per_column': 32 | projection_hidden_size = if_none(config.model_proj_hidden, projection_input_size // 2) 33 | projection_output_size = if_none(config.model_proj_output, self.charset.num_classes) 34 | self.projection = AttnLinear(projection_input_size, 35 | projection_hidden_size, 36 | projection_output_size) 37 | else: 38 | raise NotImplementedError(f'The projection scheme of {config.model_proj_scheme} is not supported.') 39 | 40 | if config.model_instance_mapping_frame_to_instance: 41 | self.instance_mapping_func = nn.Identity() 42 | else: 43 | instance_mapping_fixed = if_none(config.model_instance_mapping_fixed, 'instances') 44 | w = if_none(config.model_instance_mapping_w, 5) 45 | if instance_mapping_fixed == 'instances': 46 | self.instance_mapping_func = nn.AdaptiveAvgPool2d((w, projection_output_size)) 47 | elif instance_mapping_fixed == 'frames': 48 | self.instance_mapping_func = AvgPool(kernel_size=w, stride=w) 49 | else: 50 | raise NotImplementedError(f'instance_mapping_fixed of {instance_mapping_fixed} is not supported') 51 | 52 | if config.model_proj_checkpoint is not None: 53 | logging.info(f'Read projection head model from {config.model_proj_checkpoint}.') 54 | self.load(config.model_proj_checkpoint) 55 | 56 | def _single_forward(self, output): 57 | features = output[self.working_layer] 58 | if self.working_layer == 'backbone_feature': 59 | features = features.permute(0, 2, 3, 1).flatten(1, 2) # (N, E, H, W) -> (N, H*W, E) 60 | projected_features = self.projection(features) 61 | projected_instances = self.instance_mapping_func(projected_features) 62 | return {'instances': projected_instances, 'name': 'projection_head'} 63 | 64 | def forward(self, output, *args): 65 | if isinstance(output, (tuple, list)): 66 | return [self._single_forward(o) for o in output] 67 | else: 68 | return [self._single_forward(output)] 69 | 70 | 71 | class AvgPool(nn.Module): 72 | def __init__(self, kernel_size, stride): 73 | super().__init__() 74 | self.avg_pool = nn.AvgPool1d(kernel_size=kernel_size, stride=stride) 75 | 76 | def forward(self, x): 77 | return self.avg_pool(x.permute(0, 2, 1)).permute(0, 2, 1).contiguous() 78 | -------------------------------------------------------------------------------- /semimtr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/semimtr-text-recognition/043d65b3caac416a65ccd10ecd965ce3bbaa62ad/semimtr/utils/__init__.py -------------------------------------------------------------------------------- /semimtr/utils/test.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import ConcatDataset 2 | from fastai.vision import * 3 | from semimtr.callbacks.callbacks import IterationCallback 4 | 5 | 6 | def test_on_each_ds(learner): 7 | test_dls = dataset_to_dataloader_list( 8 | learner.data.test_ds, 9 | batch_size=learner.data.test_dl.batch_size, 10 | device=learner.data.device, 11 | transforms=learner.data.test_dl.tfms, 12 | collate_fn=learner.data.test_dl.collate_fn 13 | ) 14 | last_metrics_list = [] 15 | ds_sizes = [] 16 | loss_dict = {} 17 | for dl in test_dls: 18 | dl_name = dl.dataset.name 19 | last_metrics = learner.validate(dl=dl) 20 | last_metrics_list.append(last_metrics) 21 | IterationCallback._metrics_to_logging(last_metrics, f'{dl_name} test', dl_len=len(dl.dataset)) 22 | ds_sizes.append(len(dl.dataset)) 23 | loss_dict[dl_name] = [ds_sizes[-1]] + last_metrics 24 | 25 | last_metrics_average = np.average(last_metrics_list, axis=0, weights=ds_sizes) 26 | names = IterationCallback._metrics_to_logging(last_metrics_average, f'average test') 27 | loss_dict['Average'] = [sum(ds_sizes)] + list(last_metrics_average) 28 | df = pd.DataFrame.from_dict(loss_dict, orient='index', columns=['size'] + names) 29 | df.T.to_csv(learner.path / learner.model_dir / f'test_results.csv') 30 | 31 | 32 | def dataset_to_dataloader_list(dataset, batch_size, device, transforms, collate_fn): 33 | if isinstance(dataset, ConcatDataset): 34 | test_dls = [] 35 | for ds in dataset.datasets: 36 | test_dls.extend(dataset_to_dataloader_list(ds, batch_size, device, transforms, collate_fn)) 37 | return test_dls 38 | else: 39 | return [DeviceDataLoader(DataLoader(dataset, batch_size), device, transforms, collate_fn)] 40 | -------------------------------------------------------------------------------- /semimtr/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision import transforms 9 | from torchvision.transforms import Compose 10 | 11 | 12 | def sample_asym(magnitude, size=None): 13 | return np.random.beta(1, 4, size) * magnitude 14 | 15 | 16 | def sample_sym(magnitude, size=None): 17 | return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude 18 | 19 | 20 | def sample_uniform(low, high, size=None): 21 | return np.random.uniform(low, high, size=size) 22 | 23 | 24 | def get_interpolation(type='random'): 25 | if type == 'random': 26 | choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA] 27 | interpolation = choice[random.randint(0, len(choice) - 1)] 28 | elif type == 'nearest': 29 | interpolation = cv2.INTER_NEAREST 30 | elif type == 'linear': 31 | interpolation = cv2.INTER_LINEAR 32 | elif type == 'cubic': 33 | interpolation = cv2.INTER_CUBIC 34 | elif type == 'area': 35 | interpolation = cv2.INTER_AREA 36 | else: 37 | raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!') 38 | return interpolation 39 | 40 | 41 | class CVRandomRotation(object): 42 | def __init__(self, degrees=15): 43 | assert isinstance(degrees, numbers.Number), "degree should be a single number." 44 | assert degrees >= 0, "degree must be positive." 45 | self.degrees = degrees 46 | 47 | @staticmethod 48 | def get_params(degrees): 49 | return sample_sym(degrees) 50 | 51 | def __call__(self, img): 52 | angle = self.get_params(self.degrees) 53 | src_h, src_w = img.shape[:2] 54 | M = cv2.getRotationMatrix2D(center=(src_w / 2, src_h / 2), angle=angle, scale=1.0) 55 | abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1]) 56 | dst_w = int(src_h * abs_sin + src_w * abs_cos) 57 | dst_h = int(src_h * abs_cos + src_w * abs_sin) 58 | M[0, 2] += (dst_w - src_w) / 2 59 | M[1, 2] += (dst_h - src_h) / 2 60 | 61 | flags = get_interpolation() 62 | return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) 63 | 64 | 65 | class CVRandomAffine(object): 66 | def __init__(self, degrees, translate=None, scale=None, shear=None): 67 | assert isinstance(degrees, numbers.Number), "degree should be a single number." 68 | assert degrees >= 0, "degree must be positive." 69 | self.degrees = degrees 70 | 71 | if translate is not None: 72 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 73 | "translate should be a list or tuple and it must be of length 2." 74 | for t in translate: 75 | if not (0.0 <= t <= 1.0): 76 | raise ValueError("translation values should be between 0 and 1") 77 | self.translate = translate 78 | 79 | if scale is not None: 80 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 81 | "scale should be a list or tuple and it must be of length 2." 82 | for s in scale: 83 | if s <= 0: 84 | raise ValueError("scale values should be positive") 85 | self.scale = scale 86 | 87 | if shear is not None: 88 | if isinstance(shear, numbers.Number): 89 | if shear < 0: 90 | raise ValueError("If shear is a single number, it must be positive.") 91 | self.shear = [shear] 92 | else: 93 | assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \ 94 | "shear should be a list or tuple and it must be of length 2." 95 | self.shear = shear 96 | else: 97 | self.shear = shear 98 | 99 | def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear): 100 | # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717 101 | from numpy import sin, cos, tan 102 | 103 | if isinstance(shear, numbers.Number): 104 | shear = [shear, 0] 105 | 106 | if not isinstance(shear, (tuple, list)) and len(shear) == 2: 107 | raise ValueError( 108 | "Shear should be a single value or a tuple/list containing " + 109 | "two values. Got {}".format(shear)) 110 | 111 | rot = math.radians(angle) 112 | sx, sy = [math.radians(s) for s in shear] 113 | 114 | cx, cy = center 115 | tx, ty = translate 116 | 117 | # RSS without scaling 118 | a = cos(rot - sy) / cos(sy) 119 | b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) 120 | c = sin(rot - sy) / cos(sy) 121 | d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) 122 | 123 | # Inverted rotation matrix with scale and shear 124 | # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 125 | M = [d, -b, 0, 126 | -c, a, 0] 127 | M = [x / scale for x in M] 128 | 129 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 130 | M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) 131 | M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) 132 | 133 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 134 | M[2] += cx 135 | M[5] += cy 136 | return M 137 | 138 | @staticmethod 139 | def get_params(degrees, translate, scale_ranges, shears, height): 140 | angle = sample_sym(degrees) 141 | if translate is not None: 142 | max_dx = translate[0] * height 143 | max_dy = translate[1] * height 144 | translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy))) 145 | else: 146 | translations = (0, 0) 147 | 148 | if scale_ranges is not None: 149 | scale = sample_uniform(scale_ranges[0], scale_ranges[1]) 150 | else: 151 | scale = 1.0 152 | 153 | if shears is not None: 154 | if len(shears) == 1: 155 | shear = [sample_sym(shears[0]), 0.] 156 | elif len(shears) == 2: 157 | shear = [sample_sym(shears[0]), sample_sym(shears[1])] 158 | else: 159 | shear = 0.0 160 | 161 | return angle, translations, scale, shear 162 | 163 | def __call__(self, img): 164 | src_h, src_w = img.shape[:2] 165 | angle, translate, scale, shear = self.get_params( 166 | self.degrees, self.translate, self.scale, self.shear, src_h) 167 | 168 | M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle, (0, 0), scale, shear) 169 | M = np.array(M).reshape(2, 3) 170 | 171 | startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)] 172 | project = lambda x, y, a, b, c: int(a * x + b * y + c) 173 | endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints] 174 | 175 | rect = cv2.minAreaRect(np.array(endpoints)) 176 | bbox = cv2.boxPoints(rect).astype(dtype=np.int) 177 | max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() 178 | min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() 179 | 180 | dst_w = int(max_x - min_x) 181 | dst_h = int(max_y - min_y) 182 | M[0, 2] += (dst_w - src_w) / 2 183 | M[1, 2] += (dst_h - src_h) / 2 184 | 185 | # add translate 186 | dst_w += int(abs(translate[0])) 187 | dst_h += int(abs(translate[1])) 188 | if translate[0] < 0: M[0, 2] += abs(translate[0]) 189 | if translate[1] < 0: M[1, 2] += abs(translate[1]) 190 | 191 | flags = get_interpolation() 192 | return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) 193 | 194 | 195 | class CVRandomPerspective(object): 196 | def __init__(self, distortion=0.5): 197 | self.distortion = distortion 198 | 199 | def get_params(self, width, height, distortion): 200 | offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int) 201 | offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int) 202 | topleft = (offset_w[0], offset_h[0]) 203 | topright = (width - 1 - offset_w[1], offset_h[1]) 204 | botright = (width - 1 - offset_w[2], height - 1 - offset_h[2]) 205 | botleft = (offset_w[3], height - 1 - offset_h[3]) 206 | 207 | startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] 208 | endpoints = [topleft, topright, botright, botleft] 209 | return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32) 210 | 211 | def __call__(self, img): 212 | height, width = img.shape[:2] 213 | startpoints, endpoints = self.get_params(width, height, self.distortion) 214 | M = cv2.getPerspectiveTransform(startpoints, endpoints) 215 | 216 | # TODO: more robust way to crop image 217 | rect = cv2.minAreaRect(endpoints) 218 | bbox = cv2.boxPoints(rect).astype(dtype=np.int) 219 | max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() 220 | min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() 221 | min_x, min_y = max(min_x, 0), max(min_y, 0) 222 | 223 | flags = get_interpolation() 224 | img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE) 225 | img = img[min_y:, min_x:] 226 | return img 227 | 228 | 229 | class CVRescale(object): 230 | 231 | def __init__(self, factor=4, base_size=(128, 512)): 232 | """ Define image scales using gaussian pyramid and rescale image to target scale. 233 | 234 | Args: 235 | factor: the decayed factor from base size, factor=4 keeps target scale by default. 236 | base_size: base size the build the bottom layer of pyramid 237 | """ 238 | if isinstance(factor, numbers.Number): 239 | self.factor = round(sample_uniform(0, factor)) 240 | elif isinstance(factor, (tuple, list)) and len(factor) == 2: 241 | self.factor = round(sample_uniform(factor[0], factor[1])) 242 | else: 243 | raise Exception('factor must be number or list with length 2') 244 | # assert factor is valid 245 | self.base_h, self.base_w = base_size[:2] 246 | 247 | def __call__(self, img): 248 | if self.factor == 0: return img 249 | src_h, src_w = img.shape[:2] 250 | cur_w, cur_h = self.base_w, self.base_h 251 | scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation()) 252 | for _ in range(self.factor): 253 | scale_img = cv2.pyrDown(scale_img) 254 | scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation()) 255 | return scale_img 256 | 257 | 258 | class CVGaussianNoise(object): 259 | def __init__(self, mean=0, var=20): 260 | self.mean = mean 261 | if isinstance(var, numbers.Number): 262 | self.var = max(int(sample_asym(var)), 1) 263 | elif isinstance(var, (tuple, list)) and len(var) == 2: 264 | self.var = int(sample_uniform(var[0], var[1])) 265 | else: 266 | raise Exception('degree must be number or list with length 2') 267 | 268 | def __call__(self, img): 269 | noise = np.random.normal(self.mean, self.var ** 0.5, img.shape) 270 | img = np.clip(img + noise, 0, 255).astype(np.uint8) 271 | return img 272 | 273 | 274 | class CVMotionBlur(object): 275 | def __init__(self, degrees=12, angle=90): 276 | if isinstance(degrees, numbers.Number): 277 | self.degree = max(int(sample_asym(degrees)), 1) 278 | elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: 279 | self.degree = int(sample_uniform(degrees[0], degrees[1])) 280 | else: 281 | raise Exception('degree must be number or list with length 2') 282 | self.angle = sample_uniform(-angle, angle) 283 | 284 | def __call__(self, img): 285 | M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1) 286 | motion_blur_kernel = np.zeros((self.degree, self.degree)) 287 | motion_blur_kernel[self.degree // 2, :] = 1 288 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree)) 289 | motion_blur_kernel = motion_blur_kernel / self.degree 290 | img = cv2.filter2D(img, -1, motion_blur_kernel) 291 | img = np.clip(img, 0, 255).astype(np.uint8) 292 | return img 293 | 294 | 295 | class CVGeometry(object): 296 | def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.), 297 | shear=(45, 15), distortion=0.5, p=0.5): 298 | self.p = p 299 | type_p = random.random() 300 | if type_p < 0.33: 301 | self.transforms = CVRandomRotation(degrees=degrees) 302 | elif type_p < 0.66: 303 | self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) 304 | else: 305 | self.transforms = CVRandomPerspective(distortion=distortion) 306 | 307 | def __call__(self, img): 308 | if random.random() < self.p: 309 | img = np.array(img) 310 | return Image.fromarray(self.transforms(img)) 311 | else: 312 | return img 313 | 314 | 315 | class CVDeterioration(object): 316 | def __init__(self, var, degrees, factor, p=0.5): 317 | self.p = p 318 | transforms = [] 319 | if var is not None: 320 | transforms.append(CVGaussianNoise(var=var)) 321 | if degrees is not None: 322 | transforms.append(CVMotionBlur(degrees=degrees)) 323 | if factor is not None: 324 | transforms.append(CVRescale(factor=factor)) 325 | 326 | random.shuffle(transforms) 327 | transforms = Compose(transforms) 328 | self.transforms = transforms 329 | 330 | def __call__(self, img): 331 | if random.random() < self.p: 332 | img = np.array(img) 333 | return Image.fromarray(self.transforms(img)) 334 | else: 335 | return img 336 | 337 | 338 | class CVColorJitter(object): 339 | def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5): 340 | self.p = p 341 | self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast, 342 | saturation=saturation, hue=hue) 343 | 344 | def __call__(self, img): 345 | if random.random() < self.p: 346 | return self.transforms(img) 347 | else: 348 | return img 349 | 350 | 351 | class ImageToArray(object): 352 | 353 | def __call__(self, img): 354 | return np.array(img) 355 | 356 | 357 | class ImageToPIL(object): 358 | 359 | def __call__(self, img): 360 | return Image.fromarray(img) 361 | -------------------------------------------------------------------------------- /semimtr/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import yaml 9 | from matplotlib import colors 10 | from matplotlib import pyplot as plt 11 | from torch import Tensor, nn 12 | from torch.utils.data import ConcatDataset 13 | 14 | 15 | class CharsetMapper(object): 16 | """A simple class to map ids into strings. 17 | 18 | It works only when the character set is 1:1 mapping between individual 19 | characters and individual ids. 20 | """ 21 | 22 | def __init__(self, 23 | filename='', 24 | max_length=30, 25 | null_char=u'\u2591'): 26 | """Creates a lookup table. 27 | 28 | Args: 29 | filename: Path to charset file which maps characters to ids. 30 | max_sequence_length: The max length of ids and string. 31 | null_char: A unicode character used to replace '' character. 32 | the default value is a light shade block '░'. 33 | """ 34 | self.null_char = null_char 35 | self.max_length = max_length 36 | 37 | self.label_to_char = self._read_charset(filename) 38 | self.char_to_label = dict(map(reversed, self.label_to_char.items())) 39 | self.num_classes = len(self.label_to_char) 40 | 41 | def _read_charset(self, filename): 42 | """Reads a charset definition from a tab separated text file. 43 | 44 | Args: 45 | filename: a path to the charset file. 46 | 47 | Returns: 48 | a dictionary with keys equal to character codes and values - unicode 49 | characters. 50 | """ 51 | import re 52 | pattern = re.compile(r'(\d+)\t(.+)') 53 | charset = {} 54 | self.null_label = 0 55 | charset[self.null_label] = self.null_char 56 | with open(filename, 'r') as f: 57 | for i, line in enumerate(f): 58 | m = pattern.match(line) 59 | assert m, f'Incorrect charset file. line #{i}: {line}' 60 | label = int(m.group(1)) + 1 61 | char = m.group(2) 62 | charset[label] = char 63 | return charset 64 | 65 | def trim(self, text): 66 | assert isinstance(text, str) 67 | return text.replace(self.null_char, '') 68 | 69 | def get_text(self, labels, length=None, padding=True, trim=False): 70 | """ Returns a string corresponding to a sequence of character ids. 71 | """ 72 | length = length if length else self.max_length 73 | labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels] 74 | if padding: 75 | labels = labels + [self.null_label] * (length - len(labels)) 76 | text = ''.join([self.label_to_char[label] for label in labels]) 77 | if trim: text = self.trim(text) 78 | return text 79 | 80 | def get_labels(self, text, length=None, padding=True, case_sensitive=False): 81 | """ Returns the labels of the corresponding text. 82 | """ 83 | length = length if length else self.max_length 84 | if padding: 85 | text = text + self.null_char * (length - len(text)) 86 | if not case_sensitive: 87 | text = text.lower() 88 | labels = [self.char_to_label[char] for char in text] 89 | return labels 90 | 91 | def pad_labels(self, labels, length=None): 92 | length = length if length else self.max_length 93 | 94 | return labels + [self.null_label] * (length - len(labels)) 95 | 96 | @property 97 | def digits(self): 98 | return '0123456789' 99 | 100 | @property 101 | def digit_labels(self): 102 | return self.get_labels(self.digits, padding=False) 103 | 104 | @property 105 | def alphabets(self): 106 | all_chars = list(self.char_to_label.keys()) 107 | valid_chars = [] 108 | for c in all_chars: 109 | if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': 110 | valid_chars.append(c) 111 | return ''.join(valid_chars) 112 | 113 | @property 114 | def alphabet_labels(self): 115 | return self.get_labels(self.alphabets, padding=False) 116 | 117 | 118 | class Timer(object): 119 | """A simple timer.""" 120 | 121 | def __init__(self): 122 | self.data_time = 0. 123 | self.data_diff = 0. 124 | self.data_total_time = 0. 125 | self.data_call = 0 126 | self.running_time = 0. 127 | self.running_diff = 0. 128 | self.running_total_time = 0. 129 | self.running_call = 0 130 | 131 | def tic(self): 132 | self.start_time = time.time() 133 | self.running_time = self.start_time 134 | 135 | def toc_data(self): 136 | self.data_time = time.time() 137 | self.data_diff = self.data_time - self.running_time 138 | self.data_total_time += self.data_diff 139 | self.data_call += 1 140 | 141 | def toc_running(self): 142 | self.running_time = time.time() 143 | self.running_diff = self.running_time - self.data_time 144 | self.running_total_time += self.running_diff 145 | self.running_call += 1 146 | 147 | def total_time(self): 148 | return self.data_total_time + self.running_total_time 149 | 150 | def average_time(self): 151 | return self.average_data_time() + self.average_running_time() 152 | 153 | def average_data_time(self): 154 | return self.data_total_time / (self.data_call or 1) 155 | 156 | def average_running_time(self): 157 | return self.running_total_time / (self.running_call or 1) 158 | 159 | 160 | class Logger(object): 161 | _handle = None 162 | _root = None 163 | 164 | @staticmethod 165 | def init(output_dir, name, phase): 166 | format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \ 167 | '%(message)s'.format(name) 168 | logging.basicConfig(level=logging.INFO, format=format) 169 | 170 | try: 171 | os.makedirs(output_dir) 172 | except: 173 | pass 174 | config_path = os.path.join(output_dir, f'{phase}.txt') 175 | Logger._handle = logging.FileHandler(config_path) 176 | Logger._root = logging.getLogger() 177 | 178 | @staticmethod 179 | def enable_file(): 180 | if Logger._handle is None or Logger._root is None: 181 | raise Exception('Invoke Logger.init() first!') 182 | Logger._root.addHandler(Logger._handle) 183 | 184 | @staticmethod 185 | def disable_file(): 186 | if Logger._handle is None or Logger._root is None: 187 | raise Exception('Invoke Logger.init() first!') 188 | Logger._root.removeHandler(Logger._handle) 189 | 190 | 191 | class Config(object): 192 | 193 | def __init__(self, config_path, host=True): 194 | def __dict2attr(d, prefix=''): 195 | for k, v in d.items(): 196 | if isinstance(v, dict): 197 | __dict2attr(v, f'{prefix}{k}_') 198 | else: 199 | if k == 'phase': 200 | assert v in ['train', 'test'] 201 | if k == 'stage': 202 | assert v in ['pretrain-vision', 'pretrain-language', 'pretrain-fusion', 203 | 'train-semi-supervised', 'train-supervised'] 204 | self.__setattr__(f'{prefix}{k}', v) 205 | 206 | assert os.path.exists(config_path), '%s does not exists!' % config_path 207 | with open(config_path) as file: 208 | config_dict = yaml.safe_load(file) 209 | with open('configs/template.yaml') as file: 210 | default_config_dict = yaml.safe_load(file) 211 | __dict2attr(default_config_dict) 212 | if 'global' in config_dict.keys() and 'experiment_template' in config_dict['global'].keys(): 213 | with open(f"configs/{config_dict['global']['experiment_template']}") as file: 214 | default_exp_config_dict = yaml.safe_load(file) 215 | __dict2attr(default_exp_config_dict) 216 | __dict2attr(config_dict) 217 | self.global_workdir = os.path.join(self.global_workdir, self.global_name) 218 | 219 | def __getattr__(self, item): 220 | attr = self.__dict__.get(item) 221 | if attr is None: 222 | attr = dict() 223 | prefix = f'{item}_' 224 | for k, v in self.__dict__.items(): 225 | if k.startswith(prefix): 226 | n = k.replace(prefix, '') 227 | attr[n] = v 228 | return attr if len(attr) > 0 else None 229 | else: 230 | return attr 231 | 232 | def __repr__(self): 233 | str = 'ModelConfig(\n' 234 | for i, (k, v) in enumerate(sorted(vars(self).items())): 235 | str += f'\t({i}): {k} = {v}\n' 236 | str += ')' 237 | return str 238 | 239 | 240 | def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0): 241 | # normalize mask 242 | mask = (mask - mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps) 243 | if mask.shape != image.shape: 244 | mask = cv2.resize(mask, (image.shape[1], image.shape[0])) 245 | # get color map 246 | color_map = plt.get_cmap(cmap) 247 | mask = color_map(mask)[:, :, :3] 248 | # convert float to uint8 249 | mask = (mask * 255).astype(dtype=np.uint8) 250 | 251 | # set the basic color 252 | basic_color = np.array(colors.to_rgb(color)) * 255 253 | basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1]) 254 | basic_color = basic_color.astype(dtype=np.uint8) 255 | # blend with basic color 256 | blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1 - color_alpha, 0) 257 | # blend with mask 258 | blended_img = cv2.addWeighted(blended_img, alpha, mask, 1 - alpha, 0) 259 | 260 | return blended_img 261 | 262 | 263 | def onehot(label, depth, device=None): 264 | """ 265 | Args: 266 | label: shape (n1, n2, ..., ) 267 | depth: a scalar 268 | 269 | Returns: 270 | onehot: (n1, n2, ..., depth) 271 | """ 272 | if not isinstance(label, torch.Tensor): 273 | label = torch.tensor(label, device=device) 274 | onehot = torch.zeros(label.size() + torch.Size([depth]), device=device) 275 | onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1) 276 | 277 | return onehot 278 | 279 | 280 | class MyDataParallel(nn.DataParallel): 281 | 282 | def gather(self, outputs, target_device): 283 | r""" 284 | Gathers tensors from different GPUs on a specified device 285 | (-1 means the CPU). 286 | """ 287 | 288 | def gather_map(outputs): 289 | out = outputs[0] 290 | if isinstance(out, (str, int, float)): 291 | return out 292 | if isinstance(out, list) and isinstance(out[0], str): 293 | return [o for out in outputs for o in out] 294 | if isinstance(out, torch.Tensor): 295 | return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs) 296 | if out is None: 297 | return None 298 | if isinstance(out, dict): 299 | if not all((len(out) == len(d) for d in outputs)): 300 | raise ValueError('All dicts must have the same number of keys') 301 | return type(out)(((k, gather_map([d[k] for d in outputs])) 302 | for k in out)) 303 | return type(out)(map(gather_map, zip(*outputs))) 304 | 305 | # Recursive function calls like this create reference cycles. 306 | # Setting the function to None clears the refcycle. 307 | try: 308 | res = gather_map(outputs) 309 | finally: 310 | gather_map = None 311 | return res 312 | 313 | 314 | class MyConcatDataset(ConcatDataset): 315 | def __getattr__(self, k): 316 | return getattr(self.datasets[0], k) 317 | 318 | 319 | def if_none(a, b): 320 | return b if a is None else a 321 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open('requirements.txt') as f: 4 | requirements = f.read().splitlines() 5 | 6 | setuptools.setup( 7 | name="semimtr-text-recognition", 8 | version="0.0.1", 9 | author="SemiMTR Text Recognition", 10 | author_email="aws-cv-text-ocr@amazon.com", 11 | description="This package contains the package for SemiMTR", 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/amazon-science/semimtr-text-recognition.git", 14 | packages=setuptools.find_packages(), 15 | install_requires=requirements, 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: Amazon Copyright", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.6', 22 | ) 23 | -------------------------------------------------------------------------------- /tools/create_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ 2 | 3 | import fire 4 | import os 5 | import lmdb 6 | import cv2 7 | 8 | import numpy as np 9 | 10 | 11 | def checkImageIsValid(imageBin): 12 | if imageBin is None: 13 | return False 14 | imageBuf = np.frombuffer(imageBin, dtype=np.uint8) 15 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 16 | imgH, imgW = img.shape[0], img.shape[1] 17 | if imgH * imgW == 0: 18 | return False 19 | return True 20 | 21 | 22 | def writeCache(env, cache): 23 | with env.begin(write=True) as txn: 24 | for k, v in cache.items(): 25 | txn.put(k, v) 26 | 27 | 28 | def createDataset(inputPath, gtFile, outputPath, checkValid=True): 29 | """ 30 | Create LMDB dataset for training and evaluation. 31 | ARGS: 32 | inputPath : input folder path where starts imagePath 33 | outputPath : LMDB output path 34 | gtFile : list of image path and label 35 | checkValid : if true, check the validity of every image 36 | """ 37 | os.makedirs(outputPath, exist_ok=True) 38 | env = lmdb.open(outputPath, map_size=1099511627776) 39 | cache = {} 40 | cnt = 1 41 | 42 | with open(gtFile, 'r', encoding='utf-8') as data: 43 | datalist = data.readlines() 44 | 45 | nSamples = len(datalist) 46 | for i in range(nSamples): 47 | imagePath, label = datalist[i].strip('\n').split('\t') 48 | imagePath = os.path.join(inputPath, imagePath) 49 | 50 | # # only use alphanumeric data 51 | # if re.search('[^a-zA-Z0-9]', label): 52 | # continue 53 | 54 | if not os.path.exists(imagePath): 55 | print('%s does not exist' % imagePath) 56 | continue 57 | with open(imagePath, 'rb') as f: 58 | imageBin = f.read() 59 | if checkValid: 60 | try: 61 | if not checkImageIsValid(imageBin): 62 | print('%s is not a valid image' % imagePath) 63 | continue 64 | except: 65 | print('error occured', i) 66 | with open(outputPath + '/error_image_log.txt', 'a') as log: 67 | log.write('%s-th image data occured error\n' % str(i)) 68 | continue 69 | 70 | imageKey = 'image-%09d'.encode() % cnt 71 | labelKey = 'label-%09d'.encode() % cnt 72 | cache[imageKey] = imageBin 73 | cache[labelKey] = label.encode() 74 | 75 | if cnt % 1000 == 0: 76 | writeCache(env, cache) 77 | cache = {} 78 | print('Written %d / %d' % (cnt, nSamples)) 79 | cnt += 1 80 | nSamples = cnt-1 81 | cache['num-samples'.encode()] = str(nSamples).encode() 82 | writeCache(env, cache) 83 | print('Created dataset with %d samples' % nSamples) 84 | 85 | 86 | if __name__ == '__main__': 87 | fire.Fire(createDataset) 88 | -------------------------------------------------------------------------------- /tools/crop_by_word_bb_syn90k.py: -------------------------------------------------------------------------------- 1 | # Crop by word bounding box 2 | # Locate script with gt.mat 3 | # $ python crop_by_word_bb.py 4 | 5 | import os 6 | import re 7 | import cv2 8 | import scipy.io as sio 9 | from itertools import chain 10 | import numpy as np 11 | import math 12 | 13 | mat_contents = sio.loadmat('gt.mat') 14 | 15 | image_names = mat_contents['imnames'][0] 16 | cropped_indx = 0 17 | start_img_indx = 0 18 | gt_file = open('gt_oabc.txt', 'a') 19 | err_file = open('err_oabc.txt', 'a') 20 | 21 | for img_indx in range(start_img_indx, len(image_names)): 22 | 23 | 24 | # Get image name 25 | image_name_new = image_names[img_indx][0] 26 | # print(image_name_new) 27 | image_name = '/home/yxwang/pytorch/dataset/SynthText/img/'+ image_name_new 28 | # print('IMAGE : {}.{}'.format(img_indx, image_name)) 29 | print('evaluating {} image'.format(img_indx), end='\r') 30 | # Get text in image 31 | txt = mat_contents['txt'][0][img_indx] 32 | txt = [re.split(' \n|\n |\n| ', t.strip()) for t in txt] 33 | txt = list(chain(*txt)) 34 | txt = [t for t in txt if len(t) > 0 ] 35 | # print(txt) # ['Lines:', 'I', 'lost', 'Kevin', 'will', 'line', 'and', 'and', 'the', '(and', 'the', 'out', 'you', "don't", 'pkg'] 36 | # assert 1<0 37 | 38 | # Open image 39 | #img = Image.open(image_name) 40 | img = cv2.imread(image_name, cv2.IMREAD_COLOR) 41 | img_height, img_width, _ = img.shape 42 | 43 | # Validation 44 | if len(np.shape(mat_contents['wordBB'][0][img_indx])) == 2: 45 | wordBBlen = 1 46 | else: 47 | wordBBlen = mat_contents['wordBB'][0][img_indx].shape[-1] 48 | 49 | if wordBBlen == len(txt): 50 | # Crop image and save 51 | for word_indx in range(len(txt)): 52 | # print('txt--',txt) 53 | txt_temp = txt[word_indx] 54 | len_now = len(txt_temp) 55 | # txt_temp = re.sub('[^0-9a-zA-Z]+', '', txt_temp) 56 | # print('txt_temp-1-',txt_temp) 57 | txt_temp = re.sub('[^a-zA-Z]+', '', txt_temp) 58 | # print('txt_temp-2-',txt_temp) 59 | if len_now - len(txt_temp) != 0: 60 | print('txt_temp-2-', txt_temp) 61 | 62 | if len(np.shape(mat_contents['wordBB'][0][img_indx])) == 2: # only one word (2,4) 63 | wordBB = mat_contents['wordBB'][0][img_indx] 64 | else: # many words (2,4,num_words) 65 | wordBB = mat_contents['wordBB'][0][img_indx][:, :, word_indx] 66 | 67 | if np.shape(wordBB) != (2, 4): 68 | err_log = 'malformed box index: {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB) 69 | err_file.write(err_log) 70 | # print(err_log) 71 | continue 72 | 73 | pts1 = np.float32([[wordBB[0][0], wordBB[1][0]], 74 | [wordBB[0][3], wordBB[1][3]], 75 | [wordBB[0][1], wordBB[1][1]], 76 | [wordBB[0][2], wordBB[1][2]]]) 77 | height = math.sqrt((wordBB[0][0] - wordBB[0][3])**2 + (wordBB[1][0] - wordBB[1][3])**2) 78 | width = math.sqrt((wordBB[0][0] - wordBB[0][1])**2 + (wordBB[1][0] - wordBB[1][1])**2) 79 | 80 | # Coord validation check 81 | if (height * width) <= 0: 82 | err_log = 'empty file : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB) 83 | err_file.write(err_log) 84 | # print(err_log) 85 | continue 86 | elif (height * width) > (img_height * img_width): 87 | err_log = 'too big box : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB) 88 | err_file.write(err_log) 89 | # print(err_log) 90 | continue 91 | else: 92 | valid = True 93 | for i in range(2): 94 | for j in range(4): 95 | if wordBB[i][j] < 0 or wordBB[i][j] > img.shape[1 - i]: 96 | valid = False 97 | break 98 | if not valid: 99 | break 100 | if not valid: 101 | err_log = 'invalid coord : {}\t{}\t{}\t{}\t{}\n'.format( 102 | image_name, txt[word_indx], wordBB, (width, height), (img_width, img_height)) 103 | err_file.write(err_log) 104 | # print(err_log) 105 | continue 106 | 107 | pts2 = np.float32([[0, 0], 108 | [0, height], 109 | [width, 0], 110 | [width, height]]) 111 | 112 | x_min = np.int(round(min(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3]))) 113 | x_max = np.int(round(max(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3]))) 114 | y_min = np.int(round(min(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3]))) 115 | y_max = np.int(round(max(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3]))) 116 | # print(x_min, x_max, y_min, y_max) 117 | # print(img.shape) 118 | # assert 1<0 119 | if len(img.shape) == 3: 120 | img_cropped = img[ y_min:y_max:1, x_min:x_max:1, :] 121 | else: 122 | img_cropped = img[ y_min:y_max:1, x_min:x_max:1] 123 | dir_name = '/home/yxwang/pytorch/dataset/SynthText/cropped-oabc/{}'.format(image_name_new.split('/')[0]) 124 | # print('dir_name--',dir_name) 125 | if not os.path.exists(dir_name): 126 | os.mkdir(dir_name) 127 | cropped_file_name = "{}/{}_{}_{}.jpg".format(dir_name, cropped_indx, 128 | image_name.split('/')[-1][:-len('.jpg')], word_indx) 129 | # print('cropped_file_name--',cropped_file_name) 130 | # print('img_cropped--',img_cropped.shape) 131 | if img_cropped.shape[0] == 0 or img_cropped.shape[1] == 0: 132 | err_log = 'word_box_mismatch : {}\t{}\t{}\n'.format(image_name, mat_contents['txt'][0][ 133 | img_indx], mat_contents['wordBB'][0][img_indx]) 134 | err_file.write(err_log) 135 | # print(err_log) 136 | continue 137 | # print('img_cropped--',img_cropped) 138 | 139 | # img_cropped.save(cropped_file_name) 140 | cv2.imwrite(cropped_file_name, img_cropped) 141 | cropped_indx += 1 142 | gt_file.write('%s\t%s\n' % (cropped_file_name, txt[word_indx])) 143 | 144 | # if cropped_indx>10: 145 | # assert 1<0 146 | # assert 1 < 0 147 | else: 148 | err_log = 'word_box_mismatch : {}\t{}\t{}\n'.format(image_name, mat_contents['txt'][0][ 149 | img_indx], mat_contents['wordBB'][0][img_indx]) 150 | err_file.write(err_log) 151 | # print(err_log) 152 | gt_file.close() 153 | err_file.close() 154 | -------------------------------------------------------------------------------- /tools/prepare_wikitext103.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 82841986 is_char and is_digit" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# 82075350 regrex non-ascii and none-digit" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 86460763 left" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import os\n", 31 | "import random\n", 32 | "import re\n", 33 | "import pandas as pd" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "max_length = 25\n", 43 | "min_length = 1\n", 44 | "root = '../data'\n", 45 | "charset = 'abcdefghijklmnopqrstuvwxyz'\n", 46 | "digits = '0123456789'" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "def is_char(text, ratio=0.5):\n", 56 | " text = text.lower()\n", 57 | " length = max(len(text), 1)\n", 58 | " char_num = sum([t in charset for t in text])\n", 59 | " if char_num < min_length: return False\n", 60 | " if char_num / length < ratio: return False\n", 61 | " return True\n", 62 | "\n", 63 | "def is_digit(text, ratio=0.5):\n", 64 | " length = max(len(text), 1)\n", 65 | " digit_num = sum([t in digits for t in text])\n", 66 | " if digit_num / length < ratio: return False\n", 67 | " return True" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "# generate training dataset" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "with open('/tmp/wikitext-103/wiki.train.tokens', 'r') as file:\n", 84 | " lines = file.readlines()" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "inp, gt = [], []\n", 94 | "for line in lines:\n", 95 | " token = line.lower().split()\n", 96 | " for text in token:\n", 97 | " text = re.sub('[^0-9a-zA-Z]+', '', text)\n", 98 | " if len(text) < min_length:\n", 99 | " # print('short-text', text)\n", 100 | " continue\n", 101 | " if len(text) > max_length:\n", 102 | " # print('long-text', text)\n", 103 | " continue\n", 104 | " inp.append(text)\n", 105 | " gt.append(text)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "train_voc = os.path.join(root, 'WikiText-103.csv')\n", 115 | "pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(train_voc, index=None, sep='\\t')" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "86460763" 127 | ] 128 | }, 129 | "execution_count": 7, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "len(inp)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 8, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "text/plain": [ 146 | "['valkyria',\n", 147 | " 'chronicles',\n", 148 | " 'iii',\n", 149 | " 'senj',\n", 150 | " 'no',\n", 151 | " 'valkyria',\n", 152 | " '3',\n", 153 | " 'unk',\n", 154 | " 'chronicles',\n", 155 | " 'japanese',\n", 156 | " '3',\n", 157 | " 'lit',\n", 158 | " 'valkyria',\n", 159 | " 'of',\n", 160 | " 'the',\n", 161 | " 'battlefield',\n", 162 | " '3',\n", 163 | " 'commonly',\n", 164 | " 'referred',\n", 165 | " 'to',\n", 166 | " 'as',\n", 167 | " 'valkyria',\n", 168 | " 'chronicles',\n", 169 | " 'iii',\n", 170 | " 'outside',\n", 171 | " 'japan',\n", 172 | " 'is',\n", 173 | " 'a',\n", 174 | " 'tactical',\n", 175 | " 'role',\n", 176 | " 'playing',\n", 177 | " 'video',\n", 178 | " 'game',\n", 179 | " 'developed',\n", 180 | " 'by',\n", 181 | " 'sega',\n", 182 | " 'and',\n", 183 | " 'mediavision',\n", 184 | " 'for',\n", 185 | " 'the',\n", 186 | " 'playstation',\n", 187 | " 'portable',\n", 188 | " 'released',\n", 189 | " 'in',\n", 190 | " 'january',\n", 191 | " '2011',\n", 192 | " 'in',\n", 193 | " 'japan',\n", 194 | " 'it',\n", 195 | " 'is',\n", 196 | " 'the',\n", 197 | " 'third',\n", 198 | " 'game',\n", 199 | " 'in',\n", 200 | " 'the',\n", 201 | " 'valkyria',\n", 202 | " 'series',\n", 203 | " 'employing',\n", 204 | " 'the',\n", 205 | " 'same',\n", 206 | " 'fusion',\n", 207 | " 'of',\n", 208 | " 'tactical',\n", 209 | " 'and',\n", 210 | " 'real',\n", 211 | " 'time',\n", 212 | " 'gameplay',\n", 213 | " 'as',\n", 214 | " 'its',\n", 215 | " 'predecessors',\n", 216 | " 'the',\n", 217 | " 'story',\n", 218 | " 'runs',\n", 219 | " 'parallel',\n", 220 | " 'to',\n", 221 | " 'the',\n", 222 | " 'first',\n", 223 | " 'game',\n", 224 | " 'and',\n", 225 | " 'follows',\n", 226 | " 'the',\n", 227 | " 'nameless',\n", 228 | " 'a',\n", 229 | " 'penal',\n", 230 | " 'military',\n", 231 | " 'unit',\n", 232 | " 'serving',\n", 233 | " 'the',\n", 234 | " 'nation',\n", 235 | " 'of',\n", 236 | " 'gallia',\n", 237 | " 'during',\n", 238 | " 'the',\n", 239 | " 'second',\n", 240 | " 'europan',\n", 241 | " 'war',\n", 242 | " 'who',\n", 243 | " 'perform',\n", 244 | " 'secret',\n", 245 | " 'black']" 246 | ] 247 | }, 248 | "execution_count": 8, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "inp[:100]" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "# generate evaluation dataset" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 9, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "def disturb(word, degree, p=0.3):\n", 271 | " if len(word) // 2 < degree: return word\n", 272 | " if is_digit(word): return word\n", 273 | " if random.random() < p: return word\n", 274 | " else:\n", 275 | " index = list(range(len(word)))\n", 276 | " random.shuffle(index)\n", 277 | " index = index[:degree]\n", 278 | " new_word = []\n", 279 | " for i in range(len(word)):\n", 280 | " if i not in index: \n", 281 | " new_word.append(word[i])\n", 282 | " continue\n", 283 | " if (word[i] not in charset) and (word[i] not in digits):\n", 284 | " # special token\n", 285 | " new_word.append(word[i])\n", 286 | " continue\n", 287 | " op = random.random()\n", 288 | " if op < 0.1: # add\n", 289 | " new_word.append(random.choice(charset))\n", 290 | " new_word.append(word[i])\n", 291 | " elif op < 0.2: continue # remove\n", 292 | " else: new_word.append(random.choice(charset)) # replace\n", 293 | " return ''.join(new_word)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 10, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "lines = inp\n", 303 | "degree = 1\n", 304 | "keep_num = 50000\n", 305 | "\n", 306 | "random.shuffle(lines)\n", 307 | "part_lines = lines[:keep_num]\n", 308 | "inp, gt = [], []\n", 309 | "\n", 310 | "for w in part_lines:\n", 311 | " w = w.strip().lower()\n", 312 | " new_w = disturb(w, degree)\n", 313 | " inp.append(new_w)\n", 314 | " gt.append(w)\n", 315 | " \n", 316 | "eval_voc = os.path.join(root, f'WikiText-103_eval_d{degree}.csv')\n", 317 | "pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(eval_voc, index=None, sep='\\t')" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 11, 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "data": { 327 | "text/plain": [ 328 | "[('high', 'high'),\n", 329 | " ('vctoria', 'victoria'),\n", 330 | " ('mains', 'mains'),\n", 331 | " ('bi', 'by'),\n", 332 | " ('13', '13'),\n", 333 | " ('ticnet', 'ticket'),\n", 334 | " ('basil', 'basic'),\n", 335 | " ('cut', 'cut'),\n", 336 | " ('aqarky', 'anarky'),\n", 337 | " ('the', 'the'),\n", 338 | " ('tqe', 'the'),\n", 339 | " ('oc', 'of'),\n", 340 | " ('diwpersal', 'dispersal'),\n", 341 | " ('traffic', 'traffic'),\n", 342 | " ('in', 'in'),\n", 343 | " ('the', 'the'),\n", 344 | " ('ti', 'to'),\n", 345 | " ('professionalms', 'professionals'),\n", 346 | " ('747', '747'),\n", 347 | " ('in', 'in'),\n", 348 | " ('and', 'and'),\n", 349 | " ('exezutive', 'executive'),\n", 350 | " ('n400', 'n400'),\n", 351 | " ('yusic', 'music'),\n", 352 | " ('s', 's'),\n", 353 | " ('henri', 'henry'),\n", 354 | " ('heard', 'heard'),\n", 355 | " ('thousand', 'thousand'),\n", 356 | " ('to', 'to'),\n", 357 | " ('arhy', 'army'),\n", 358 | " ('td', 'to'),\n", 359 | " ('a', 'a'),\n", 360 | " ('oall', 'hall'),\n", 361 | " ('qind', 'kind'),\n", 362 | " ('od', 'on'),\n", 363 | " ('samfria', 'samaria'),\n", 364 | " ('driveway', 'driveway'),\n", 365 | " ('which', 'which'),\n", 366 | " ('wotk', 'work'),\n", 367 | " ('ak', 'as'),\n", 368 | " ('persona', 'persona'),\n", 369 | " ('s', 's'),\n", 370 | " ('melbourne', 'melbourne'),\n", 371 | " ('apong', 'along'),\n", 372 | " ('fas', 'was'),\n", 373 | " ('thea', 'then'),\n", 374 | " ('permcy', 'percy'),\n", 375 | " ('nnd', 'and'),\n", 376 | " ('alan', 'alan'),\n", 377 | " ('13', '13'),\n", 378 | " ('matteos', 'matters'),\n", 379 | " ('against', 'against'),\n", 380 | " ('nefion', 'nexion'),\n", 381 | " ('held', 'held'),\n", 382 | " ('negative', 'negative'),\n", 383 | " ('gogd', 'good'),\n", 384 | " ('the', 'the'),\n", 385 | " ('thd', 'the'),\n", 386 | " ('groening', 'groening'),\n", 387 | " ('tqe', 'the'),\n", 388 | " ('cwould', 'would'),\n", 389 | " ('fb', 'ft'),\n", 390 | " ('uniten', 'united'),\n", 391 | " ('kone', 'one'),\n", 392 | " ('thiy', 'this'),\n", 393 | " ('lanren', 'lauren'),\n", 394 | " ('s', 's'),\n", 395 | " ('thhe', 'the'),\n", 396 | " ('is', 'is'),\n", 397 | " ('modep', 'model'),\n", 398 | " ('weird', 'weird'),\n", 399 | " ('angwer', 'answer'),\n", 400 | " ('imprisxnment', 'imprisonment'),\n", 401 | " ('marpery', 'margery'),\n", 402 | " ('eventuanly', 'eventually'),\n", 403 | " ('in', 'in'),\n", 404 | " ('donnoa', 'donna'),\n", 405 | " ('ik', 'it'),\n", 406 | " ('reached', 'reached'),\n", 407 | " ('at', 'at'),\n", 408 | " ('excxted', 'excited'),\n", 409 | " ('ws', 'was'),\n", 410 | " ('raes', 'rates'),\n", 411 | " ('the', 'the'),\n", 412 | " ('firsq', 'first'),\n", 413 | " ('concluyed', 'concluded'),\n", 414 | " ('recdorded', 'recorded'),\n", 415 | " ('fhe', 'the'),\n", 416 | " ('uegiment', 'regiment'),\n", 417 | " ('a', 'a'),\n", 418 | " ('glanes', 'planes'),\n", 419 | " ('conyrol', 'control'),\n", 420 | " ('thr', 'the'),\n", 421 | " ('arrext', 'arrest'),\n", 422 | " ('bth', 'both'),\n", 423 | " ('forward', 'forward'),\n", 424 | " ('allowdd', 'allowed'),\n", 425 | " ('revealed', 'revealed'),\n", 426 | " ('mayagement', 'management'),\n", 427 | " ('normal', 'normal')]" 428 | ] 429 | }, 430 | "execution_count": 11, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "list(zip(inp, gt))[:100]" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [] 445 | } 446 | ], 447 | "metadata": { 448 | "kernelspec": { 449 | "display_name": "Python 3", 450 | "language": "python", 451 | "name": "python3" 452 | }, 453 | "language_info": { 454 | "codemirror_mode": { 455 | "name": "ipython", 456 | "version": 3 457 | }, 458 | "file_extension": ".py", 459 | "mimetype": "text/x-python", 460 | "name": "python", 461 | "nbconvert_exporter": "python", 462 | "pygments_lexer": "ipython3", 463 | "version": "3.7.4" 464 | } 465 | }, 466 | "nbformat": 4, 467 | "nbformat_minor": 4 468 | } 469 | -------------------------------------------------------------------------------- /tools/remove_labels_from_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb 3 | from joblib import Parallel, delayed, cpu_count 4 | from tqdm import tqdm 5 | 6 | 7 | def remove_labels_from_lmdb_dataset(input_lmdb_path, output_lmdb_path): 8 | """ 9 | Create LMDB dataset for training and evaluation. 10 | ARGS: 11 | input_lmdb_path : input folder path where starts imagePath 12 | output_lmdb_path : LMDB output path 13 | """ 14 | os.makedirs(output_lmdb_path, exist_ok=True) 15 | cache = {} 16 | cnt = 1 17 | env_input = lmdb.open(input_lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 18 | env_output = lmdb.open(output_lmdb_path, map_size=1099511627776) 19 | with env_input.begin(write=False) as txn: 20 | n_samples = int(txn.get('num-samples'.encode())) 21 | for _ in tqdm(range(n_samples)): 22 | image_key_code = 'image-%09d'.encode() % cnt 23 | image_key = txn.get(image_key_code) 24 | cache[image_key_code] = image_key 25 | 26 | label_key_code = 'label-%09d'.encode() % cnt 27 | cache[label_key_code] = 'unlabeleddata'.encode() 28 | 29 | if cnt % 1000 == 0: 30 | write_cache(env_output, cache) 31 | cache = {} 32 | cnt += 1 33 | cache['num-samples'.encode()] = str(n_samples).encode() 34 | write_cache(env_output, cache) 35 | 36 | 37 | def write_cache(env, cache): 38 | with env.begin(write=True) as txn: 39 | for k, v in cache.items(): 40 | txn.put(k, v) 41 | 42 | 43 | if __name__ == '__main__': 44 | labeled_data_root = "data/training/label/real" 45 | unlabeled_data_root = "data/training/label_without_labels/real" 46 | datasets = [ 47 | "10.MLT19", 48 | "11.ReCTS", 49 | "1.SVT", 50 | "2.IIIT", 51 | "3.IC13", 52 | "4.IC15", 53 | "5.COCO", 54 | "6.RCTW17", 55 | "7.Uber", 56 | "8.ArT", 57 | "9.LSVT", 58 | ] 59 | 60 | n_jobs = min(cpu_count(), len(datasets)) 61 | Parallel(n_jobs=n_jobs)(delayed(remove_labels_from_lmdb_dataset)( 62 | input_lmdb_path=os.path.join(labeled_data_root, dataset), 63 | output_lmdb_path=os.path.join(unlabeled_data_root, dataset)) for dataset in datasets) 64 | --------------------------------------------------------------------------------