├── .gitignore ├── LICENSE ├── README.md ├── config ├── __init__.py ├── base.yml ├── resnet-transformer.yml ├── resnet_fpn_transformer.yml ├── vgg-convseq2seq.yml ├── vgg-seq2seq.yml └── vgg-transformer.yml ├── image ├── .keep ├── sample.png └── vietocr.jpg ├── setup.py ├── vietocr ├── __init__.py ├── loader │ ├── __init__.py │ ├── aug.py │ ├── dataloader.py │ └── dataloader_v1.py ├── model │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── resnet.py │ │ └── vgg.py │ ├── beam.py │ ├── seqmodel │ │ ├── __init__.py │ │ ├── convseq2seq.py │ │ ├── seq2seq.py │ │ └── transformer.py │ ├── trainer.py │ ├── transformerocr.py │ └── vocab.py ├── optim │ ├── __init__.py │ ├── labelsmoothingloss.py │ └── optim.py ├── predict.py ├── requirement.txt ├── tests │ ├── image │ │ ├── 001099025107.jpeg │ │ ├── 026301003919.jpeg │ │ ├── 036170002830.jpeg │ │ ├── 038078002355.jpeg │ │ ├── 038089010274.jpeg │ │ ├── 038144000109.jpeg │ │ ├── 060085000115.jpeg │ │ ├── 072183002222.jpeg │ │ ├── 079084000809.jpeg │ │ └── 079193002341.jpeg │ ├── sample.txt │ └── utest.py ├── tool │ ├── __init__.py │ ├── config.py │ ├── create_dataset.py │ ├── logger.py │ ├── predictor.py │ ├── translate.py │ └── utils.py └── train.py └── vietocr_gettingstart.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VietOCR 2 | **[DORI](https://dorify.net/vi) là end-to-end OCR platform, hỗ trợ các bạn đánh nhãn, huấn luyện, deploy mô hình dễ dàng.** 3 | - Tham khảo tại [dorify.net](https://dorify.net/vi). 4 | - Tài liệu hướng dẫn [tại đây](https://pbcquoc.github.io/dori_guideline/). 5 | 6 | ---- 7 | **Các bạn vui lòng cập nhật lên version mới nhất để không xảy ra lỗi.** 8 | 9 |

10 | 11 |

12 | 13 | Trong project này, mình cài đặt mô hình Transformer OCR nhận dạng chữ viết tay, chữ đánh máy cho Tiếng Việt. Kiến trúc mô hình là sự kết hợp tuyệt vời giữ mô hình CNN và Transformer (là mô hình nền tảng của BERT khá nổi tiếng). Mô hình TransformerOCR có rất nhiều ưu điểm so với kiến trúc của mô hình CRNN đã được mình cài đặt. Các bạn có thể đọc [tại](https://pbcquoc.github.io/vietocr) đây về kiến trúc và cách huấn luyện mô hình với các tập dữ liệu khác nhau. 14 | 15 | Mô hình VietOCR có tính tổng quát cực tốt, thậm chí có độ chính xác khá cao trên một bộ dataset mới mặc dù mô hình chưa được huấn luyện bao giờ. 16 | 17 |

18 | 19 |

20 | 21 | # Cài Đặt 22 | Để cài đặt các bạn gõ lệnh sau 23 | ``` 24 | pip install vietocr 25 | ``` 26 | # Quick Start 27 | Các bạn tham khảo notebook [này](https://github.com/pbcquoc/vietocr/blob/master/vietocr_gettingstart.ipynb) để biết cách sử dụng nhé. 28 | # Cách tạo file train/test 29 | File train/test có 2 cột, cột đầu tiên là tên file, cột thứ 2 là nhãn(không chứa kí tự \t), 2 cột này cách nhau bằng \t 30 | ``` 31 | 20160518_0151_25432_1_tg_3_5.png để nghe phổ biến chủ trương của UBND tỉnh Phú Yên 32 | 20160421_0102_25464_2_tg_0_4.png môi trường lại đều đồng thanh 33 | ``` 34 | Tham khảo file mẫu tại [đây](https://vocr.vn/data/vietocr/data_line.zip) 35 | 36 | # Model Zoo 37 | Thư viện này cài đặt cả 2 kiểu seq model đó là attention seq2seq và transfomer. Seq2seq có tốc độ dự đoán rất nhanh và được dùng trong industry khá nhiều, tuy nhiên transformer lại chính xác hơn nhưng lúc dự đoán lại khá chậm. Do đó mình cung cấp cả 2 loại cho các bạn lựa chọn. 38 | 39 | Mô hình này được huấn luyện trên tập dữ liệu gồm 10m ảnh, bao gồm nhiều loại ảnh khác nhau như ảnh tự phát sinh, chữ viết tay, các văn bản scan thực tế. 40 | Pretrain model được cung cấp sẵn. 41 | 42 | # Kết quả thử nghiệm trên tập 10m 43 | | Backbone | Config | Precision full sequence | time | 44 | | ------------- |:-------------:| ---:|---:| 45 | | VGG19-bn - Transformer | vgg_transformer | 0.8800 | 86ms @ 1080ti | 46 | | VGG19-bn - Seq2Seq | vgg_seq2seq | 0.8701 | 12ms @ 1080ti | 47 | 48 | Thời gian dự đoán của mô hình vgg-transformer quá lâu so với mô hình seq2seq, trong khi đó không có sự khác biệt rõ ràng giữ độ chính xác của 2 loại kiến trúc này. 49 | 50 | # Dataset 51 | Mình chỉ cung cấp tập dữ liệu mẫu khoảng 1m ảnh tự phát sinh. Các bạn có thể tải về tại [đây](https://drive.google.com/file/d/1T0cmkhTgu3ahyMIwGZeby612RpVdDxOR/view). 52 | # License 53 | Mình phát hành thư viện này dưới các điều khoản của [Apache 2.0 license](). 54 | 55 | # Liên hệ 56 | Nếu bạn có bất kì vấn đề gì, vui lòng tạo issue hoặc liên hệ mình tại pbcquoc@gmail.com 57 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/config/__init__.py -------------------------------------------------------------------------------- /config/base.yml: -------------------------------------------------------------------------------- 1 | # change to list chars of your dataset or use default vietnamese chars 2 | vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ ' 3 | 4 | # cpu, cuda, cuda:0 5 | device: cuda:0 6 | 7 | seq_modeling: transformer 8 | transformer: 9 | d_model: 256 10 | nhead: 8 11 | num_encoder_layers: 6 12 | num_decoder_layers: 6 13 | dim_feedforward: 2048 14 | max_seq_length: 1024 15 | pos_dropout: 0.1 16 | trans_dropout: 0.1 17 | 18 | optimizer: 19 | max_lr: 0.0003 20 | pct_start: 0.1 21 | 22 | trainer: 23 | batch_size: 32 24 | print_every: 200 25 | valid_every: 4000 26 | iters: 100000 27 | # where to save our model for prediction 28 | export: ./weights/transformerocr.pth 29 | checkpoint: ./checkpoint/transformerocr_checkpoint.pth 30 | log: ./train.log 31 | # null to disable compuate accuracy, or change to number of sample to enable validiation while training 32 | metrics: null 33 | 34 | dataset: 35 | # name of your dataset 36 | name: data 37 | # path to annotation and image 38 | data_root: ./img/ 39 | train_annotation: annotation_train.txt 40 | valid_annotation: annotation_val_small.txt 41 | # resize image to 32 height, larger height will increase accuracy 42 | image_height: 32 43 | image_min_width: 32 44 | image_max_width: 512 45 | 46 | dataloader: 47 | num_workers: 3 48 | pin_memory: True 49 | 50 | aug: 51 | image_aug: true 52 | masked_language_model: true 53 | 54 | predictor: 55 | # disable or enable beamsearch while prediction, use beamsearch will be slower 56 | beamsearch: False 57 | 58 | quiet: False 59 | -------------------------------------------------------------------------------- /config/resnet-transformer.yml: -------------------------------------------------------------------------------- 1 | pretrain: 2 | id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA 3 | md5: 7068030afe2e8fc639d0e1e2c25612b3 4 | cached: /tmp/tranformerorc.pth 5 | 6 | weights: https://drive.google.com/uc?id=12dTOZ9VP7ZVzwQgVvqBWz5JO5RXXW5NY 7 | 8 | backbone: resnet50 9 | cnn: 10 | ss: 11 | - [2, 2] 12 | - [2, 1] 13 | - [2, 1] 14 | - [2, 1] 15 | - [1, 1] 16 | hidden: 256 17 | -------------------------------------------------------------------------------- /config/resnet_fpn_transformer.yml: -------------------------------------------------------------------------------- 1 | pretrain: 2 | id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA 3 | md5: 7068030afe2e8fc639d0e1e2c25612b3 4 | cached: /tmp/tranformerorc.pth 5 | 6 | weights: https://drive.google.com/uc?id=12dTOZ9VP7ZVzwQgVvqBWz5JO5RXXW5NY 7 | 8 | backbone: resnet50_fpn 9 | cnn: {} 10 | -------------------------------------------------------------------------------- /config/vgg-convseq2seq.yml: -------------------------------------------------------------------------------- 1 | pretrain: 2 | id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA 3 | md5: fbefa85079ad9001a71eb1bf47a93785 4 | cached: /tmp/tranformerorc.pth 5 | 6 | # url or local path 7 | weights: https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA 8 | 9 | backbone: vgg19_bn 10 | cnn: 11 | # pooling stride size 12 | ss: 13 | - [2, 2] 14 | - [2, 2] 15 | - [2, 1] 16 | - [2, 1] 17 | - [1, 1] 18 | # pooling kernel size 19 | ks: 20 | - [2, 2] 21 | - [2, 2] 22 | - [2, 1] 23 | - [2, 1] 24 | - [1, 1] 25 | # dim of ouput feature map 26 | hidden: 256 27 | 28 | seq_modeling: convseq2seq 29 | transformer: 30 | emb_dim: 256 31 | hid_dim: 512 32 | enc_layers: 10 33 | dec_layers: 10 34 | enc_kernel_size: 3 35 | dec_kernel_size: 3 36 | dropout: 0.1 37 | pad_idx: 0 38 | device: cuda:1 39 | enc_max_length: 512 40 | dec_max_length: 512 41 | -------------------------------------------------------------------------------- /config/vgg-seq2seq.yml: -------------------------------------------------------------------------------- 1 | # for train 2 | pretrain: https://vocr.vn/data/vietocr/vgg_seq2seq.pth 3 | 4 | # url or local path (for predict) 5 | weights: https://vocr.vn/data/vietocr/vgg_seq2seq.pth 6 | 7 | backbone: vgg19_bn 8 | cnn: 9 | # pooling stride size 10 | ss: 11 | - [2, 2] 12 | - [2, 2] 13 | - [2, 1] 14 | - [2, 1] 15 | - [1, 1] 16 | # pooling kernel size 17 | ks: 18 | - [2, 2] 19 | - [2, 2] 20 | - [2, 1] 21 | - [2, 1] 22 | - [1, 1] 23 | # dim of ouput feature map 24 | hidden: 256 25 | 26 | seq_modeling: seq2seq 27 | transformer: 28 | encoder_hidden: 256 29 | decoder_hidden: 256 30 | img_channel: 256 31 | decoder_embedded: 256 32 | dropout: 0.1 33 | 34 | optimizer: 35 | max_lr: 0.001 36 | pct_start: 0.1 37 | -------------------------------------------------------------------------------- /config/vgg-transformer.yml: -------------------------------------------------------------------------------- 1 | # for training 2 | pretrain: https://vocr.vn/data/vietocr/vgg_transformer.pth 3 | 4 | # url or local path (predict) 5 | weights: https://vocr.vn/data/vietocr/vgg_transformer.pth 6 | 7 | backbone: vgg19_bn 8 | cnn: 9 | pretrained: True 10 | # pooling stride size 11 | ss: 12 | - [2, 2] 13 | - [2, 2] 14 | - [2, 1] 15 | - [2, 1] 16 | - [1, 1] 17 | # pooling kernel size 18 | ks: 19 | - [2, 2] 20 | - [2, 2] 21 | - [2, 1] 22 | - [2, 1] 23 | - [1, 1] 24 | # dim of ouput feature map 25 | hidden: 256 26 | 27 | -------------------------------------------------------------------------------- /image/.keep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /image/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/image/sample.png -------------------------------------------------------------------------------- /image/vietocr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/image/vietocr.jpg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="vietocr", 8 | version="0.3.13", 9 | author="pbcquoc", 10 | author_email="pbcquoc@gmail.com", 11 | description="Transformer base text detection", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/pbcquoc/vietocr", 15 | packages=setuptools.find_packages(), 16 | install_requires=[ 17 | "einops>=0.2.0", 18 | "gdown>=4.4.0", 19 | "albumentations>=1.4.2", 20 | "lmdb>=1.0.0", 21 | "scikit-image>=0.21.0", 22 | "pillow>=10.2.0", 23 | ], 24 | classifiers=[ 25 | "Programming Language :: Python :: 3", 26 | "License :: OSI Approved :: MIT License", 27 | "Operating System :: OS Independent", 28 | ], 29 | python_requires=">=3.6", 30 | ) 31 | -------------------------------------------------------------------------------- /vietocr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/__init__.py -------------------------------------------------------------------------------- /vietocr/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/loader/__init__.py -------------------------------------------------------------------------------- /vietocr/loader/aug.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | import albumentations as A 5 | from albumentations.core.transforms_interface import ImageOnlyTransform 6 | import cv2 7 | import random 8 | 9 | 10 | class RandomDottedLine(ImageOnlyTransform): 11 | def __init__(self, num_lines=1, p=0.5): 12 | super(RandomDottedLine, self).__init__(p=p) 13 | self.num_lines = num_lines 14 | 15 | def apply(self, img, **params): 16 | h, w = img.shape[:2] 17 | for _ in range(self.num_lines): 18 | # Random start and end points 19 | x1, y1 = np.random.randint(0, w), np.random.randint(0, h) 20 | x2, y2 = np.random.randint(0, w), np.random.randint(0, h) 21 | # Random color 22 | color = tuple(np.random.randint(0, 256, size=3).tolist()) 23 | # Random thickness 24 | thickness = np.random.randint(1, 5) 25 | # Draw dotted or dashed line 26 | line_type = random.choice(["dotted", "dashed", "solid"]) 27 | if line_type != "solid": 28 | self._draw_dotted_line( 29 | img, (x1, y1), (x2, y2), color, thickness, line_type 30 | ) 31 | else: 32 | cv2.line(img, (x1, y1), (x2, y2), color, thickness) 33 | 34 | return img 35 | 36 | def _draw_dotted_line(self, img, pt1, pt2, color, thickness, line_type): 37 | # Calculate the distance between the points 38 | dist = np.hypot(pt2[0] - pt1[0], pt2[1] - pt1[1]) 39 | # Number of segments 40 | num_segments = max(int(dist // 5), 1) 41 | # Generate points along the line 42 | x_points = np.linspace(pt1[0], pt2[0], num_segments) 43 | y_points = np.linspace(pt1[1], pt2[1], num_segments) 44 | # Draw segments 45 | for i in range(num_segments - 1): 46 | if line_type == "dotted" and i % 2 == 0: 47 | pt_start = (int(x_points[i]), int(y_points[i])) 48 | pt_end = (int(x_points[i]), int(y_points[i])) 49 | cv2.circle(img, pt_start, thickness, color, -1) 50 | elif line_type == "dashed" and i % 4 < 2: 51 | pt_start = (int(x_points[i]), int(y_points[i])) 52 | pt_end = (int(x_points[i + 1]), int(y_points[i + 1])) 53 | cv2.line(img, pt_start, pt_end, color, thickness) 54 | return img 55 | 56 | def get_transform_init_args_names(self): 57 | return ("num_lines",) 58 | 59 | 60 | class ImgAugTransformV2: 61 | def __init__(self): 62 | self.aug = A.Compose( 63 | [ 64 | A.InvertImg(p=0.2), 65 | A.ColorJitter(p=0.2), 66 | A.MotionBlur(blur_limit=3, p=0.2), 67 | A.RandomBrightnessContrast(p=0.2), 68 | A.Perspective(scale=(0.01, 0.05)), 69 | RandomDottedLine(), 70 | ] 71 | ) 72 | 73 | def __call__(self, img): 74 | img = np.array(img) 75 | transformed = self.aug(image=img) 76 | img = transformed["image"] 77 | img = Image.fromarray(img) 78 | return img 79 | -------------------------------------------------------------------------------- /vietocr/loader/dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import random 4 | from PIL import Image 5 | from PIL import ImageFile 6 | 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | from collections import defaultdict 10 | import numpy as np 11 | import torch 12 | import lmdb 13 | import six 14 | import time 15 | from tqdm import tqdm 16 | 17 | from torch.utils.data import Dataset 18 | from torch.utils.data.sampler import Sampler 19 | from vietocr.tool.translate import process_image 20 | from vietocr.tool.create_dataset import createDataset 21 | from vietocr.tool.translate import resize 22 | 23 | 24 | class OCRDataset(Dataset): 25 | def __init__( 26 | self, 27 | lmdb_path, 28 | root_dir, 29 | annotation_path, 30 | vocab, 31 | image_height=32, 32 | image_min_width=32, 33 | image_max_width=512, 34 | transform=None, 35 | ): 36 | self.root_dir = root_dir 37 | self.annotation_path = os.path.join(root_dir, annotation_path) 38 | self.vocab = vocab 39 | self.transform = transform 40 | 41 | self.image_height = image_height 42 | self.image_min_width = image_min_width 43 | self.image_max_width = image_max_width 44 | 45 | self.lmdb_path = lmdb_path 46 | 47 | if os.path.isdir(self.lmdb_path): 48 | print( 49 | "{} exists. Remove folder if you want to create new dataset".format( 50 | self.lmdb_path 51 | ) 52 | ) 53 | sys.stdout.flush() 54 | else: 55 | createDataset(self.lmdb_path, root_dir, annotation_path) 56 | 57 | self.env = lmdb.open( 58 | self.lmdb_path, 59 | max_readers=8, 60 | readonly=True, 61 | lock=False, 62 | readahead=False, 63 | meminit=False, 64 | ) 65 | self.txn = self.env.begin(write=False) 66 | 67 | nSamples = int(self.txn.get("num-samples".encode())) 68 | self.nSamples = nSamples 69 | 70 | self.build_cluster_indices() 71 | 72 | def build_cluster_indices(self): 73 | self.cluster_indices = defaultdict(list) 74 | 75 | pbar = tqdm( 76 | range(self.__len__()), 77 | desc="{} build cluster".format(self.lmdb_path), 78 | ncols=100, 79 | position=0, 80 | leave=True, 81 | ) 82 | 83 | for i in pbar: 84 | bucket = self.get_bucket(i) 85 | self.cluster_indices[bucket].append(i) 86 | 87 | def get_bucket(self, idx): 88 | key = "dim-%09d" % idx 89 | 90 | dim_img = self.txn.get(key.encode()) 91 | dim_img = np.fromstring(dim_img, dtype=np.int32) 92 | imgH, imgW = dim_img 93 | 94 | new_w, image_height = resize( 95 | imgW, imgH, self.image_height, self.image_min_width, self.image_max_width 96 | ) 97 | 98 | return new_w 99 | 100 | def read_buffer(self, idx): 101 | img_file = "image-%09d" % idx 102 | label_file = "label-%09d" % idx 103 | path_file = "path-%09d" % idx 104 | 105 | imgbuf = self.txn.get(img_file.encode()) 106 | 107 | label = self.txn.get(label_file.encode()).decode() 108 | img_path = self.txn.get(path_file.encode()).decode() 109 | 110 | buf = six.BytesIO() 111 | buf.write(imgbuf) 112 | buf.seek(0) 113 | 114 | return buf, label, img_path 115 | 116 | def read_data(self, idx): 117 | buf, label, img_path = self.read_buffer(idx) 118 | 119 | img = Image.open(buf).convert("RGB") 120 | 121 | if self.transform: 122 | img = self.transform(img) 123 | 124 | img_bw = process_image( 125 | img, self.image_height, self.image_min_width, self.image_max_width 126 | ) 127 | 128 | word = self.vocab.encode(label) 129 | 130 | return img_bw, word, img_path 131 | 132 | def __getitem__(self, idx): 133 | img, word, img_path = self.read_data(idx) 134 | 135 | img_path = os.path.join(self.root_dir, img_path) 136 | 137 | sample = {"img": img, "word": word, "img_path": img_path} 138 | 139 | return sample 140 | 141 | def __len__(self): 142 | return self.nSamples 143 | 144 | 145 | class ClusterRandomSampler(Sampler): 146 | 147 | def __init__(self, data_source, batch_size, shuffle=True): 148 | self.data_source = data_source 149 | self.batch_size = batch_size 150 | self.shuffle = shuffle 151 | 152 | def flatten_list(self, lst): 153 | return [item for sublist in lst for item in sublist] 154 | 155 | def __iter__(self): 156 | batch_lists = [] 157 | for cluster, cluster_indices in self.data_source.cluster_indices.items(): 158 | if self.shuffle: 159 | random.shuffle(cluster_indices) 160 | 161 | batches = [ 162 | cluster_indices[i : i + self.batch_size] 163 | for i in range(0, len(cluster_indices), self.batch_size) 164 | ] 165 | batches = [_ for _ in batches if len(_) == self.batch_size] 166 | if self.shuffle: 167 | random.shuffle(batches) 168 | 169 | batch_lists.append(batches) 170 | 171 | lst = self.flatten_list(batch_lists) 172 | if self.shuffle: 173 | random.shuffle(lst) 174 | 175 | lst = self.flatten_list(lst) 176 | 177 | return iter(lst) 178 | 179 | def __len__(self): 180 | return len(self.data_source) 181 | 182 | 183 | class Collator(object): 184 | def __init__(self, masked_language_model=True): 185 | self.masked_language_model = masked_language_model 186 | 187 | def __call__(self, batch): 188 | filenames = [] 189 | img = [] 190 | target_weights = [] 191 | tgt_input = [] 192 | max_label_len = max(len(sample["word"]) for sample in batch) 193 | for sample in batch: 194 | img.append(sample["img"]) 195 | filenames.append(sample["img_path"]) 196 | label = sample["word"] 197 | label_len = len(label) 198 | 199 | tgt = np.concatenate( 200 | (label, np.zeros(max_label_len - label_len, dtype=np.int32)) 201 | ) 202 | tgt_input.append(tgt) 203 | 204 | one_mask_len = label_len - 1 205 | 206 | target_weights.append( 207 | np.concatenate( 208 | ( 209 | np.ones(one_mask_len, dtype=np.float32), 210 | np.zeros(max_label_len - one_mask_len, dtype=np.float32), 211 | ) 212 | ) 213 | ) 214 | 215 | img = np.array(img, dtype=np.float32) 216 | 217 | tgt_input = np.array(tgt_input, dtype=np.int64).T 218 | tgt_output = np.roll(tgt_input, -1, 0).T 219 | tgt_output[:, -1] = 0 220 | 221 | # random mask token 222 | if self.masked_language_model: 223 | mask = np.random.random(size=tgt_input.shape) < 0.05 224 | mask = mask & (tgt_input != 0) & (tgt_input != 1) & (tgt_input != 2) 225 | tgt_input[mask] = 3 226 | 227 | tgt_padding_mask = np.array(target_weights) == 0 228 | 229 | rs = { 230 | "img": torch.FloatTensor(img), 231 | "tgt_input": torch.LongTensor(tgt_input), 232 | "tgt_output": torch.LongTensor(tgt_output), 233 | "tgt_padding_mask": torch.BoolTensor(tgt_padding_mask), 234 | "filenames": filenames, 235 | } 236 | 237 | return rs 238 | -------------------------------------------------------------------------------- /vietocr/loader/dataloader_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import random 5 | from vietocr.model.vocab import Vocab 6 | from vietocr.tool.translate import process_image 7 | import os 8 | from collections import defaultdict 9 | import math 10 | from prefetch_generator import background 11 | 12 | 13 | class BucketData(object): 14 | def __init__(self, device): 15 | self.max_label_len = 0 16 | self.data_list = [] 17 | self.label_list = [] 18 | self.file_list = [] 19 | self.device = device 20 | 21 | def append(self, datum, label, filename): 22 | self.data_list.append(datum) 23 | self.label_list.append(label) 24 | self.file_list.append(filename) 25 | 26 | self.max_label_len = max(len(label), self.max_label_len) 27 | 28 | return len(self.data_list) 29 | 30 | def flush_out(self): 31 | """ 32 | Shape: 33 | - img: (N, C, H, W) 34 | - tgt_input: (T, N) 35 | - tgt_output: (N, T) 36 | - tgt_padding_mask: (N, T) 37 | """ 38 | # encoder part 39 | img = np.array(self.data_list, dtype=np.float32) 40 | 41 | # decoder part 42 | target_weights = [] 43 | tgt_input = [] 44 | for label in self.label_list: 45 | label_len = len(label) 46 | 47 | tgt = np.concatenate( 48 | (label, np.zeros(self.max_label_len - label_len, dtype=np.int32)) 49 | ) 50 | tgt_input.append(tgt) 51 | 52 | one_mask_len = label_len - 1 53 | 54 | target_weights.append( 55 | np.concatenate( 56 | ( 57 | np.ones(one_mask_len, dtype=np.float32), 58 | np.zeros(self.max_label_len - one_mask_len, dtype=np.float32), 59 | ) 60 | ) 61 | ) 62 | 63 | # reshape to fit input shape 64 | tgt_input = np.array(tgt_input, dtype=np.int64).T 65 | tgt_output = np.roll(tgt_input, -1, 0).T 66 | tgt_output[:, -1] = 0 67 | 68 | tgt_padding_mask = np.array(target_weights) == 0 69 | 70 | filenames = self.file_list 71 | 72 | self.data_list, self.label_list, self.file_list = [], [], [] 73 | self.max_label_len = 0 74 | 75 | rs = { 76 | "img": torch.FloatTensor(img).to(self.device), 77 | "tgt_input": torch.LongTensor(tgt_input).to(self.device), 78 | "tgt_output": torch.LongTensor(tgt_output).to(self.device), 79 | "tgt_padding_mask": torch.BoolTensor(tgt_padding_mask).to(self.device), 80 | "filenames": filenames, 81 | } 82 | 83 | return rs 84 | 85 | def __len__(self): 86 | return len(self.data_list) 87 | 88 | def __iadd__(self, other): 89 | self.data_list += other.data_list 90 | self.label_list += other.label_list 91 | self.max_label_len = max(self.max_label_len, other.max_label_len) 92 | self.max_width = max(self.max_width, other.max_width) 93 | 94 | def __add__(self, other): 95 | res = BucketData() 96 | res.data_list = self.data_list + other.data_list 97 | res.label_list = self.label_list + other.label_list 98 | res.max_width = max(self.max_width, other.max_width) 99 | res.max_label_len = max((self.max_label_len, other.max_label_len)) 100 | return res 101 | 102 | 103 | class DataGen(object): 104 | 105 | def __init__( 106 | self, 107 | data_root, 108 | annotation_fn, 109 | vocab, 110 | device, 111 | image_height=32, 112 | image_min_width=32, 113 | image_max_width=512, 114 | ): 115 | 116 | self.image_height = image_height 117 | self.image_min_width = image_min_width 118 | self.image_max_width = image_max_width 119 | 120 | self.data_root = data_root 121 | self.annotation_path = os.path.join(data_root, annotation_fn) 122 | 123 | self.vocab = vocab 124 | self.device = device 125 | 126 | self.clear() 127 | 128 | def clear(self): 129 | self.bucket_data = defaultdict(lambda: BucketData(self.device)) 130 | 131 | @background(max_prefetch=1) 132 | def gen(self, batch_size, last_batch=True): 133 | with open(self.annotation_path, "r") as ann_file: 134 | lines = ann_file.readlines() 135 | np.random.shuffle(lines) 136 | for l in lines: 137 | 138 | img_path, lex = l.strip().split("\t") 139 | 140 | img_path = os.path.join(self.data_root, img_path) 141 | 142 | try: 143 | img_bw, word = self.read_data(img_path, lex) 144 | except IOError: 145 | print("ioread image:{}".format(img_path)) 146 | 147 | width = img_bw.shape[-1] 148 | 149 | bs = self.bucket_data[width].append(img_bw, word, img_path) 150 | if bs >= batch_size: 151 | b = self.bucket_data[width].flush_out() 152 | yield b 153 | 154 | if last_batch: 155 | for bucket in self.bucket_data.values(): 156 | if len(bucket) > 0: 157 | b = bucket.flush_out() 158 | yield b 159 | 160 | self.clear() 161 | 162 | def read_data(self, img_path, lex): 163 | 164 | with open(img_path, "rb") as img_file: 165 | img = Image.open(img_file).convert("RGB") 166 | img_bw = process_image( 167 | img, self.image_height, self.image_min_width, self.image_max_width 168 | ) 169 | 170 | word = self.vocab.encode(lex) 171 | 172 | return img_bw, word 173 | -------------------------------------------------------------------------------- /vietocr/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/model/__init__.py -------------------------------------------------------------------------------- /vietocr/model/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/model/backbone/__init__.py -------------------------------------------------------------------------------- /vietocr/model/backbone/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import vietocr.model.backbone.vgg as vgg 5 | from vietocr.model.backbone.resnet import Resnet50 6 | 7 | 8 | class CNN(nn.Module): 9 | def __init__(self, backbone, **kwargs): 10 | super(CNN, self).__init__() 11 | 12 | if backbone == "vgg11_bn": 13 | self.model = vgg.vgg11_bn(**kwargs) 14 | elif backbone == "vgg19_bn": 15 | self.model = vgg.vgg19_bn(**kwargs) 16 | elif backbone == "resnet50": 17 | self.model = Resnet50(**kwargs) 18 | 19 | def forward(self, x): 20 | return self.model(x) 21 | 22 | def freeze(self): 23 | for name, param in self.model.features.named_parameters(): 24 | if name != "last_conv_1x1": 25 | param.requires_grad = False 26 | 27 | def unfreeze(self): 28 | for param in self.model.features.parameters(): 29 | param.requires_grad = True 30 | -------------------------------------------------------------------------------- /vietocr/model/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BasicBlock(nn.Module): 6 | expansion = 1 7 | 8 | def __init__(self, inplanes, planes, stride=1, downsample=None): 9 | super(BasicBlock, self).__init__() 10 | self.conv1 = self._conv3x3(inplanes, planes) 11 | self.bn1 = nn.BatchNorm2d(planes) 12 | self.conv2 = self._conv3x3(planes, planes) 13 | self.bn2 = nn.BatchNorm2d(planes) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.downsample = downsample 16 | self.stride = stride 17 | 18 | def _conv3x3(self, in_planes, out_planes, stride=1): 19 | "3x3 convolution with padding" 20 | return nn.Conv2d( 21 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 22 | ) 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | out += residual 37 | out = self.relu(out) 38 | 39 | return out 40 | 41 | 42 | class ResNet(nn.Module): 43 | 44 | def __init__(self, input_channel, output_channel, block, layers): 45 | super(ResNet, self).__init__() 46 | 47 | self.output_channel_block = [ 48 | int(output_channel / 4), 49 | int(output_channel / 2), 50 | output_channel, 51 | output_channel, 52 | ] 53 | 54 | self.inplanes = int(output_channel / 8) 55 | self.conv0_1 = nn.Conv2d( 56 | input_channel, 57 | int(output_channel / 16), 58 | kernel_size=3, 59 | stride=1, 60 | padding=1, 61 | bias=False, 62 | ) 63 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 64 | self.conv0_2 = nn.Conv2d( 65 | int(output_channel / 16), 66 | self.inplanes, 67 | kernel_size=3, 68 | stride=1, 69 | padding=1, 70 | bias=False, 71 | ) 72 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 73 | self.relu = nn.ReLU(inplace=True) 74 | 75 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 76 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 77 | self.conv1 = nn.Conv2d( 78 | self.output_channel_block[0], 79 | self.output_channel_block[0], 80 | kernel_size=3, 81 | stride=1, 82 | padding=1, 83 | bias=False, 84 | ) 85 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 86 | 87 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 88 | self.layer2 = self._make_layer( 89 | block, self.output_channel_block[1], layers[1], stride=1 90 | ) 91 | self.conv2 = nn.Conv2d( 92 | self.output_channel_block[1], 93 | self.output_channel_block[1], 94 | kernel_size=3, 95 | stride=1, 96 | padding=1, 97 | bias=False, 98 | ) 99 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 100 | 101 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 102 | self.layer3 = self._make_layer( 103 | block, self.output_channel_block[2], layers[2], stride=1 104 | ) 105 | self.conv3 = nn.Conv2d( 106 | self.output_channel_block[2], 107 | self.output_channel_block[2], 108 | kernel_size=3, 109 | stride=1, 110 | padding=1, 111 | bias=False, 112 | ) 113 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 114 | 115 | self.layer4 = self._make_layer( 116 | block, self.output_channel_block[3], layers[3], stride=1 117 | ) 118 | self.conv4_1 = nn.Conv2d( 119 | self.output_channel_block[3], 120 | self.output_channel_block[3], 121 | kernel_size=2, 122 | stride=(2, 1), 123 | padding=(0, 1), 124 | bias=False, 125 | ) 126 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 127 | self.conv4_2 = nn.Conv2d( 128 | self.output_channel_block[3], 129 | self.output_channel_block[3], 130 | kernel_size=2, 131 | stride=1, 132 | padding=0, 133 | bias=False, 134 | ) 135 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d( 142 | self.inplanes, 143 | planes * block.expansion, 144 | kernel_size=1, 145 | stride=stride, 146 | bias=False, 147 | ), 148 | nn.BatchNorm2d(planes * block.expansion), 149 | ) 150 | 151 | layers = [] 152 | layers.append(block(self.inplanes, planes, stride, downsample)) 153 | self.inplanes = planes * block.expansion 154 | for i in range(1, blocks): 155 | layers.append(block(self.inplanes, planes)) 156 | 157 | return nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | x = self.conv0_1(x) 161 | x = self.bn0_1(x) 162 | x = self.relu(x) 163 | x = self.conv0_2(x) 164 | x = self.bn0_2(x) 165 | x = self.relu(x) 166 | 167 | x = self.maxpool1(x) 168 | x = self.layer1(x) 169 | x = self.conv1(x) 170 | x = self.bn1(x) 171 | x = self.relu(x) 172 | 173 | x = self.maxpool2(x) 174 | x = self.layer2(x) 175 | x = self.conv2(x) 176 | x = self.bn2(x) 177 | x = self.relu(x) 178 | 179 | x = self.maxpool3(x) 180 | x = self.layer3(x) 181 | x = self.conv3(x) 182 | x = self.bn3(x) 183 | x = self.relu(x) 184 | 185 | x = self.layer4(x) 186 | x = self.conv4_1(x) 187 | x = self.bn4_1(x) 188 | x = self.relu(x) 189 | x = self.conv4_2(x) 190 | x = self.bn4_2(x) 191 | conv = self.relu(x) 192 | 193 | conv = conv.transpose(-1, -2) 194 | conv = conv.flatten(2) 195 | conv = conv.permute(-1, 0, 1) 196 | 197 | return conv 198 | 199 | 200 | def Resnet50(ss, hidden): 201 | return ResNet(3, hidden, BasicBlock, [1, 2, 5, 3]) 202 | -------------------------------------------------------------------------------- /vietocr/model/backbone/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | from einops import rearrange 5 | from torchvision.models._utils import IntermediateLayerGetter 6 | 7 | 8 | class Vgg(nn.Module): 9 | def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5): 10 | super(Vgg, self).__init__() 11 | 12 | if pretrained: 13 | weights = "DEFAULT" 14 | else: 15 | weights = None 16 | 17 | if name == "vgg11_bn": 18 | cnn = models.vgg11_bn(weights=weights) 19 | elif name == "vgg19_bn": 20 | cnn = models.vgg19_bn(weights=weights) 21 | 22 | pool_idx = 0 23 | 24 | for i, layer in enumerate(cnn.features): 25 | if isinstance(layer, torch.nn.MaxPool2d): 26 | cnn.features[i] = torch.nn.AvgPool2d( 27 | kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0 28 | ) 29 | pool_idx += 1 30 | 31 | self.features = cnn.features 32 | self.dropout = nn.Dropout(dropout) 33 | self.last_conv_1x1 = nn.Conv2d(512, hidden, 1) 34 | 35 | def forward(self, x): 36 | """ 37 | Shape: 38 | - x: (N, C, H, W) 39 | - output: (W, N, C) 40 | """ 41 | 42 | conv = self.features(x) 43 | conv = self.dropout(conv) 44 | conv = self.last_conv_1x1(conv) 45 | 46 | # conv = rearrange(conv, 'b d h w -> b d (w h)') 47 | conv = conv.transpose(-1, -2) 48 | conv = conv.flatten(2) 49 | conv = conv.permute(-1, 0, 1) 50 | return conv 51 | 52 | 53 | def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5): 54 | return Vgg("vgg11_bn", ss, ks, hidden, pretrained, dropout) 55 | 56 | 57 | def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5): 58 | return Vgg("vgg19_bn", ss, ks, hidden, pretrained, dropout) 59 | -------------------------------------------------------------------------------- /vietocr/model/beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Beam: 5 | 6 | def __init__( 7 | self, 8 | beam_size=8, 9 | min_length=0, 10 | n_top=1, 11 | ranker=None, 12 | start_token_id=1, 13 | end_token_id=2, 14 | ): 15 | self.beam_size = beam_size 16 | self.min_length = min_length 17 | self.ranker = ranker 18 | 19 | self.end_token_id = end_token_id 20 | self.top_sentence_ended = False 21 | 22 | self.prev_ks = [] 23 | self.next_ys = [ 24 | torch.LongTensor(beam_size).fill_(start_token_id) 25 | ] # remove padding 26 | 27 | self.current_scores = torch.FloatTensor(beam_size).zero_() 28 | self.all_scores = [] 29 | 30 | # Time and k pair for finished. 31 | self.finished = [] 32 | self.n_top = n_top 33 | 34 | self.ranker = ranker 35 | 36 | def advance(self, next_log_probs): 37 | # next_probs : beam_size X vocab_size 38 | 39 | vocabulary_size = next_log_probs.size(1) 40 | # current_beam_size = next_log_probs.size(0) 41 | 42 | current_length = len(self.next_ys) 43 | if current_length < self.min_length: 44 | for beam_index in range(len(next_log_probs)): 45 | next_log_probs[beam_index][self.end_token_id] = -1e10 46 | 47 | if len(self.prev_ks) > 0: 48 | beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as( 49 | next_log_probs 50 | ) 51 | # Don't let EOS have children. 52 | last_y = self.next_ys[-1] 53 | for beam_index in range(last_y.size(0)): 54 | if last_y[beam_index] == self.end_token_id: 55 | beam_scores[beam_index] = -1e10 # -1e20 raises error when executing 56 | else: 57 | beam_scores = next_log_probs[0] 58 | 59 | flat_beam_scores = beam_scores.view(-1) 60 | top_scores, top_score_ids = flat_beam_scores.topk( 61 | k=self.beam_size, dim=0, largest=True, sorted=True 62 | ) 63 | 64 | self.current_scores = top_scores 65 | self.all_scores.append(self.current_scores) 66 | 67 | prev_k = top_score_ids // vocabulary_size # (beam_size, ) 68 | next_y = top_score_ids - prev_k * vocabulary_size # (beam_size, ) 69 | 70 | self.prev_ks.append(prev_k) 71 | self.next_ys.append(next_y) 72 | 73 | for beam_index, last_token_id in enumerate(next_y): 74 | 75 | if last_token_id == self.end_token_id: 76 | 77 | # skip scoring 78 | self.finished.append( 79 | (self.current_scores[beam_index], len(self.next_ys) - 1, beam_index) 80 | ) 81 | 82 | if next_y[0] == self.end_token_id: 83 | self.top_sentence_ended = True 84 | 85 | def get_current_state(self): 86 | "Get the outputs for the current timestep." 87 | return torch.stack(self.next_ys, dim=1) 88 | 89 | def get_current_origin(self): 90 | "Get the backpointers for the current timestep." 91 | return self.prev_ks[-1] 92 | 93 | def done(self): 94 | return self.top_sentence_ended and len(self.finished) >= self.n_top 95 | 96 | def get_hypothesis(self, timestep, k): 97 | hypothesis = [] 98 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 99 | hypothesis.append(self.next_ys[j + 1][k]) 100 | # for RNN, [:, k, :], and for trnasformer, [k, :, :] 101 | k = self.prev_ks[j][k] 102 | 103 | return hypothesis[::-1] 104 | 105 | def sort_finished(self, minimum=None): 106 | if minimum is not None: 107 | i = 0 108 | # Add from beam until we have minimum outputs. 109 | while len(self.finished) < minimum: 110 | # global_scores = self.global_scorer.score(self, self.scores) 111 | # s = global_scores[i] 112 | s = self.current_scores[i] 113 | self.finished.append((s, len(self.next_ys) - 1, i)) 114 | i += 1 115 | 116 | self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True) 117 | scores = [sc for sc, _, _ in self.finished] 118 | ks = [(t, k) for _, t, k in self.finished] 119 | return scores, ks 120 | -------------------------------------------------------------------------------- /vietocr/model/seqmodel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/model/seqmodel/__init__.py -------------------------------------------------------------------------------- /vietocr/model/seqmodel/convseq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__( 9 | self, emb_dim, hid_dim, n_layers, kernel_size, dropout, device, max_length=512 10 | ): 11 | super().__init__() 12 | 13 | assert kernel_size % 2 == 1, "Kernel size must be odd!" 14 | 15 | self.device = device 16 | 17 | self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device) 18 | 19 | # self.tok_embedding = nn.Embedding(input_dim, emb_dim) 20 | self.pos_embedding = nn.Embedding(max_length, emb_dim) 21 | 22 | self.emb2hid = nn.Linear(emb_dim, hid_dim) 23 | self.hid2emb = nn.Linear(hid_dim, emb_dim) 24 | 25 | self.convs = nn.ModuleList( 26 | [ 27 | nn.Conv1d( 28 | in_channels=hid_dim, 29 | out_channels=2 * hid_dim, 30 | kernel_size=kernel_size, 31 | padding=(kernel_size - 1) // 2, 32 | ) 33 | for _ in range(n_layers) 34 | ] 35 | ) 36 | 37 | self.dropout = nn.Dropout(dropout) 38 | 39 | def forward(self, src): 40 | 41 | # src = [batch size, src len] 42 | 43 | src = src.transpose(0, 1) 44 | 45 | batch_size = src.shape[0] 46 | src_len = src.shape[1] 47 | device = src.device 48 | 49 | # create position tensor 50 | pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device) 51 | 52 | # pos = [0, 1, 2, 3, ..., src len - 1] 53 | 54 | # pos = [batch size, src len] 55 | 56 | # embed tokens and positions 57 | 58 | # tok_embedded = self.tok_embedding(src) 59 | tok_embedded = src 60 | 61 | pos_embedded = self.pos_embedding(pos) 62 | 63 | # tok_embedded = pos_embedded = [batch size, src len, emb dim] 64 | 65 | # combine embeddings by elementwise summing 66 | embedded = self.dropout(tok_embedded + pos_embedded) 67 | 68 | # embedded = [batch size, src len, emb dim] 69 | 70 | # pass embedded through linear layer to convert from emb dim to hid dim 71 | conv_input = self.emb2hid(embedded) 72 | 73 | # conv_input = [batch size, src len, hid dim] 74 | 75 | # permute for convolutional layer 76 | conv_input = conv_input.permute(0, 2, 1) 77 | 78 | # conv_input = [batch size, hid dim, src len] 79 | 80 | # begin convolutional blocks... 81 | 82 | for i, conv in enumerate(self.convs): 83 | 84 | # pass through convolutional layer 85 | conved = conv(self.dropout(conv_input)) 86 | 87 | # conved = [batch size, 2 * hid dim, src len] 88 | 89 | # pass through GLU activation function 90 | conved = F.glu(conved, dim=1) 91 | 92 | # conved = [batch size, hid dim, src len] 93 | 94 | # apply residual connection 95 | conved = (conved + conv_input) * self.scale 96 | 97 | # conved = [batch size, hid dim, src len] 98 | 99 | # set conv_input to conved for next loop iteration 100 | conv_input = conved 101 | 102 | # ...end convolutional blocks 103 | 104 | # permute and convert back to emb dim 105 | conved = self.hid2emb(conved.permute(0, 2, 1)) 106 | 107 | # conved = [batch size, src len, emb dim] 108 | 109 | # elementwise sum output (conved) and input (embedded) to be used for attention 110 | combined = (conved + embedded) * self.scale 111 | 112 | # combined = [batch size, src len, emb dim] 113 | 114 | return conved, combined 115 | 116 | 117 | class Decoder(nn.Module): 118 | def __init__( 119 | self, 120 | output_dim, 121 | emb_dim, 122 | hid_dim, 123 | n_layers, 124 | kernel_size, 125 | dropout, 126 | trg_pad_idx, 127 | device, 128 | max_length=512, 129 | ): 130 | super().__init__() 131 | 132 | self.kernel_size = kernel_size 133 | self.trg_pad_idx = trg_pad_idx 134 | self.device = device 135 | 136 | self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device) 137 | 138 | self.tok_embedding = nn.Embedding(output_dim, emb_dim) 139 | self.pos_embedding = nn.Embedding(max_length, emb_dim) 140 | 141 | self.emb2hid = nn.Linear(emb_dim, hid_dim) 142 | self.hid2emb = nn.Linear(hid_dim, emb_dim) 143 | 144 | self.attn_hid2emb = nn.Linear(hid_dim, emb_dim) 145 | self.attn_emb2hid = nn.Linear(emb_dim, hid_dim) 146 | 147 | self.fc_out = nn.Linear(emb_dim, output_dim) 148 | 149 | self.convs = nn.ModuleList( 150 | [ 151 | nn.Conv1d( 152 | in_channels=hid_dim, 153 | out_channels=2 * hid_dim, 154 | kernel_size=kernel_size, 155 | ) 156 | for _ in range(n_layers) 157 | ] 158 | ) 159 | 160 | self.dropout = nn.Dropout(dropout) 161 | 162 | def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined): 163 | 164 | # embedded = [batch size, trg len, emb dim] 165 | # conved = [batch size, hid dim, trg len] 166 | # encoder_conved = encoder_combined = [batch size, src len, emb dim] 167 | 168 | # permute and convert back to emb dim 169 | conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1)) 170 | 171 | # conved_emb = [batch size, trg len, emb dim] 172 | 173 | combined = (conved_emb + embedded) * self.scale 174 | 175 | # combined = [batch size, trg len, emb dim] 176 | 177 | energy = torch.matmul(combined, encoder_conved.permute(0, 2, 1)) 178 | 179 | # energy = [batch size, trg len, src len] 180 | 181 | attention = F.softmax(energy, dim=2) 182 | 183 | # attention = [batch size, trg len, src len] 184 | 185 | attended_encoding = torch.matmul(attention, encoder_combined) 186 | 187 | # attended_encoding = [batch size, trg len, emd dim] 188 | 189 | # convert from emb dim -> hid dim 190 | attended_encoding = self.attn_emb2hid(attended_encoding) 191 | 192 | # attended_encoding = [batch size, trg len, hid dim] 193 | 194 | # apply residual connection 195 | attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale 196 | 197 | # attended_combined = [batch size, hid dim, trg len] 198 | 199 | return attention, attended_combined 200 | 201 | def forward(self, trg, encoder_conved, encoder_combined): 202 | 203 | # trg = [batch size, trg len] 204 | # encoder_conved = encoder_combined = [batch size, src len, emb dim] 205 | trg = trg.transpose(0, 1) 206 | 207 | batch_size = trg.shape[0] 208 | trg_len = trg.shape[1] 209 | device = trg.device 210 | 211 | # create position tensor 212 | pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device) 213 | 214 | # pos = [batch size, trg len] 215 | 216 | # embed tokens and positions 217 | tok_embedded = self.tok_embedding(trg) 218 | pos_embedded = self.pos_embedding(pos) 219 | 220 | # tok_embedded = [batch size, trg len, emb dim] 221 | # pos_embedded = [batch size, trg len, emb dim] 222 | 223 | # combine embeddings by elementwise summing 224 | embedded = self.dropout(tok_embedded + pos_embedded) 225 | 226 | # embedded = [batch size, trg len, emb dim] 227 | 228 | # pass embedded through linear layer to go through emb dim -> hid dim 229 | conv_input = self.emb2hid(embedded) 230 | 231 | # conv_input = [batch size, trg len, hid dim] 232 | 233 | # permute for convolutional layer 234 | conv_input = conv_input.permute(0, 2, 1) 235 | 236 | # conv_input = [batch size, hid dim, trg len] 237 | 238 | batch_size = conv_input.shape[0] 239 | hid_dim = conv_input.shape[1] 240 | 241 | for i, conv in enumerate(self.convs): 242 | 243 | # apply dropout 244 | conv_input = self.dropout(conv_input) 245 | 246 | # need to pad so decoder can't "cheat" 247 | padding = ( 248 | torch.zeros(batch_size, hid_dim, self.kernel_size - 1) 249 | .fill_(self.trg_pad_idx) 250 | .to(device) 251 | ) 252 | 253 | padded_conv_input = torch.cat((padding, conv_input), dim=2) 254 | 255 | # padded_conv_input = [batch size, hid dim, trg len + kernel size - 1] 256 | 257 | # pass through convolutional layer 258 | conved = conv(padded_conv_input) 259 | 260 | # conved = [batch size, 2 * hid dim, trg len] 261 | 262 | # pass through GLU activation function 263 | conved = F.glu(conved, dim=1) 264 | 265 | # conved = [batch size, hid dim, trg len] 266 | 267 | # calculate attention 268 | attention, conved = self.calculate_attention( 269 | embedded, conved, encoder_conved, encoder_combined 270 | ) 271 | 272 | # attention = [batch size, trg len, src len] 273 | 274 | # apply residual connection 275 | conved = (conved + conv_input) * self.scale 276 | 277 | # conved = [batch size, hid dim, trg len] 278 | 279 | # set conv_input to conved for next loop iteration 280 | conv_input = conved 281 | 282 | conved = self.hid2emb(conved.permute(0, 2, 1)) 283 | 284 | # conved = [batch size, trg len, emb dim] 285 | 286 | output = self.fc_out(self.dropout(conved)) 287 | 288 | # output = [batch size, trg len, output dim] 289 | 290 | return output, attention 291 | 292 | 293 | class ConvSeq2Seq(nn.Module): 294 | def __init__( 295 | self, 296 | vocab_size, 297 | emb_dim, 298 | hid_dim, 299 | enc_layers, 300 | dec_layers, 301 | enc_kernel_size, 302 | dec_kernel_size, 303 | enc_max_length, 304 | dec_max_length, 305 | dropout, 306 | pad_idx, 307 | device, 308 | ): 309 | super().__init__() 310 | 311 | enc = Encoder( 312 | emb_dim, 313 | hid_dim, 314 | enc_layers, 315 | enc_kernel_size, 316 | dropout, 317 | device, 318 | enc_max_length, 319 | ) 320 | dec = Decoder( 321 | vocab_size, 322 | emb_dim, 323 | hid_dim, 324 | dec_layers, 325 | dec_kernel_size, 326 | dropout, 327 | pad_idx, 328 | device, 329 | dec_max_length, 330 | ) 331 | 332 | self.encoder = enc 333 | self.decoder = dec 334 | 335 | def forward_encoder(self, src): 336 | encoder_conved, encoder_combined = self.encoder(src) 337 | 338 | return encoder_conved, encoder_combined 339 | 340 | def forward_decoder(self, trg, memory): 341 | encoder_conved, encoder_combined = memory 342 | output, attention = self.decoder(trg, encoder_conved, encoder_combined) 343 | 344 | return output, (encoder_conved, encoder_combined) 345 | 346 | def forward(self, src, trg): 347 | 348 | # src = [batch size, src len] 349 | # trg = [batch size, trg len - 1] ( token sliced off the end) 350 | 351 | # calculate z^u (encoder_conved) and (z^u + e) (encoder_combined) 352 | # encoder_conved is output from final encoder conv. block 353 | # encoder_combined is encoder_conved plus (elementwise) src embedding plus 354 | # positional embeddings 355 | encoder_conved, encoder_combined = self.encoder(src) 356 | 357 | # encoder_conved = [batch size, src len, emb dim] 358 | # encoder_combined = [batch size, src len, emb dim] 359 | 360 | # calculate predictions of next words 361 | # output is a batch of predictions for each word in the trg sentence 362 | # attention a batch of attention scores across the src sentence for 363 | # each word in the trg sentence 364 | output, attention = self.decoder(trg, encoder_conved, encoder_combined) 365 | 366 | # output = [batch size, trg len - 1, output dim] 367 | # attention = [batch size, trg len - 1, src len] 368 | 369 | return output # , attention 370 | -------------------------------------------------------------------------------- /vietocr/model/seqmodel/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout): 9 | super().__init__() 10 | 11 | self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True) 12 | self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim) 13 | self.dropout = nn.Dropout(dropout) 14 | 15 | def forward(self, src): 16 | """ 17 | src: src_len x batch_size x img_channel 18 | outputs: src_len x batch_size x hid_dim 19 | hidden: batch_size x hid_dim 20 | """ 21 | 22 | embedded = self.dropout(src) 23 | 24 | outputs, hidden = self.rnn(embedded) 25 | 26 | hidden = torch.tanh( 27 | self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)) 28 | ) 29 | 30 | return outputs, hidden 31 | 32 | 33 | class Attention(nn.Module): 34 | def __init__(self, enc_hid_dim, dec_hid_dim): 35 | super().__init__() 36 | 37 | self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) 38 | self.v = nn.Linear(dec_hid_dim, 1, bias=False) 39 | 40 | def forward(self, hidden, encoder_outputs): 41 | """ 42 | hidden: batch_size x hid_dim 43 | encoder_outputs: src_len x batch_size x hid_dim, 44 | outputs: batch_size x src_len 45 | """ 46 | 47 | batch_size = encoder_outputs.shape[1] 48 | src_len = encoder_outputs.shape[0] 49 | 50 | hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) 51 | 52 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 53 | 54 | energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2))) 55 | 56 | attention = self.v(energy).squeeze(2) 57 | 58 | return F.softmax(attention, dim=1) 59 | 60 | 61 | class Decoder(nn.Module): 62 | def __init__( 63 | self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention 64 | ): 65 | super().__init__() 66 | 67 | self.output_dim = output_dim 68 | self.attention = attention 69 | 70 | self.embedding = nn.Embedding(output_dim, emb_dim) 71 | self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim) 72 | self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim) 73 | self.dropout = nn.Dropout(dropout) 74 | 75 | def forward(self, input, hidden, encoder_outputs): 76 | """ 77 | inputs: batch_size 78 | hidden: batch_size x hid_dim 79 | encoder_outputs: src_len x batch_size x hid_dim 80 | """ 81 | 82 | input = input.unsqueeze(0) 83 | 84 | embedded = self.dropout(self.embedding(input)) 85 | 86 | a = self.attention(hidden, encoder_outputs) 87 | 88 | a = a.unsqueeze(1) 89 | 90 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 91 | 92 | weighted = torch.bmm(a, encoder_outputs) 93 | 94 | weighted = weighted.permute(1, 0, 2) 95 | 96 | rnn_input = torch.cat((embedded, weighted), dim=2) 97 | 98 | output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) 99 | 100 | assert (output == hidden).all() 101 | 102 | embedded = embedded.squeeze(0) 103 | output = output.squeeze(0) 104 | weighted = weighted.squeeze(0) 105 | 106 | prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1)) 107 | 108 | return prediction, hidden.squeeze(0), a.squeeze(1) 109 | 110 | 111 | class Seq2Seq(nn.Module): 112 | def __init__( 113 | self, 114 | vocab_size, 115 | encoder_hidden, 116 | decoder_hidden, 117 | img_channel, 118 | decoder_embedded, 119 | dropout=0.1, 120 | ): 121 | super().__init__() 122 | 123 | attn = Attention(encoder_hidden, decoder_hidden) 124 | 125 | self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout) 126 | self.decoder = Decoder( 127 | vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn 128 | ) 129 | 130 | def forward_encoder(self, src): 131 | """ 132 | src: timestep x batch_size x channel 133 | hidden: batch_size x hid_dim 134 | encoder_outputs: src_len x batch_size x hid_dim 135 | """ 136 | 137 | encoder_outputs, hidden = self.encoder(src) 138 | 139 | return (hidden, encoder_outputs) 140 | 141 | def forward_decoder(self, tgt, memory): 142 | """ 143 | tgt: timestep x batch_size 144 | hidden: batch_size x hid_dim 145 | encouder: src_len x batch_size x hid_dim 146 | output: batch_size x 1 x vocab_size 147 | """ 148 | 149 | tgt = tgt[-1] 150 | hidden, encoder_outputs = memory 151 | output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs) 152 | output = output.unsqueeze(1) 153 | 154 | return output, (hidden, encoder_outputs) 155 | 156 | def forward(self, src, trg): 157 | """ 158 | src: time_step x batch_size 159 | trg: time_step x batch_size 160 | outputs: batch_size x time_step x vocab_size 161 | """ 162 | 163 | batch_size = src.shape[1] 164 | trg_len = trg.shape[0] 165 | trg_vocab_size = self.decoder.output_dim 166 | device = src.device 167 | 168 | outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device) 169 | encoder_outputs, hidden = self.encoder(src) 170 | 171 | for t in range(trg_len): 172 | input = trg[t] 173 | output, hidden, _ = self.decoder(input, hidden, encoder_outputs) 174 | 175 | outputs[t] = output 176 | 177 | outputs = outputs.transpose(0, 1).contiguous() 178 | 179 | return outputs 180 | 181 | def expand_memory(self, memory, beam_size): 182 | hidden, encoder_outputs = memory 183 | hidden = hidden.repeat(beam_size, 1) 184 | encoder_outputs = encoder_outputs.repeat(1, beam_size, 1) 185 | 186 | return (hidden, encoder_outputs) 187 | 188 | def get_memory(self, memory, i): 189 | hidden, encoder_outputs = memory 190 | hidden = hidden[[i]] 191 | encoder_outputs = encoder_outputs[:, [i], :] 192 | 193 | return (hidden, encoder_outputs) 194 | -------------------------------------------------------------------------------- /vietocr/model/seqmodel/transformer.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torchvision import models 3 | import math 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class LanguageTransformer(nn.Module): 9 | def __init__( 10 | self, 11 | vocab_size, 12 | d_model, 13 | nhead, 14 | num_encoder_layers, 15 | num_decoder_layers, 16 | dim_feedforward, 17 | max_seq_length, 18 | pos_dropout, 19 | trans_dropout, 20 | ): 21 | super().__init__() 22 | 23 | self.d_model = d_model 24 | self.embed_tgt = nn.Embedding(vocab_size, d_model) 25 | self.pos_enc = PositionalEncoding(d_model, pos_dropout, max_seq_length) 26 | # self.learned_pos_enc = LearnedPositionalEncoding(d_model, pos_dropout, max_seq_length) 27 | 28 | self.transformer = nn.Transformer( 29 | d_model, 30 | nhead, 31 | num_encoder_layers, 32 | num_decoder_layers, 33 | dim_feedforward, 34 | trans_dropout, 35 | ) 36 | 37 | self.fc = nn.Linear(d_model, vocab_size) 38 | 39 | def forward( 40 | self, 41 | src, 42 | tgt, 43 | src_key_padding_mask=None, 44 | tgt_key_padding_mask=None, 45 | memory_key_padding_mask=None, 46 | ): 47 | """ 48 | Shape: 49 | - src: (W, N, C) 50 | - tgt: (T, N) 51 | - src_key_padding_mask: (N, S) 52 | - tgt_key_padding_mask: (N, T) 53 | - memory_key_padding_mask: (N, S) 54 | - output: (N, T, E) 55 | 56 | """ 57 | tgt_mask = self.gen_nopeek_mask(tgt.shape[0]).to(src.device) 58 | 59 | src = self.pos_enc(src * math.sqrt(self.d_model)) 60 | # src = self.learned_pos_enc(src*math.sqrt(self.d_model)) 61 | 62 | tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model)) 63 | 64 | output = self.transformer( 65 | src, 66 | tgt, 67 | tgt_mask=tgt_mask, 68 | src_key_padding_mask=src_key_padding_mask, 69 | tgt_key_padding_mask=tgt_key_padding_mask.float(), 70 | memory_key_padding_mask=memory_key_padding_mask, 71 | ) 72 | # output = rearrange(output, 't n e -> n t e') 73 | output = output.transpose(0, 1) 74 | return self.fc(output) 75 | 76 | def gen_nopeek_mask(self, length): 77 | mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1) 78 | mask = ( 79 | mask.float() 80 | .masked_fill(mask == 0, float("-inf")) 81 | .masked_fill(mask == 1, float(0.0)) 82 | ) 83 | 84 | return mask 85 | 86 | def forward_encoder(self, src): 87 | src = self.pos_enc(src * math.sqrt(self.d_model)) 88 | memory = self.transformer.encoder(src) 89 | return memory 90 | 91 | def forward_decoder(self, tgt, memory): 92 | tgt_mask = self.gen_nopeek_mask(tgt.shape[0]).to(tgt.device) 93 | tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model)) 94 | 95 | output = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask) 96 | # output = rearrange(output, 't n e -> n t e') 97 | output = output.transpose(0, 1) 98 | 99 | return self.fc(output), memory 100 | 101 | def expand_memory(self, memory, beam_size): 102 | memory = memory.repeat(1, beam_size, 1) 103 | return memory 104 | 105 | def get_memory(self, memory, i): 106 | memory = memory[:, [i], :] 107 | return memory 108 | 109 | 110 | class PositionalEncoding(nn.Module): 111 | def __init__(self, d_model, dropout=0.1, max_len=100): 112 | super(PositionalEncoding, self).__init__() 113 | self.dropout = nn.Dropout(p=dropout) 114 | 115 | pe = torch.zeros(max_len, d_model) 116 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 117 | div_term = torch.exp( 118 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 119 | ) 120 | pe[:, 0::2] = torch.sin(position * div_term) 121 | pe[:, 1::2] = torch.cos(position * div_term) 122 | pe = pe.unsqueeze(0).transpose(0, 1) 123 | self.register_buffer("pe", pe) 124 | 125 | def forward(self, x): 126 | x = x + self.pe[: x.size(0), :] 127 | 128 | return self.dropout(x) 129 | 130 | 131 | class LearnedPositionalEncoding(nn.Module): 132 | def __init__(self, d_model, dropout=0.1, max_len=100): 133 | super(LearnedPositionalEncoding, self).__init__() 134 | self.dropout = nn.Dropout(p=dropout) 135 | 136 | self.pos_embed = nn.Embedding(max_len, d_model) 137 | self.layernorm = LayerNorm(d_model) 138 | 139 | def forward(self, x): 140 | seq_len = x.size(0) 141 | pos = torch.arange(seq_len, dtype=torch.long, device=x.device) 142 | pos = pos.unsqueeze(-1).expand(x.size()[:2]) 143 | x = x + self.pos_embed(pos) 144 | return self.dropout(self.layernorm(x)) 145 | 146 | 147 | class LayerNorm(nn.Module): 148 | "A layernorm module in the TF style (epsilon inside the square root)." 149 | 150 | def __init__(self, d_model, variance_epsilon=1e-12): 151 | super().__init__() 152 | self.gamma = nn.Parameter(torch.ones(d_model)) 153 | self.beta = nn.Parameter(torch.zeros(d_model)) 154 | self.variance_epsilon = variance_epsilon 155 | 156 | def forward(self, x): 157 | u = x.mean(-1, keepdim=True) 158 | s = (x - u).pow(2).mean(-1, keepdim=True) 159 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 160 | return self.gamma * x + self.beta 161 | -------------------------------------------------------------------------------- /vietocr/model/trainer.py: -------------------------------------------------------------------------------- 1 | from vietocr.optim.optim import ScheduledOptim 2 | from vietocr.optim.labelsmoothingloss import LabelSmoothingLoss 3 | from torch.optim import Adam, SGD, AdamW 4 | from torch import nn 5 | from vietocr.tool.translate import build_model 6 | from vietocr.tool.translate import translate, batch_translate_beam_search 7 | from vietocr.tool.utils import download_weights 8 | from vietocr.tool.logger import Logger 9 | from vietocr.loader.aug import ImgAugTransformV2 10 | 11 | import yaml 12 | import torch 13 | from vietocr.loader.dataloader_v1 import DataGen 14 | from vietocr.loader.dataloader import OCRDataset, ClusterRandomSampler, Collator 15 | from torch.utils.data import DataLoader 16 | from einops import rearrange 17 | from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, OneCycleLR 18 | 19 | import torchvision 20 | 21 | from vietocr.tool.utils import compute_accuracy 22 | from PIL import Image 23 | import numpy as np 24 | import os 25 | import matplotlib.pyplot as plt 26 | import time 27 | 28 | 29 | class Trainer: 30 | def __init__(self, config, pretrained=True, augmentor=ImgAugTransformV2()): 31 | 32 | self.config = config 33 | self.model, self.vocab = build_model(config) 34 | 35 | self.device = config["device"] 36 | self.num_iters = config["trainer"]["iters"] 37 | self.beamsearch = config["predictor"]["beamsearch"] 38 | 39 | self.data_root = config["dataset"]["data_root"] 40 | self.train_annotation = config["dataset"]["train_annotation"] 41 | self.valid_annotation = config["dataset"]["valid_annotation"] 42 | self.dataset_name = config["dataset"]["name"] 43 | 44 | self.batch_size = config["trainer"]["batch_size"] 45 | self.print_every = config["trainer"]["print_every"] 46 | self.valid_every = config["trainer"]["valid_every"] 47 | 48 | self.image_aug = config["aug"]["image_aug"] 49 | self.masked_language_model = config["aug"]["masked_language_model"] 50 | 51 | self.checkpoint = config["trainer"]["checkpoint"] 52 | self.export_weights = config["trainer"]["export"] 53 | self.metrics = config["trainer"]["metrics"] 54 | logger = config["trainer"]["log"] 55 | 56 | if logger: 57 | self.logger = Logger(logger) 58 | 59 | if pretrained: 60 | weight_file = download_weights(config["pretrain"], quiet=config["quiet"]) 61 | self.load_weights(weight_file) 62 | 63 | self.iter = 0 64 | 65 | self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) 66 | self.scheduler = OneCycleLR( 67 | self.optimizer, total_steps=self.num_iters, **config["optimizer"] 68 | ) 69 | # self.optimizer = ScheduledOptim( 70 | # Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), 71 | # #config['transformer']['d_model'], 72 | # 512, 73 | # **config['optimizer']) 74 | 75 | self.criterion = LabelSmoothingLoss( 76 | len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1 77 | ) 78 | 79 | transforms = None 80 | if self.image_aug: 81 | transforms = augmentor 82 | 83 | self.train_gen = self.data_gen( 84 | "train_{}".format(self.dataset_name), 85 | self.data_root, 86 | self.train_annotation, 87 | self.masked_language_model, 88 | transform=transforms, 89 | ) 90 | if self.valid_annotation: 91 | self.valid_gen = self.data_gen( 92 | "valid_{}".format(self.dataset_name), 93 | self.data_root, 94 | self.valid_annotation, 95 | masked_language_model=False, 96 | ) 97 | 98 | self.train_losses = [] 99 | 100 | def train(self): 101 | total_loss = 0 102 | 103 | total_loader_time = 0 104 | total_gpu_time = 0 105 | best_acc = 0 106 | 107 | data_iter = iter(self.train_gen) 108 | for i in range(self.num_iters): 109 | self.iter += 1 110 | 111 | start = time.time() 112 | 113 | try: 114 | batch = next(data_iter) 115 | except StopIteration: 116 | data_iter = iter(self.train_gen) 117 | batch = next(data_iter) 118 | 119 | total_loader_time += time.time() - start 120 | 121 | start = time.time() 122 | loss = self.step(batch) 123 | total_gpu_time += time.time() - start 124 | 125 | total_loss += loss 126 | self.train_losses.append((self.iter, loss)) 127 | 128 | if self.iter % self.print_every == 0: 129 | info = "iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}".format( 130 | self.iter, 131 | total_loss / self.print_every, 132 | self.optimizer.param_groups[0]["lr"], 133 | total_loader_time, 134 | total_gpu_time, 135 | ) 136 | 137 | total_loss = 0 138 | total_loader_time = 0 139 | total_gpu_time = 0 140 | print(info) 141 | self.logger.log(info) 142 | 143 | if self.valid_annotation and self.iter % self.valid_every == 0: 144 | val_loss = self.validate() 145 | acc_full_seq, acc_per_char = self.precision(self.metrics) 146 | 147 | info = "iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}".format( 148 | self.iter, val_loss, acc_full_seq, acc_per_char 149 | ) 150 | print(info) 151 | self.logger.log(info) 152 | 153 | if acc_full_seq > best_acc: 154 | self.save_weights(self.export_weights) 155 | best_acc = acc_full_seq 156 | 157 | def validate(self): 158 | self.model.eval() 159 | 160 | total_loss = [] 161 | 162 | with torch.no_grad(): 163 | for step, batch in enumerate(self.valid_gen): 164 | batch = self.batch_to_device(batch) 165 | img, tgt_input, tgt_output, tgt_padding_mask = ( 166 | batch["img"], 167 | batch["tgt_input"], 168 | batch["tgt_output"], 169 | batch["tgt_padding_mask"], 170 | ) 171 | 172 | outputs = self.model(img, tgt_input, tgt_padding_mask) 173 | # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) 174 | 175 | outputs = outputs.flatten(0, 1) 176 | tgt_output = tgt_output.flatten() 177 | loss = self.criterion(outputs, tgt_output) 178 | 179 | total_loss.append(loss.item()) 180 | 181 | del outputs 182 | del loss 183 | 184 | total_loss = np.mean(total_loss) 185 | self.model.train() 186 | 187 | return total_loss 188 | 189 | def predict(self, sample=None): 190 | pred_sents = [] 191 | actual_sents = [] 192 | img_files = [] 193 | 194 | for batch in self.valid_gen: 195 | batch = self.batch_to_device(batch) 196 | 197 | if self.beamsearch: 198 | translated_sentence = batch_translate_beam_search( 199 | batch["img"], self.model 200 | ) 201 | prob = None 202 | else: 203 | translated_sentence, prob = translate(batch["img"], self.model) 204 | 205 | pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) 206 | actual_sent = self.vocab.batch_decode(batch["tgt_output"].tolist()) 207 | 208 | img_files.extend(batch["filenames"]) 209 | 210 | pred_sents.extend(pred_sent) 211 | actual_sents.extend(actual_sent) 212 | 213 | if sample != None and len(pred_sents) > sample: 214 | break 215 | 216 | return pred_sents, actual_sents, img_files, prob 217 | 218 | def precision(self, sample=None): 219 | 220 | pred_sents, actual_sents, _, _ = self.predict(sample=sample) 221 | 222 | acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode="full_sequence") 223 | acc_per_char = compute_accuracy(actual_sents, pred_sents, mode="per_char") 224 | 225 | return acc_full_seq, acc_per_char 226 | 227 | def visualize_prediction( 228 | self, sample=16, errorcase=False, fontname="serif", fontsize=16 229 | ): 230 | 231 | pred_sents, actual_sents, img_files, probs = self.predict(sample) 232 | 233 | if errorcase: 234 | wrongs = [] 235 | for i in range(len(img_files)): 236 | if pred_sents[i] != actual_sents[i]: 237 | wrongs.append(i) 238 | 239 | pred_sents = [pred_sents[i] for i in wrongs] 240 | actual_sents = [actual_sents[i] for i in wrongs] 241 | img_files = [img_files[i] for i in wrongs] 242 | probs = [probs[i] for i in wrongs] 243 | 244 | img_files = img_files[:sample] 245 | 246 | fontdict = {"family": fontname, "size": fontsize} 247 | 248 | for vis_idx in range(0, len(img_files)): 249 | img_path = img_files[vis_idx] 250 | pred_sent = pred_sents[vis_idx] 251 | actual_sent = actual_sents[vis_idx] 252 | prob = probs[vis_idx] 253 | 254 | img = Image.open(open(img_path, "rb")) 255 | plt.figure() 256 | plt.imshow(img) 257 | plt.title( 258 | "prob: {:.3f} - pred: {} - actual: {}".format( 259 | prob, pred_sent, actual_sent 260 | ), 261 | loc="left", 262 | fontdict=fontdict, 263 | ) 264 | plt.axis("off") 265 | 266 | plt.show() 267 | 268 | def visualize_dataset(self, sample=16, fontname="serif"): 269 | n = 0 270 | for batch in self.train_gen: 271 | for i in range(self.batch_size): 272 | img = batch["img"][i].numpy().transpose(1, 2, 0) 273 | sent = self.vocab.decode(batch["tgt_input"].T[i].tolist()) 274 | 275 | plt.figure() 276 | plt.title("sent: {}".format(sent), loc="center", fontname=fontname) 277 | plt.imshow(img) 278 | plt.axis("off") 279 | 280 | n += 1 281 | if n >= sample: 282 | plt.show() 283 | return 284 | 285 | def load_checkpoint(self, filename): 286 | checkpoint = torch.load(filename) 287 | 288 | optim = ScheduledOptim( 289 | Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), 290 | self.config["transformer"]["d_model"], 291 | **self.config["optimizer"] 292 | ) 293 | 294 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 295 | self.model.load_state_dict(checkpoint["state_dict"]) 296 | self.iter = checkpoint["iter"] 297 | 298 | self.train_losses = checkpoint["train_losses"] 299 | 300 | def save_checkpoint(self, filename): 301 | state = { 302 | "iter": self.iter, 303 | "state_dict": self.model.state_dict(), 304 | "optimizer": self.optimizer.state_dict(), 305 | "train_losses": self.train_losses, 306 | } 307 | 308 | path, _ = os.path.split(filename) 309 | os.makedirs(path, exist_ok=True) 310 | 311 | torch.save(state, filename) 312 | 313 | def load_weights(self, filename): 314 | state_dict = torch.load(filename, map_location=torch.device(self.device)) 315 | 316 | for name, param in self.model.named_parameters(): 317 | if name not in state_dict: 318 | print("{} not found".format(name)) 319 | elif state_dict[name].shape != param.shape: 320 | print( 321 | "{} missmatching shape, required {} but found {}".format( 322 | name, param.shape, state_dict[name].shape 323 | ) 324 | ) 325 | del state_dict[name] 326 | 327 | self.model.load_state_dict(state_dict, strict=False) 328 | 329 | def save_weights(self, filename): 330 | path, _ = os.path.split(filename) 331 | os.makedirs(path, exist_ok=True) 332 | 333 | torch.save(self.model.state_dict(), filename) 334 | 335 | def batch_to_device(self, batch): 336 | img = batch["img"].to(self.device, non_blocking=True) 337 | tgt_input = batch["tgt_input"].to(self.device, non_blocking=True) 338 | tgt_output = batch["tgt_output"].to(self.device, non_blocking=True) 339 | tgt_padding_mask = batch["tgt_padding_mask"].to(self.device, non_blocking=True) 340 | 341 | batch = { 342 | "img": img, 343 | "tgt_input": tgt_input, 344 | "tgt_output": tgt_output, 345 | "tgt_padding_mask": tgt_padding_mask, 346 | "filenames": batch["filenames"], 347 | } 348 | 349 | return batch 350 | 351 | def data_gen( 352 | self, 353 | lmdb_path, 354 | data_root, 355 | annotation, 356 | masked_language_model=True, 357 | transform=None, 358 | ): 359 | dataset = OCRDataset( 360 | lmdb_path=lmdb_path, 361 | root_dir=data_root, 362 | annotation_path=annotation, 363 | vocab=self.vocab, 364 | transform=transform, 365 | image_height=self.config["dataset"]["image_height"], 366 | image_min_width=self.config["dataset"]["image_min_width"], 367 | image_max_width=self.config["dataset"]["image_max_width"], 368 | ) 369 | 370 | sampler = ClusterRandomSampler(dataset, self.batch_size, True) 371 | collate_fn = Collator(masked_language_model) 372 | 373 | gen = DataLoader( 374 | dataset, 375 | batch_size=self.batch_size, 376 | sampler=sampler, 377 | collate_fn=collate_fn, 378 | shuffle=False, 379 | drop_last=False, 380 | **self.config["dataloader"] 381 | ) 382 | 383 | return gen 384 | 385 | def data_gen_v1(self, lmdb_path, data_root, annotation): 386 | data_gen = DataGen( 387 | data_root, 388 | annotation, 389 | self.vocab, 390 | "cpu", 391 | image_height=self.config["dataset"]["image_height"], 392 | image_min_width=self.config["dataset"]["image_min_width"], 393 | image_max_width=self.config["dataset"]["image_max_width"], 394 | ) 395 | 396 | return data_gen 397 | 398 | def step(self, batch): 399 | self.model.train() 400 | 401 | batch = self.batch_to_device(batch) 402 | img, tgt_input, tgt_output, tgt_padding_mask = ( 403 | batch["img"], 404 | batch["tgt_input"], 405 | batch["tgt_output"], 406 | batch["tgt_padding_mask"], 407 | ) 408 | 409 | outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) 410 | # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) 411 | outputs = outputs.view(-1, outputs.size(2)) # flatten(0, 1) 412 | tgt_output = tgt_output.view(-1) # flatten() 413 | 414 | loss = self.criterion(outputs, tgt_output) 415 | 416 | self.optimizer.zero_grad() 417 | 418 | loss.backward() 419 | 420 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) 421 | 422 | self.optimizer.step() 423 | self.scheduler.step() 424 | 425 | loss_item = loss.item() 426 | 427 | return loss_item 428 | -------------------------------------------------------------------------------- /vietocr/model/transformerocr.py: -------------------------------------------------------------------------------- 1 | from vietocr.model.backbone.cnn import CNN 2 | from vietocr.model.seqmodel.transformer import LanguageTransformer 3 | from vietocr.model.seqmodel.seq2seq import Seq2Seq 4 | from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq 5 | from torch import nn 6 | 7 | 8 | class VietOCR(nn.Module): 9 | def __init__( 10 | self, 11 | vocab_size, 12 | backbone, 13 | cnn_args, 14 | transformer_args, 15 | seq_modeling="transformer", 16 | ): 17 | 18 | super(VietOCR, self).__init__() 19 | 20 | self.cnn = CNN(backbone, **cnn_args) 21 | self.seq_modeling = seq_modeling 22 | 23 | if seq_modeling == "transformer": 24 | self.transformer = LanguageTransformer(vocab_size, **transformer_args) 25 | elif seq_modeling == "seq2seq": 26 | self.transformer = Seq2Seq(vocab_size, **transformer_args) 27 | elif seq_modeling == "convseq2seq": 28 | self.transformer = ConvSeq2Seq(vocab_size, **transformer_args) 29 | else: 30 | raise ("Not Support Seq Model") 31 | 32 | def forward(self, img, tgt_input, tgt_key_padding_mask): 33 | """ 34 | Shape: 35 | - img: (N, C, H, W) 36 | - tgt_input: (T, N) 37 | - tgt_key_padding_mask: (N, T) 38 | - output: b t v 39 | """ 40 | src = self.cnn(img) 41 | 42 | if self.seq_modeling == "transformer": 43 | outputs = self.transformer( 44 | src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask 45 | ) 46 | elif self.seq_modeling == "seq2seq": 47 | outputs = self.transformer(src, tgt_input) 48 | elif self.seq_modeling == "convseq2seq": 49 | outputs = self.transformer(src, tgt_input) 50 | return outputs 51 | -------------------------------------------------------------------------------- /vietocr/model/vocab.py: -------------------------------------------------------------------------------- 1 | class Vocab: 2 | def __init__(self, chars): 3 | self.pad = 0 4 | self.go = 1 5 | self.eos = 2 6 | self.mask_token = 3 7 | 8 | self.chars = chars 9 | 10 | self.c2i = {c: i + 4 for i, c in enumerate(chars)} 11 | 12 | self.i2c = {i + 4: c for i, c in enumerate(chars)} 13 | 14 | self.i2c[0] = "" 15 | self.i2c[1] = "" 16 | self.i2c[2] = "" 17 | self.i2c[3] = "*" 18 | 19 | def encode(self, chars): 20 | return [self.go] + [self.c2i[c] for c in chars] + [self.eos] 21 | 22 | def decode(self, ids): 23 | first = 1 if self.go in ids else 0 24 | last = ids.index(self.eos) if self.eos in ids else None 25 | sent = "".join([self.i2c[i] for i in ids[first:last]]) 26 | return sent 27 | 28 | def __len__(self): 29 | return len(self.c2i) + 4 30 | 31 | def batch_decode(self, arr): 32 | texts = [self.decode(ids) for ids in arr] 33 | return texts 34 | 35 | def __str__(self): 36 | return self.chars 37 | -------------------------------------------------------------------------------- /vietocr/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/optim/__init__.py -------------------------------------------------------------------------------- /vietocr/optim/labelsmoothingloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LabelSmoothingLoss(nn.Module): 6 | def __init__(self, classes, padding_idx, smoothing=0.0, dim=-1): 7 | super(LabelSmoothingLoss, self).__init__() 8 | self.confidence = 1.0 - smoothing 9 | self.smoothing = smoothing 10 | self.cls = classes 11 | self.dim = dim 12 | self.padding_idx = padding_idx 13 | 14 | def forward(self, pred, target): 15 | pred = pred.log_softmax(dim=self.dim) 16 | with torch.no_grad(): 17 | # true_dist = pred.data.clone() 18 | true_dist = torch.zeros_like(pred) 19 | true_dist.fill_(self.smoothing / (self.cls - 2)) 20 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 21 | true_dist[:, self.padding_idx] = 0 22 | mask = torch.nonzero(target.data == self.padding_idx, as_tuple=False) 23 | if mask.dim() > 0: 24 | true_dist.index_fill_(0, mask.squeeze(), 0.0) 25 | 26 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) 27 | -------------------------------------------------------------------------------- /vietocr/optim/optim.py: -------------------------------------------------------------------------------- 1 | class ScheduledOptim: 2 | """A simple wrapper class for learning rate scheduling""" 3 | 4 | def __init__(self, optimizer, d_model, init_lr, n_warmup_steps): 5 | assert n_warmup_steps > 0, "must be greater than 0" 6 | 7 | self._optimizer = optimizer 8 | self.init_lr = init_lr 9 | self.d_model = d_model 10 | self.n_warmup_steps = n_warmup_steps 11 | self.n_steps = 0 12 | 13 | def step(self): 14 | "Step with the inner optimizer" 15 | self._update_learning_rate() 16 | self._optimizer.step() 17 | 18 | def zero_grad(self): 19 | "Zero out the gradients with the inner optimizer" 20 | self._optimizer.zero_grad() 21 | 22 | def _get_lr_scale(self): 23 | d_model = self.d_model 24 | n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps 25 | return (d_model**-0.5) * min( 26 | n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5) 27 | ) 28 | 29 | def state_dict(self): 30 | optimizer_state_dict = { 31 | "init_lr": self.init_lr, 32 | "d_model": self.d_model, 33 | "n_warmup_steps": self.n_warmup_steps, 34 | "n_steps": self.n_steps, 35 | "_optimizer": self._optimizer.state_dict(), 36 | } 37 | 38 | return optimizer_state_dict 39 | 40 | def load_state_dict(self, state_dict): 41 | self.init_lr = state_dict["init_lr"] 42 | self.d_model = state_dict["d_model"] 43 | self.n_warmup_steps = state_dict["n_warmup_steps"] 44 | self.n_steps = state_dict["n_steps"] 45 | 46 | self._optimizer.load_state_dict(state_dict["_optimizer"]) 47 | 48 | def _update_learning_rate(self): 49 | """Learning rate scheduling per step""" 50 | 51 | self.n_steps += 1 52 | 53 | for param_group in self._optimizer.param_groups: 54 | lr = self.init_lr * self._get_lr_scale() 55 | self.lr = lr 56 | 57 | param_group["lr"] = lr 58 | -------------------------------------------------------------------------------- /vietocr/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | 4 | from vietocr.tool.predictor import Predictor 5 | from vietocr.tool.config import Cfg 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--img", required=True, help="foo help") 11 | parser.add_argument("--config", required=True, help="foo help") 12 | 13 | args = parser.parse_args() 14 | config = Cfg.load_config_from_file(args.config) 15 | 16 | detector = Predictor(config) 17 | 18 | img = Image.open(args.img) 19 | s = detector.predict(img) 20 | 21 | print(s) 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /vietocr/requirement.txt: -------------------------------------------------------------------------------- 1 | einops==0.2.0 2 | -------------------------------------------------------------------------------- /vietocr/tests/image/001099025107.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/001099025107.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/026301003919.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/026301003919.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/036170002830.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/036170002830.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/038078002355.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/038078002355.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/038089010274.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/038089010274.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/038144000109.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/038144000109.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/060085000115.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/060085000115.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/072183002222.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/072183002222.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/079084000809.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/079084000809.jpeg -------------------------------------------------------------------------------- /vietocr/tests/image/079193002341.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/079193002341.jpeg -------------------------------------------------------------------------------- /vietocr/tests/sample.txt: -------------------------------------------------------------------------------- 1 | ./image/036170002830.jpeg HOÀNG THỊ THOI 2 | ./image/079193002341.jpeg TRỊNH THỊ THÚY HẰNG 3 | ./image/001099025107.jpeg NGUYỄN VĂN BÌNH 4 | ./image/060085000115.jpeg NGUYỄN MINH TOÀN 5 | ./image/026301003919.jpeg NGUYỄN THỊ KIỀU TRANG 6 | ./image/079084000809.jpeg LÊ NGỌC PHƯƠNG KHANH 7 | ./image/038144000109.jpeg ĐÀO THỊ TƠ 8 | ./image/072183002222.jpeg NGUYỄN THANH PHƯỚC 9 | ./image/038078002355.jpeg HÀ ĐÌNH LỢI 10 | ./image/038089010274.jpeg HÀ VĂN LUÂN 11 | -------------------------------------------------------------------------------- /vietocr/tests/utest.py: -------------------------------------------------------------------------------- 1 | from vietocr.loader.dataloader_v1 import DataGen 2 | from vietocr.model.vocab import Vocab 3 | 4 | 5 | def test_loader(): 6 | chars = "aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ " 7 | 8 | vocab = Vocab(chars) 9 | s_gen = DataGen("./vietocr/tests/", "sample.txt", vocab, "cpu", 32, 512) 10 | 11 | iterator = s_gen.gen(30) 12 | for batch in iterator: 13 | assert batch["img"].shape[1] == 3, "image must have 3 channels" 14 | assert batch["img"].shape[2] == 32, "the height must be 32" 15 | print( 16 | batch["img"].shape, 17 | batch["tgt_input"].shape, 18 | batch["tgt_output"].shape, 19 | batch["tgt_padding_mask"].shape, 20 | ) 21 | 22 | 23 | if __name__ == "__main__": 24 | test_loader() 25 | -------------------------------------------------------------------------------- /vietocr/tool/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tool/__init__.py -------------------------------------------------------------------------------- /vietocr/tool/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from vietocr.tool.utils import download_config 3 | 4 | url_config = { 5 | "vgg_transformer": "vgg-transformer.yml", 6 | "resnet_transformer": "resnet_transformer.yml", 7 | "resnet_fpn_transformer": "resnet_fpn_transformer.yml", 8 | "vgg_seq2seq": "vgg-seq2seq.yml", 9 | "vgg_convseq2seq": "vgg_convseq2seq.yml", 10 | "vgg_decoderseq2seq": "vgg_decoderseq2seq.yml", 11 | "base": "base.yml", 12 | } 13 | 14 | 15 | class Cfg(dict): 16 | def __init__(self, config_dict): 17 | super(Cfg, self).__init__(**config_dict) 18 | self.__dict__ = self 19 | 20 | @staticmethod 21 | def load_config_from_file(fname): 22 | # base_config = download_config(url_config['base']) 23 | base_config = {} 24 | with open(fname, encoding="utf-8") as f: 25 | config = yaml.safe_load(f) 26 | base_config.update(config) 27 | 28 | return Cfg(base_config) 29 | 30 | @staticmethod 31 | def load_config_from_name(name): 32 | base_config = download_config(url_config["base"]) 33 | config = download_config(url_config[name]) 34 | 35 | base_config.update(config) 36 | return Cfg(base_config) 37 | 38 | def save(self, fname): 39 | with open(fname, "w", encoding="utf-8") as outfile: 40 | yaml.dump(dict(self), outfile, default_flow_style=False, allow_unicode=True) 41 | -------------------------------------------------------------------------------- /vietocr/tool/create_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import lmdb # install lmdb by "pip install lmdb" 4 | import cv2 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | def checkImageIsValid(imageBin): 10 | isvalid = True 11 | imgH = None 12 | imgW = None 13 | 14 | imageBuf = np.fromstring(imageBin, dtype=np.uint8) 15 | try: 16 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 17 | 18 | imgH, imgW = img.shape[0], img.shape[1] 19 | if imgH * imgW == 0: 20 | isvalid = False 21 | except Exception as e: 22 | isvalid = False 23 | 24 | return isvalid, imgH, imgW 25 | 26 | 27 | def writeCache(env, cache): 28 | with env.begin(write=True) as txn: 29 | for k, v in cache.items(): 30 | txn.put(k.encode(), v) 31 | 32 | 33 | def createDataset(outputPath, root_dir, annotation_path): 34 | """ 35 | Create LMDB dataset for CRNN training. 36 | ARGS: 37 | outputPath : LMDB output path 38 | imagePathList : list of image path 39 | labelList : list of corresponding groundtruth texts 40 | lexiconList : (optional) list of lexicon lists 41 | checkValid : if true, check the validity of every image 42 | """ 43 | 44 | annotation_path = os.path.join(root_dir, annotation_path) 45 | with open(annotation_path, "r", encoding="utf-8") as ann_file: 46 | lines = ann_file.readlines() 47 | annotations = [l.strip().split("\t") for l in lines] 48 | 49 | nSamples = len(annotations) 50 | env = lmdb.open(outputPath, map_size=1099511627776) 51 | cache = {} 52 | cnt = 0 53 | error = 0 54 | 55 | pbar = tqdm(range(nSamples), ncols=100, desc="Create {}".format(outputPath)) 56 | for i in pbar: 57 | imageFile, label = annotations[i] 58 | imagePath = os.path.join(root_dir, imageFile) 59 | 60 | if not os.path.exists(imagePath): 61 | error += 1 62 | continue 63 | 64 | with open(imagePath, "rb") as f: 65 | imageBin = f.read() 66 | isvalid, imgH, imgW = checkImageIsValid(imageBin) 67 | 68 | if not isvalid: 69 | error += 1 70 | continue 71 | 72 | imageKey = "image-%09d" % cnt 73 | labelKey = "label-%09d" % cnt 74 | pathKey = "path-%09d" % cnt 75 | dimKey = "dim-%09d" % cnt 76 | 77 | cache[imageKey] = imageBin 78 | cache[labelKey] = label.encode() 79 | cache[pathKey] = imageFile.encode() 80 | cache[dimKey] = np.array([imgH, imgW], dtype=np.int32).tobytes() 81 | 82 | cnt += 1 83 | 84 | if cnt % 1000 == 0: 85 | writeCache(env, cache) 86 | cache = {} 87 | 88 | nSamples = cnt - 1 89 | cache["num-samples"] = str(nSamples).encode() 90 | writeCache(env, cache) 91 | 92 | if error > 0: 93 | print("Remove {} invalid images".format(error)) 94 | print("Created dataset with %d samples" % nSamples) 95 | sys.stdout.flush() 96 | -------------------------------------------------------------------------------- /vietocr/tool/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Logger: 5 | def __init__(self, fname): 6 | path, _ = os.path.split(fname) 7 | os.makedirs(path, exist_ok=True) 8 | 9 | self.logger = open(fname, "w") 10 | 11 | def log(self, string): 12 | self.logger.write(string + "\n") 13 | self.logger.flush() 14 | 15 | def close(self): 16 | self.logger.close() 17 | -------------------------------------------------------------------------------- /vietocr/tool/predictor.py: -------------------------------------------------------------------------------- 1 | from vietocr.tool.translate import ( 2 | build_model, 3 | translate, 4 | translate_beam_search, 5 | process_input, 6 | predict, 7 | ) 8 | from vietocr.tool.utils import download_weights 9 | 10 | import torch 11 | from collections import defaultdict 12 | 13 | 14 | class Predictor: 15 | def __init__(self, config): 16 | 17 | device = config["device"] 18 | 19 | model, vocab = build_model(config) 20 | weights = "/tmp/weights.pth" 21 | 22 | if config["weights"].startswith("http"): 23 | weights = download_weights(config["weights"]) 24 | else: 25 | weights = config["weights"] 26 | 27 | model.load_state_dict(torch.load(weights, map_location=torch.device(device))) 28 | 29 | self.config = config 30 | self.model = model 31 | self.vocab = vocab 32 | self.device = device 33 | 34 | def predict(self, img, return_prob=False): 35 | img = process_input( 36 | img, 37 | self.config["dataset"]["image_height"], 38 | self.config["dataset"]["image_min_width"], 39 | self.config["dataset"]["image_max_width"], 40 | ) 41 | img = img.to(self.config["device"]) 42 | 43 | if self.config["predictor"]["beamsearch"]: 44 | sent = translate_beam_search(img, self.model) 45 | s = sent 46 | prob = None 47 | else: 48 | s, prob = translate(img, self.model) 49 | s = s[0].tolist() 50 | prob = prob[0] 51 | 52 | s = self.vocab.decode(s) 53 | 54 | if return_prob: 55 | return s, prob 56 | else: 57 | return s 58 | 59 | def predict_batch(self, imgs, return_prob=False): 60 | bucket = defaultdict(list) 61 | bucket_idx = defaultdict(list) 62 | bucket_pred = {} 63 | 64 | sents, probs = [0] * len(imgs), [0] * len(imgs) 65 | 66 | for i, img in enumerate(imgs): 67 | img = process_input( 68 | img, 69 | self.config["dataset"]["image_height"], 70 | self.config["dataset"]["image_min_width"], 71 | self.config["dataset"]["image_max_width"], 72 | ) 73 | 74 | bucket[img.shape[-1]].append(img) 75 | bucket_idx[img.shape[-1]].append(i) 76 | 77 | for k, batch in bucket.items(): 78 | batch = torch.cat(batch, 0).to(self.device) 79 | s, prob = translate(batch, self.model) 80 | prob = prob.tolist() 81 | 82 | s = s.tolist() 83 | s = self.vocab.batch_decode(s) 84 | 85 | bucket_pred[k] = (s, prob) 86 | 87 | for k in bucket_pred: 88 | idx = bucket_idx[k] 89 | sent, prob = bucket_pred[k] 90 | for i, j in enumerate(idx): 91 | sents[j] = sent[i] 92 | probs[j] = prob[i] 93 | 94 | if return_prob: 95 | return sents, probs 96 | else: 97 | return sents 98 | -------------------------------------------------------------------------------- /vietocr/tool/translate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | from PIL import Image 5 | from torch.nn.functional import log_softmax, softmax 6 | 7 | from vietocr.model.transformerocr import VietOCR 8 | from vietocr.model.vocab import Vocab 9 | from vietocr.model.beam import Beam 10 | 11 | 12 | def batch_translate_beam_search( 13 | img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2 14 | ): 15 | # img: NxCxHxW 16 | model.eval() 17 | device = img.device 18 | sents = [] 19 | 20 | with torch.no_grad(): 21 | src = model.cnn(img) 22 | print(src.shap) 23 | memories = model.transformer.forward_encoder(src) 24 | for i in range(src.size(0)): 25 | # memory = memories[:,i,:].repeat(1, beam_size, 1) # TxNxE 26 | memory = model.transformer.get_memory(memories, i) 27 | sent = beamsearch( 28 | memory, 29 | model, 30 | device, 31 | beam_size, 32 | candidates, 33 | max_seq_length, 34 | sos_token, 35 | eos_token, 36 | ) 37 | sents.append(sent) 38 | 39 | sents = np.asarray(sents) 40 | 41 | return sents 42 | 43 | 44 | def translate_beam_search( 45 | img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2 46 | ): 47 | # img: 1xCxHxW 48 | model.eval() 49 | device = img.device 50 | 51 | with torch.no_grad(): 52 | src = model.cnn(img) 53 | memory = model.transformer.forward_encoder(src) # TxNxE 54 | sent = beamsearch( 55 | memory, 56 | model, 57 | device, 58 | beam_size, 59 | candidates, 60 | max_seq_length, 61 | sos_token, 62 | eos_token, 63 | ) 64 | 65 | return sent 66 | 67 | 68 | def beamsearch( 69 | memory, 70 | model, 71 | device, 72 | beam_size=4, 73 | candidates=1, 74 | max_seq_length=128, 75 | sos_token=1, 76 | eos_token=2, 77 | ): 78 | # memory: Tx1xE 79 | model.eval() 80 | 81 | beam = Beam( 82 | beam_size=beam_size, 83 | min_length=0, 84 | n_top=candidates, 85 | ranker=None, 86 | start_token_id=sos_token, 87 | end_token_id=eos_token, 88 | ) 89 | 90 | with torch.no_grad(): 91 | # memory = memory.repeat(1, beam_size, 1) # TxNxE 92 | memory = model.transformer.expand_memory(memory, beam_size) 93 | 94 | for _ in range(max_seq_length): 95 | 96 | tgt_inp = beam.get_current_state().transpose(0, 1).to(device) # TxN 97 | decoder_outputs, memory = model.transformer.forward_decoder(tgt_inp, memory) 98 | 99 | log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0), dim=-1) 100 | beam.advance(log_prob.cpu()) 101 | 102 | if beam.done(): 103 | break 104 | 105 | scores, ks = beam.sort_finished(minimum=1) 106 | 107 | hypothesises = [] 108 | for i, (times, k) in enumerate(ks[:candidates]): 109 | hypothesis = beam.get_hypothesis(times, k) 110 | hypothesises.append(hypothesis) 111 | 112 | return [1] + [int(i) for i in hypothesises[0][:-1]] 113 | 114 | 115 | def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2): 116 | "data: BxCXHxW" 117 | model.eval() 118 | device = img.device 119 | 120 | with torch.no_grad(): 121 | src = model.cnn(img) 122 | memory = model.transformer.forward_encoder(src) 123 | 124 | translated_sentence = [[sos_token] * len(img)] 125 | char_probs = [[1] * len(img)] 126 | 127 | max_length = 0 128 | 129 | while max_length <= max_seq_length and not all( 130 | np.any(np.asarray(translated_sentence).T == eos_token, axis=1) 131 | ): 132 | 133 | tgt_inp = torch.LongTensor(translated_sentence).to(device) 134 | 135 | # output = model(img, tgt_inp, tgt_key_padding_mask=None) 136 | # output = model.transformer(src, tgt_inp, tgt_key_padding_mask=None) 137 | output, memory = model.transformer.forward_decoder(tgt_inp, memory) 138 | output = softmax(output, dim=-1) 139 | output = output.to("cpu") 140 | 141 | values, indices = torch.topk(output, 5) 142 | 143 | indices = indices[:, -1, 0] 144 | indices = indices.tolist() 145 | 146 | values = values[:, -1, 0] 147 | values = values.tolist() 148 | char_probs.append(values) 149 | 150 | translated_sentence.append(indices) 151 | max_length += 1 152 | 153 | del output 154 | 155 | translated_sentence = np.asarray(translated_sentence).T 156 | 157 | char_probs = np.asarray(char_probs).T 158 | char_probs = np.multiply(char_probs, translated_sentence > 3) 159 | char_probs = np.sum(char_probs, axis=-1) / (char_probs > 0).sum(-1) 160 | 161 | return translated_sentence, char_probs 162 | 163 | 164 | def build_model(config): 165 | vocab = Vocab(config["vocab"]) 166 | device = config["device"] 167 | 168 | model = VietOCR( 169 | len(vocab), 170 | config["backbone"], 171 | config["cnn"], 172 | config["transformer"], 173 | config["seq_modeling"], 174 | ) 175 | 176 | model = model.to(device) 177 | 178 | return model, vocab 179 | 180 | 181 | def resize(w, h, expected_height, image_min_width, image_max_width): 182 | new_w = int(expected_height * float(w) / float(h)) 183 | round_to = 10 184 | new_w = math.ceil(new_w / round_to) * round_to 185 | new_w = max(new_w, image_min_width) 186 | new_w = min(new_w, image_max_width) 187 | 188 | return new_w, expected_height 189 | 190 | 191 | def process_image(image, image_height, image_min_width, image_max_width): 192 | img = image.convert("RGB") 193 | 194 | w, h = img.size 195 | new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width) 196 | 197 | img = img.resize((new_w, image_height), Image.LANCZOS) 198 | 199 | img = np.asarray(img).transpose(2, 0, 1) 200 | img = img / 255 201 | return img 202 | 203 | 204 | def process_input(image, image_height, image_min_width, image_max_width): 205 | img = process_image(image, image_height, image_min_width, image_max_width) 206 | img = img[np.newaxis, ...] 207 | img = torch.FloatTensor(img) 208 | return img 209 | 210 | 211 | def predict(filename, config): 212 | img = Image.open(filename) 213 | img = process_input(img) 214 | 215 | img = img.to(config["device"]) 216 | 217 | model, vocab = build_model(config) 218 | s = translate(img, model)[0].tolist() 219 | s = vocab.decode(s) 220 | 221 | return s 222 | -------------------------------------------------------------------------------- /vietocr/tool/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import yaml 4 | import numpy as np 5 | import uuid 6 | import requests 7 | import tempfile 8 | from tqdm import tqdm 9 | 10 | 11 | def download_weights(uri, cached=None, md5=None, quiet=False): 12 | if uri.startswith("http"): 13 | return download(url=uri, quiet=quiet) 14 | return uri 15 | 16 | 17 | def download(url, quiet=False): 18 | tmp_dir = tempfile.gettempdir() 19 | filename = url.split("/")[-1] 20 | full_path = os.path.join(tmp_dir, filename) 21 | 22 | if os.path.exists(full_path): 23 | print("Model weight {} exsits. Ignore download!".format(full_path)) 24 | return full_path 25 | 26 | with requests.get(url, stream=True) as r: 27 | r.raise_for_status() 28 | with open(full_path, "wb") as f: 29 | for chunk in tqdm(r.iter_content(chunk_size=8192)): 30 | # If you have chunk encoded response uncomment if 31 | # and set chunk_size parameter to None. 32 | # if chunk: 33 | f.write(chunk) 34 | return full_path 35 | 36 | 37 | def download_config(id): 38 | url = "https://vocr.vn/data/vietocr/config/{}".format(id) 39 | r = requests.get(url) 40 | config = yaml.safe_load(r.text) 41 | return config 42 | 43 | 44 | def compute_accuracy(ground_truth, predictions, mode="full_sequence"): 45 | """ 46 | Computes accuracy 47 | :param ground_truth: 48 | :param predictions: 49 | :param display: Whether to print values to stdout 50 | :param mode: if 'per_char' is selected then 51 | single_label_accuracy = correct_predicted_char_nums_of_single_sample / single_label_char_nums 52 | avg_label_accuracy = sum(single_label_accuracy) / label_nums 53 | if 'full_sequence' is selected then 54 | single_label_accuracy = 1 if the prediction result is exactly the same as label else 0 55 | avg_label_accuracy = sum(single_label_accuracy) / label_nums 56 | :return: avg_label_accuracy 57 | """ 58 | if mode == "per_char": 59 | 60 | accuracy = [] 61 | 62 | for index, label in enumerate(ground_truth): 63 | prediction = predictions[index] 64 | total_count = len(label) 65 | correct_count = 0 66 | try: 67 | for i, tmp in enumerate(label): 68 | if tmp == prediction[i]: 69 | correct_count += 1 70 | except IndexError: 71 | continue 72 | finally: 73 | try: 74 | accuracy.append(correct_count / total_count) 75 | except ZeroDivisionError: 76 | if len(prediction) == 0: 77 | accuracy.append(1) 78 | else: 79 | accuracy.append(0) 80 | avg_accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0) 81 | elif mode == "full_sequence": 82 | try: 83 | correct_count = 0 84 | for index, label in enumerate(ground_truth): 85 | prediction = predictions[index] 86 | if prediction == label: 87 | correct_count += 1 88 | avg_accuracy = correct_count / len(ground_truth) 89 | except ZeroDivisionError: 90 | if not predictions: 91 | avg_accuracy = 1 92 | else: 93 | avg_accuracy = 0 94 | else: 95 | raise NotImplementedError( 96 | "Other accuracy compute mode has not been implemented" 97 | ) 98 | 99 | return avg_accuracy 100 | -------------------------------------------------------------------------------- /vietocr/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from vietocr.model.trainer import Trainer 4 | from vietocr.tool.config import Cfg 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--config", required=True, help="see example at ") 10 | parser.add_argument("--checkpoint", required=False, help="your checkpoint") 11 | 12 | args = parser.parse_args() 13 | config = Cfg.load_config_from_file(args.config) 14 | 15 | trainer = Trainer(config) 16 | 17 | if args.checkpoint: 18 | trainer.load_checkpoint(args.checkpoint) 19 | 20 | trainer.train() 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | --------------------------------------------------------------------------------