├── .gitignore
├── LICENSE
├── README.md
├── config
├── __init__.py
├── base.yml
├── resnet-transformer.yml
├── resnet_fpn_transformer.yml
├── vgg-convseq2seq.yml
├── vgg-seq2seq.yml
└── vgg-transformer.yml
├── image
├── .keep
├── sample.png
└── vietocr.jpg
├── setup.py
├── vietocr
├── __init__.py
├── loader
│ ├── __init__.py
│ ├── aug.py
│ ├── dataloader.py
│ └── dataloader_v1.py
├── model
│ ├── __init__.py
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── cnn.py
│ │ ├── resnet.py
│ │ └── vgg.py
│ ├── beam.py
│ ├── seqmodel
│ │ ├── __init__.py
│ │ ├── convseq2seq.py
│ │ ├── seq2seq.py
│ │ └── transformer.py
│ ├── trainer.py
│ ├── transformerocr.py
│ └── vocab.py
├── optim
│ ├── __init__.py
│ ├── labelsmoothingloss.py
│ └── optim.py
├── predict.py
├── requirement.txt
├── tests
│ ├── image
│ │ ├── 001099025107.jpeg
│ │ ├── 026301003919.jpeg
│ │ ├── 036170002830.jpeg
│ │ ├── 038078002355.jpeg
│ │ ├── 038089010274.jpeg
│ │ ├── 038144000109.jpeg
│ │ ├── 060085000115.jpeg
│ │ ├── 072183002222.jpeg
│ │ ├── 079084000809.jpeg
│ │ └── 079193002341.jpeg
│ ├── sample.txt
│ └── utest.py
├── tool
│ ├── __init__.py
│ ├── config.py
│ ├── create_dataset.py
│ ├── logger.py
│ ├── predictor.py
│ ├── translate.py
│ └── utils.py
└── train.py
└── vietocr_gettingstart.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VietOCR
2 | **[DORI](https://dorify.net/vi) là end-to-end OCR platform, hỗ trợ các bạn đánh nhãn, huấn luyện, deploy mô hình dễ dàng.**
3 | - Tham khảo tại [dorify.net](https://dorify.net/vi).
4 | - Tài liệu hướng dẫn [tại đây](https://pbcquoc.github.io/dori_guideline/).
5 |
6 | ----
7 | **Các bạn vui lòng cập nhật lên version mới nhất để không xảy ra lỗi.**
8 |
9 |
10 |
11 |
12 |
13 | Trong project này, mình cài đặt mô hình Transformer OCR nhận dạng chữ viết tay, chữ đánh máy cho Tiếng Việt. Kiến trúc mô hình là sự kết hợp tuyệt vời giữ mô hình CNN và Transformer (là mô hình nền tảng của BERT khá nổi tiếng). Mô hình TransformerOCR có rất nhiều ưu điểm so với kiến trúc của mô hình CRNN đã được mình cài đặt. Các bạn có thể đọc [tại](https://pbcquoc.github.io/vietocr) đây về kiến trúc và cách huấn luyện mô hình với các tập dữ liệu khác nhau.
14 |
15 | Mô hình VietOCR có tính tổng quát cực tốt, thậm chí có độ chính xác khá cao trên một bộ dataset mới mặc dù mô hình chưa được huấn luyện bao giờ.
16 |
17 |
18 |
19 |
20 |
21 | # Cài Đặt
22 | Để cài đặt các bạn gõ lệnh sau
23 | ```
24 | pip install vietocr
25 | ```
26 | # Quick Start
27 | Các bạn tham khảo notebook [này](https://github.com/pbcquoc/vietocr/blob/master/vietocr_gettingstart.ipynb) để biết cách sử dụng nhé.
28 | # Cách tạo file train/test
29 | File train/test có 2 cột, cột đầu tiên là tên file, cột thứ 2 là nhãn(không chứa kí tự \t), 2 cột này cách nhau bằng \t
30 | ```
31 | 20160518_0151_25432_1_tg_3_5.png để nghe phổ biến chủ trương của UBND tỉnh Phú Yên
32 | 20160421_0102_25464_2_tg_0_4.png môi trường lại đều đồng thanh
33 | ```
34 | Tham khảo file mẫu tại [đây](https://vocr.vn/data/vietocr/data_line.zip)
35 |
36 | # Model Zoo
37 | Thư viện này cài đặt cả 2 kiểu seq model đó là attention seq2seq và transfomer. Seq2seq có tốc độ dự đoán rất nhanh và được dùng trong industry khá nhiều, tuy nhiên transformer lại chính xác hơn nhưng lúc dự đoán lại khá chậm. Do đó mình cung cấp cả 2 loại cho các bạn lựa chọn.
38 |
39 | Mô hình này được huấn luyện trên tập dữ liệu gồm 10m ảnh, bao gồm nhiều loại ảnh khác nhau như ảnh tự phát sinh, chữ viết tay, các văn bản scan thực tế.
40 | Pretrain model được cung cấp sẵn.
41 |
42 | # Kết quả thử nghiệm trên tập 10m
43 | | Backbone | Config | Precision full sequence | time |
44 | | ------------- |:-------------:| ---:|---:|
45 | | VGG19-bn - Transformer | vgg_transformer | 0.8800 | 86ms @ 1080ti |
46 | | VGG19-bn - Seq2Seq | vgg_seq2seq | 0.8701 | 12ms @ 1080ti |
47 |
48 | Thời gian dự đoán của mô hình vgg-transformer quá lâu so với mô hình seq2seq, trong khi đó không có sự khác biệt rõ ràng giữ độ chính xác của 2 loại kiến trúc này.
49 |
50 | # Dataset
51 | Mình chỉ cung cấp tập dữ liệu mẫu khoảng 1m ảnh tự phát sinh. Các bạn có thể tải về tại [đây](https://drive.google.com/file/d/1T0cmkhTgu3ahyMIwGZeby612RpVdDxOR/view).
52 | # License
53 | Mình phát hành thư viện này dưới các điều khoản của [Apache 2.0 license]().
54 |
55 | # Liên hệ
56 | Nếu bạn có bất kì vấn đề gì, vui lòng tạo issue hoặc liên hệ mình tại pbcquoc@gmail.com
57 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/config/__init__.py
--------------------------------------------------------------------------------
/config/base.yml:
--------------------------------------------------------------------------------
1 | # change to list chars of your dataset or use default vietnamese chars
2 | vocab: 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
3 |
4 | # cpu, cuda, cuda:0
5 | device: cuda:0
6 |
7 | seq_modeling: transformer
8 | transformer:
9 | d_model: 256
10 | nhead: 8
11 | num_encoder_layers: 6
12 | num_decoder_layers: 6
13 | dim_feedforward: 2048
14 | max_seq_length: 1024
15 | pos_dropout: 0.1
16 | trans_dropout: 0.1
17 |
18 | optimizer:
19 | max_lr: 0.0003
20 | pct_start: 0.1
21 |
22 | trainer:
23 | batch_size: 32
24 | print_every: 200
25 | valid_every: 4000
26 | iters: 100000
27 | # where to save our model for prediction
28 | export: ./weights/transformerocr.pth
29 | checkpoint: ./checkpoint/transformerocr_checkpoint.pth
30 | log: ./train.log
31 | # null to disable compuate accuracy, or change to number of sample to enable validiation while training
32 | metrics: null
33 |
34 | dataset:
35 | # name of your dataset
36 | name: data
37 | # path to annotation and image
38 | data_root: ./img/
39 | train_annotation: annotation_train.txt
40 | valid_annotation: annotation_val_small.txt
41 | # resize image to 32 height, larger height will increase accuracy
42 | image_height: 32
43 | image_min_width: 32
44 | image_max_width: 512
45 |
46 | dataloader:
47 | num_workers: 3
48 | pin_memory: True
49 |
50 | aug:
51 | image_aug: true
52 | masked_language_model: true
53 |
54 | predictor:
55 | # disable or enable beamsearch while prediction, use beamsearch will be slower
56 | beamsearch: False
57 |
58 | quiet: False
59 |
--------------------------------------------------------------------------------
/config/resnet-transformer.yml:
--------------------------------------------------------------------------------
1 | pretrain:
2 | id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
3 | md5: 7068030afe2e8fc639d0e1e2c25612b3
4 | cached: /tmp/tranformerorc.pth
5 |
6 | weights: https://drive.google.com/uc?id=12dTOZ9VP7ZVzwQgVvqBWz5JO5RXXW5NY
7 |
8 | backbone: resnet50
9 | cnn:
10 | ss:
11 | - [2, 2]
12 | - [2, 1]
13 | - [2, 1]
14 | - [2, 1]
15 | - [1, 1]
16 | hidden: 256
17 |
--------------------------------------------------------------------------------
/config/resnet_fpn_transformer.yml:
--------------------------------------------------------------------------------
1 | pretrain:
2 | id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
3 | md5: 7068030afe2e8fc639d0e1e2c25612b3
4 | cached: /tmp/tranformerorc.pth
5 |
6 | weights: https://drive.google.com/uc?id=12dTOZ9VP7ZVzwQgVvqBWz5JO5RXXW5NY
7 |
8 | backbone: resnet50_fpn
9 | cnn: {}
10 |
--------------------------------------------------------------------------------
/config/vgg-convseq2seq.yml:
--------------------------------------------------------------------------------
1 | pretrain:
2 | id_or_url: 13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
3 | md5: fbefa85079ad9001a71eb1bf47a93785
4 | cached: /tmp/tranformerorc.pth
5 |
6 | # url or local path
7 | weights: https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
8 |
9 | backbone: vgg19_bn
10 | cnn:
11 | # pooling stride size
12 | ss:
13 | - [2, 2]
14 | - [2, 2]
15 | - [2, 1]
16 | - [2, 1]
17 | - [1, 1]
18 | # pooling kernel size
19 | ks:
20 | - [2, 2]
21 | - [2, 2]
22 | - [2, 1]
23 | - [2, 1]
24 | - [1, 1]
25 | # dim of ouput feature map
26 | hidden: 256
27 |
28 | seq_modeling: convseq2seq
29 | transformer:
30 | emb_dim: 256
31 | hid_dim: 512
32 | enc_layers: 10
33 | dec_layers: 10
34 | enc_kernel_size: 3
35 | dec_kernel_size: 3
36 | dropout: 0.1
37 | pad_idx: 0
38 | device: cuda:1
39 | enc_max_length: 512
40 | dec_max_length: 512
41 |
--------------------------------------------------------------------------------
/config/vgg-seq2seq.yml:
--------------------------------------------------------------------------------
1 | # for train
2 | pretrain: https://vocr.vn/data/vietocr/vgg_seq2seq.pth
3 |
4 | # url or local path (for predict)
5 | weights: https://vocr.vn/data/vietocr/vgg_seq2seq.pth
6 |
7 | backbone: vgg19_bn
8 | cnn:
9 | # pooling stride size
10 | ss:
11 | - [2, 2]
12 | - [2, 2]
13 | - [2, 1]
14 | - [2, 1]
15 | - [1, 1]
16 | # pooling kernel size
17 | ks:
18 | - [2, 2]
19 | - [2, 2]
20 | - [2, 1]
21 | - [2, 1]
22 | - [1, 1]
23 | # dim of ouput feature map
24 | hidden: 256
25 |
26 | seq_modeling: seq2seq
27 | transformer:
28 | encoder_hidden: 256
29 | decoder_hidden: 256
30 | img_channel: 256
31 | decoder_embedded: 256
32 | dropout: 0.1
33 |
34 | optimizer:
35 | max_lr: 0.001
36 | pct_start: 0.1
37 |
--------------------------------------------------------------------------------
/config/vgg-transformer.yml:
--------------------------------------------------------------------------------
1 | # for training
2 | pretrain: https://vocr.vn/data/vietocr/vgg_transformer.pth
3 |
4 | # url or local path (predict)
5 | weights: https://vocr.vn/data/vietocr/vgg_transformer.pth
6 |
7 | backbone: vgg19_bn
8 | cnn:
9 | pretrained: True
10 | # pooling stride size
11 | ss:
12 | - [2, 2]
13 | - [2, 2]
14 | - [2, 1]
15 | - [2, 1]
16 | - [1, 1]
17 | # pooling kernel size
18 | ks:
19 | - [2, 2]
20 | - [2, 2]
21 | - [2, 1]
22 | - [2, 1]
23 | - [1, 1]
24 | # dim of ouput feature map
25 | hidden: 256
26 |
27 |
--------------------------------------------------------------------------------
/image/.keep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/image/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/image/sample.png
--------------------------------------------------------------------------------
/image/vietocr.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/image/vietocr.jpg
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="vietocr",
8 | version="0.3.13",
9 | author="pbcquoc",
10 | author_email="pbcquoc@gmail.com",
11 | description="Transformer base text detection",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/pbcquoc/vietocr",
15 | packages=setuptools.find_packages(),
16 | install_requires=[
17 | "einops>=0.2.0",
18 | "gdown>=4.4.0",
19 | "albumentations>=1.4.2",
20 | "lmdb>=1.0.0",
21 | "scikit-image>=0.21.0",
22 | "pillow>=10.2.0",
23 | ],
24 | classifiers=[
25 | "Programming Language :: Python :: 3",
26 | "License :: OSI Approved :: MIT License",
27 | "Operating System :: OS Independent",
28 | ],
29 | python_requires=">=3.6",
30 | )
31 |
--------------------------------------------------------------------------------
/vietocr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/__init__.py
--------------------------------------------------------------------------------
/vietocr/loader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/loader/__init__.py
--------------------------------------------------------------------------------
/vietocr/loader/aug.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | import albumentations as A
5 | from albumentations.core.transforms_interface import ImageOnlyTransform
6 | import cv2
7 | import random
8 |
9 |
10 | class RandomDottedLine(ImageOnlyTransform):
11 | def __init__(self, num_lines=1, p=0.5):
12 | super(RandomDottedLine, self).__init__(p=p)
13 | self.num_lines = num_lines
14 |
15 | def apply(self, img, **params):
16 | h, w = img.shape[:2]
17 | for _ in range(self.num_lines):
18 | # Random start and end points
19 | x1, y1 = np.random.randint(0, w), np.random.randint(0, h)
20 | x2, y2 = np.random.randint(0, w), np.random.randint(0, h)
21 | # Random color
22 | color = tuple(np.random.randint(0, 256, size=3).tolist())
23 | # Random thickness
24 | thickness = np.random.randint(1, 5)
25 | # Draw dotted or dashed line
26 | line_type = random.choice(["dotted", "dashed", "solid"])
27 | if line_type != "solid":
28 | self._draw_dotted_line(
29 | img, (x1, y1), (x2, y2), color, thickness, line_type
30 | )
31 | else:
32 | cv2.line(img, (x1, y1), (x2, y2), color, thickness)
33 |
34 | return img
35 |
36 | def _draw_dotted_line(self, img, pt1, pt2, color, thickness, line_type):
37 | # Calculate the distance between the points
38 | dist = np.hypot(pt2[0] - pt1[0], pt2[1] - pt1[1])
39 | # Number of segments
40 | num_segments = max(int(dist // 5), 1)
41 | # Generate points along the line
42 | x_points = np.linspace(pt1[0], pt2[0], num_segments)
43 | y_points = np.linspace(pt1[1], pt2[1], num_segments)
44 | # Draw segments
45 | for i in range(num_segments - 1):
46 | if line_type == "dotted" and i % 2 == 0:
47 | pt_start = (int(x_points[i]), int(y_points[i]))
48 | pt_end = (int(x_points[i]), int(y_points[i]))
49 | cv2.circle(img, pt_start, thickness, color, -1)
50 | elif line_type == "dashed" and i % 4 < 2:
51 | pt_start = (int(x_points[i]), int(y_points[i]))
52 | pt_end = (int(x_points[i + 1]), int(y_points[i + 1]))
53 | cv2.line(img, pt_start, pt_end, color, thickness)
54 | return img
55 |
56 | def get_transform_init_args_names(self):
57 | return ("num_lines",)
58 |
59 |
60 | class ImgAugTransformV2:
61 | def __init__(self):
62 | self.aug = A.Compose(
63 | [
64 | A.InvertImg(p=0.2),
65 | A.ColorJitter(p=0.2),
66 | A.MotionBlur(blur_limit=3, p=0.2),
67 | A.RandomBrightnessContrast(p=0.2),
68 | A.Perspective(scale=(0.01, 0.05)),
69 | RandomDottedLine(),
70 | ]
71 | )
72 |
73 | def __call__(self, img):
74 | img = np.array(img)
75 | transformed = self.aug(image=img)
76 | img = transformed["image"]
77 | img = Image.fromarray(img)
78 | return img
79 |
--------------------------------------------------------------------------------
/vietocr/loader/dataloader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import random
4 | from PIL import Image
5 | from PIL import ImageFile
6 |
7 | ImageFile.LOAD_TRUNCATED_IMAGES = True
8 |
9 | from collections import defaultdict
10 | import numpy as np
11 | import torch
12 | import lmdb
13 | import six
14 | import time
15 | from tqdm import tqdm
16 |
17 | from torch.utils.data import Dataset
18 | from torch.utils.data.sampler import Sampler
19 | from vietocr.tool.translate import process_image
20 | from vietocr.tool.create_dataset import createDataset
21 | from vietocr.tool.translate import resize
22 |
23 |
24 | class OCRDataset(Dataset):
25 | def __init__(
26 | self,
27 | lmdb_path,
28 | root_dir,
29 | annotation_path,
30 | vocab,
31 | image_height=32,
32 | image_min_width=32,
33 | image_max_width=512,
34 | transform=None,
35 | ):
36 | self.root_dir = root_dir
37 | self.annotation_path = os.path.join(root_dir, annotation_path)
38 | self.vocab = vocab
39 | self.transform = transform
40 |
41 | self.image_height = image_height
42 | self.image_min_width = image_min_width
43 | self.image_max_width = image_max_width
44 |
45 | self.lmdb_path = lmdb_path
46 |
47 | if os.path.isdir(self.lmdb_path):
48 | print(
49 | "{} exists. Remove folder if you want to create new dataset".format(
50 | self.lmdb_path
51 | )
52 | )
53 | sys.stdout.flush()
54 | else:
55 | createDataset(self.lmdb_path, root_dir, annotation_path)
56 |
57 | self.env = lmdb.open(
58 | self.lmdb_path,
59 | max_readers=8,
60 | readonly=True,
61 | lock=False,
62 | readahead=False,
63 | meminit=False,
64 | )
65 | self.txn = self.env.begin(write=False)
66 |
67 | nSamples = int(self.txn.get("num-samples".encode()))
68 | self.nSamples = nSamples
69 |
70 | self.build_cluster_indices()
71 |
72 | def build_cluster_indices(self):
73 | self.cluster_indices = defaultdict(list)
74 |
75 | pbar = tqdm(
76 | range(self.__len__()),
77 | desc="{} build cluster".format(self.lmdb_path),
78 | ncols=100,
79 | position=0,
80 | leave=True,
81 | )
82 |
83 | for i in pbar:
84 | bucket = self.get_bucket(i)
85 | self.cluster_indices[bucket].append(i)
86 |
87 | def get_bucket(self, idx):
88 | key = "dim-%09d" % idx
89 |
90 | dim_img = self.txn.get(key.encode())
91 | dim_img = np.fromstring(dim_img, dtype=np.int32)
92 | imgH, imgW = dim_img
93 |
94 | new_w, image_height = resize(
95 | imgW, imgH, self.image_height, self.image_min_width, self.image_max_width
96 | )
97 |
98 | return new_w
99 |
100 | def read_buffer(self, idx):
101 | img_file = "image-%09d" % idx
102 | label_file = "label-%09d" % idx
103 | path_file = "path-%09d" % idx
104 |
105 | imgbuf = self.txn.get(img_file.encode())
106 |
107 | label = self.txn.get(label_file.encode()).decode()
108 | img_path = self.txn.get(path_file.encode()).decode()
109 |
110 | buf = six.BytesIO()
111 | buf.write(imgbuf)
112 | buf.seek(0)
113 |
114 | return buf, label, img_path
115 |
116 | def read_data(self, idx):
117 | buf, label, img_path = self.read_buffer(idx)
118 |
119 | img = Image.open(buf).convert("RGB")
120 |
121 | if self.transform:
122 | img = self.transform(img)
123 |
124 | img_bw = process_image(
125 | img, self.image_height, self.image_min_width, self.image_max_width
126 | )
127 |
128 | word = self.vocab.encode(label)
129 |
130 | return img_bw, word, img_path
131 |
132 | def __getitem__(self, idx):
133 | img, word, img_path = self.read_data(idx)
134 |
135 | img_path = os.path.join(self.root_dir, img_path)
136 |
137 | sample = {"img": img, "word": word, "img_path": img_path}
138 |
139 | return sample
140 |
141 | def __len__(self):
142 | return self.nSamples
143 |
144 |
145 | class ClusterRandomSampler(Sampler):
146 |
147 | def __init__(self, data_source, batch_size, shuffle=True):
148 | self.data_source = data_source
149 | self.batch_size = batch_size
150 | self.shuffle = shuffle
151 |
152 | def flatten_list(self, lst):
153 | return [item for sublist in lst for item in sublist]
154 |
155 | def __iter__(self):
156 | batch_lists = []
157 | for cluster, cluster_indices in self.data_source.cluster_indices.items():
158 | if self.shuffle:
159 | random.shuffle(cluster_indices)
160 |
161 | batches = [
162 | cluster_indices[i : i + self.batch_size]
163 | for i in range(0, len(cluster_indices), self.batch_size)
164 | ]
165 | batches = [_ for _ in batches if len(_) == self.batch_size]
166 | if self.shuffle:
167 | random.shuffle(batches)
168 |
169 | batch_lists.append(batches)
170 |
171 | lst = self.flatten_list(batch_lists)
172 | if self.shuffle:
173 | random.shuffle(lst)
174 |
175 | lst = self.flatten_list(lst)
176 |
177 | return iter(lst)
178 |
179 | def __len__(self):
180 | return len(self.data_source)
181 |
182 |
183 | class Collator(object):
184 | def __init__(self, masked_language_model=True):
185 | self.masked_language_model = masked_language_model
186 |
187 | def __call__(self, batch):
188 | filenames = []
189 | img = []
190 | target_weights = []
191 | tgt_input = []
192 | max_label_len = max(len(sample["word"]) for sample in batch)
193 | for sample in batch:
194 | img.append(sample["img"])
195 | filenames.append(sample["img_path"])
196 | label = sample["word"]
197 | label_len = len(label)
198 |
199 | tgt = np.concatenate(
200 | (label, np.zeros(max_label_len - label_len, dtype=np.int32))
201 | )
202 | tgt_input.append(tgt)
203 |
204 | one_mask_len = label_len - 1
205 |
206 | target_weights.append(
207 | np.concatenate(
208 | (
209 | np.ones(one_mask_len, dtype=np.float32),
210 | np.zeros(max_label_len - one_mask_len, dtype=np.float32),
211 | )
212 | )
213 | )
214 |
215 | img = np.array(img, dtype=np.float32)
216 |
217 | tgt_input = np.array(tgt_input, dtype=np.int64).T
218 | tgt_output = np.roll(tgt_input, -1, 0).T
219 | tgt_output[:, -1] = 0
220 |
221 | # random mask token
222 | if self.masked_language_model:
223 | mask = np.random.random(size=tgt_input.shape) < 0.05
224 | mask = mask & (tgt_input != 0) & (tgt_input != 1) & (tgt_input != 2)
225 | tgt_input[mask] = 3
226 |
227 | tgt_padding_mask = np.array(target_weights) == 0
228 |
229 | rs = {
230 | "img": torch.FloatTensor(img),
231 | "tgt_input": torch.LongTensor(tgt_input),
232 | "tgt_output": torch.LongTensor(tgt_output),
233 | "tgt_padding_mask": torch.BoolTensor(tgt_padding_mask),
234 | "filenames": filenames,
235 | }
236 |
237 | return rs
238 |
--------------------------------------------------------------------------------
/vietocr/loader/dataloader_v1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | import random
5 | from vietocr.model.vocab import Vocab
6 | from vietocr.tool.translate import process_image
7 | import os
8 | from collections import defaultdict
9 | import math
10 | from prefetch_generator import background
11 |
12 |
13 | class BucketData(object):
14 | def __init__(self, device):
15 | self.max_label_len = 0
16 | self.data_list = []
17 | self.label_list = []
18 | self.file_list = []
19 | self.device = device
20 |
21 | def append(self, datum, label, filename):
22 | self.data_list.append(datum)
23 | self.label_list.append(label)
24 | self.file_list.append(filename)
25 |
26 | self.max_label_len = max(len(label), self.max_label_len)
27 |
28 | return len(self.data_list)
29 |
30 | def flush_out(self):
31 | """
32 | Shape:
33 | - img: (N, C, H, W)
34 | - tgt_input: (T, N)
35 | - tgt_output: (N, T)
36 | - tgt_padding_mask: (N, T)
37 | """
38 | # encoder part
39 | img = np.array(self.data_list, dtype=np.float32)
40 |
41 | # decoder part
42 | target_weights = []
43 | tgt_input = []
44 | for label in self.label_list:
45 | label_len = len(label)
46 |
47 | tgt = np.concatenate(
48 | (label, np.zeros(self.max_label_len - label_len, dtype=np.int32))
49 | )
50 | tgt_input.append(tgt)
51 |
52 | one_mask_len = label_len - 1
53 |
54 | target_weights.append(
55 | np.concatenate(
56 | (
57 | np.ones(one_mask_len, dtype=np.float32),
58 | np.zeros(self.max_label_len - one_mask_len, dtype=np.float32),
59 | )
60 | )
61 | )
62 |
63 | # reshape to fit input shape
64 | tgt_input = np.array(tgt_input, dtype=np.int64).T
65 | tgt_output = np.roll(tgt_input, -1, 0).T
66 | tgt_output[:, -1] = 0
67 |
68 | tgt_padding_mask = np.array(target_weights) == 0
69 |
70 | filenames = self.file_list
71 |
72 | self.data_list, self.label_list, self.file_list = [], [], []
73 | self.max_label_len = 0
74 |
75 | rs = {
76 | "img": torch.FloatTensor(img).to(self.device),
77 | "tgt_input": torch.LongTensor(tgt_input).to(self.device),
78 | "tgt_output": torch.LongTensor(tgt_output).to(self.device),
79 | "tgt_padding_mask": torch.BoolTensor(tgt_padding_mask).to(self.device),
80 | "filenames": filenames,
81 | }
82 |
83 | return rs
84 |
85 | def __len__(self):
86 | return len(self.data_list)
87 |
88 | def __iadd__(self, other):
89 | self.data_list += other.data_list
90 | self.label_list += other.label_list
91 | self.max_label_len = max(self.max_label_len, other.max_label_len)
92 | self.max_width = max(self.max_width, other.max_width)
93 |
94 | def __add__(self, other):
95 | res = BucketData()
96 | res.data_list = self.data_list + other.data_list
97 | res.label_list = self.label_list + other.label_list
98 | res.max_width = max(self.max_width, other.max_width)
99 | res.max_label_len = max((self.max_label_len, other.max_label_len))
100 | return res
101 |
102 |
103 | class DataGen(object):
104 |
105 | def __init__(
106 | self,
107 | data_root,
108 | annotation_fn,
109 | vocab,
110 | device,
111 | image_height=32,
112 | image_min_width=32,
113 | image_max_width=512,
114 | ):
115 |
116 | self.image_height = image_height
117 | self.image_min_width = image_min_width
118 | self.image_max_width = image_max_width
119 |
120 | self.data_root = data_root
121 | self.annotation_path = os.path.join(data_root, annotation_fn)
122 |
123 | self.vocab = vocab
124 | self.device = device
125 |
126 | self.clear()
127 |
128 | def clear(self):
129 | self.bucket_data = defaultdict(lambda: BucketData(self.device))
130 |
131 | @background(max_prefetch=1)
132 | def gen(self, batch_size, last_batch=True):
133 | with open(self.annotation_path, "r") as ann_file:
134 | lines = ann_file.readlines()
135 | np.random.shuffle(lines)
136 | for l in lines:
137 |
138 | img_path, lex = l.strip().split("\t")
139 |
140 | img_path = os.path.join(self.data_root, img_path)
141 |
142 | try:
143 | img_bw, word = self.read_data(img_path, lex)
144 | except IOError:
145 | print("ioread image:{}".format(img_path))
146 |
147 | width = img_bw.shape[-1]
148 |
149 | bs = self.bucket_data[width].append(img_bw, word, img_path)
150 | if bs >= batch_size:
151 | b = self.bucket_data[width].flush_out()
152 | yield b
153 |
154 | if last_batch:
155 | for bucket in self.bucket_data.values():
156 | if len(bucket) > 0:
157 | b = bucket.flush_out()
158 | yield b
159 |
160 | self.clear()
161 |
162 | def read_data(self, img_path, lex):
163 |
164 | with open(img_path, "rb") as img_file:
165 | img = Image.open(img_file).convert("RGB")
166 | img_bw = process_image(
167 | img, self.image_height, self.image_min_width, self.image_max_width
168 | )
169 |
170 | word = self.vocab.encode(lex)
171 |
172 | return img_bw, word
173 |
--------------------------------------------------------------------------------
/vietocr/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/model/__init__.py
--------------------------------------------------------------------------------
/vietocr/model/backbone/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/model/backbone/__init__.py
--------------------------------------------------------------------------------
/vietocr/model/backbone/cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | import vietocr.model.backbone.vgg as vgg
5 | from vietocr.model.backbone.resnet import Resnet50
6 |
7 |
8 | class CNN(nn.Module):
9 | def __init__(self, backbone, **kwargs):
10 | super(CNN, self).__init__()
11 |
12 | if backbone == "vgg11_bn":
13 | self.model = vgg.vgg11_bn(**kwargs)
14 | elif backbone == "vgg19_bn":
15 | self.model = vgg.vgg19_bn(**kwargs)
16 | elif backbone == "resnet50":
17 | self.model = Resnet50(**kwargs)
18 |
19 | def forward(self, x):
20 | return self.model(x)
21 |
22 | def freeze(self):
23 | for name, param in self.model.features.named_parameters():
24 | if name != "last_conv_1x1":
25 | param.requires_grad = False
26 |
27 | def unfreeze(self):
28 | for param in self.model.features.parameters():
29 | param.requires_grad = True
30 |
--------------------------------------------------------------------------------
/vietocr/model/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class BasicBlock(nn.Module):
6 | expansion = 1
7 |
8 | def __init__(self, inplanes, planes, stride=1, downsample=None):
9 | super(BasicBlock, self).__init__()
10 | self.conv1 = self._conv3x3(inplanes, planes)
11 | self.bn1 = nn.BatchNorm2d(planes)
12 | self.conv2 = self._conv3x3(planes, planes)
13 | self.bn2 = nn.BatchNorm2d(planes)
14 | self.relu = nn.ReLU(inplace=True)
15 | self.downsample = downsample
16 | self.stride = stride
17 |
18 | def _conv3x3(self, in_planes, out_planes, stride=1):
19 | "3x3 convolution with padding"
20 | return nn.Conv2d(
21 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
22 | )
23 |
24 | def forward(self, x):
25 | residual = x
26 |
27 | out = self.conv1(x)
28 | out = self.bn1(out)
29 | out = self.relu(out)
30 |
31 | out = self.conv2(out)
32 | out = self.bn2(out)
33 |
34 | if self.downsample is not None:
35 | residual = self.downsample(x)
36 | out += residual
37 | out = self.relu(out)
38 |
39 | return out
40 |
41 |
42 | class ResNet(nn.Module):
43 |
44 | def __init__(self, input_channel, output_channel, block, layers):
45 | super(ResNet, self).__init__()
46 |
47 | self.output_channel_block = [
48 | int(output_channel / 4),
49 | int(output_channel / 2),
50 | output_channel,
51 | output_channel,
52 | ]
53 |
54 | self.inplanes = int(output_channel / 8)
55 | self.conv0_1 = nn.Conv2d(
56 | input_channel,
57 | int(output_channel / 16),
58 | kernel_size=3,
59 | stride=1,
60 | padding=1,
61 | bias=False,
62 | )
63 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
64 | self.conv0_2 = nn.Conv2d(
65 | int(output_channel / 16),
66 | self.inplanes,
67 | kernel_size=3,
68 | stride=1,
69 | padding=1,
70 | bias=False,
71 | )
72 | self.bn0_2 = nn.BatchNorm2d(self.inplanes)
73 | self.relu = nn.ReLU(inplace=True)
74 |
75 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
76 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
77 | self.conv1 = nn.Conv2d(
78 | self.output_channel_block[0],
79 | self.output_channel_block[0],
80 | kernel_size=3,
81 | stride=1,
82 | padding=1,
83 | bias=False,
84 | )
85 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
86 |
87 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
88 | self.layer2 = self._make_layer(
89 | block, self.output_channel_block[1], layers[1], stride=1
90 | )
91 | self.conv2 = nn.Conv2d(
92 | self.output_channel_block[1],
93 | self.output_channel_block[1],
94 | kernel_size=3,
95 | stride=1,
96 | padding=1,
97 | bias=False,
98 | )
99 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
100 |
101 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
102 | self.layer3 = self._make_layer(
103 | block, self.output_channel_block[2], layers[2], stride=1
104 | )
105 | self.conv3 = nn.Conv2d(
106 | self.output_channel_block[2],
107 | self.output_channel_block[2],
108 | kernel_size=3,
109 | stride=1,
110 | padding=1,
111 | bias=False,
112 | )
113 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
114 |
115 | self.layer4 = self._make_layer(
116 | block, self.output_channel_block[3], layers[3], stride=1
117 | )
118 | self.conv4_1 = nn.Conv2d(
119 | self.output_channel_block[3],
120 | self.output_channel_block[3],
121 | kernel_size=2,
122 | stride=(2, 1),
123 | padding=(0, 1),
124 | bias=False,
125 | )
126 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
127 | self.conv4_2 = nn.Conv2d(
128 | self.output_channel_block[3],
129 | self.output_channel_block[3],
130 | kernel_size=2,
131 | stride=1,
132 | padding=0,
133 | bias=False,
134 | )
135 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
136 |
137 | def _make_layer(self, block, planes, blocks, stride=1):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | nn.Conv2d(
142 | self.inplanes,
143 | planes * block.expansion,
144 | kernel_size=1,
145 | stride=stride,
146 | bias=False,
147 | ),
148 | nn.BatchNorm2d(planes * block.expansion),
149 | )
150 |
151 | layers = []
152 | layers.append(block(self.inplanes, planes, stride, downsample))
153 | self.inplanes = planes * block.expansion
154 | for i in range(1, blocks):
155 | layers.append(block(self.inplanes, planes))
156 |
157 | return nn.Sequential(*layers)
158 |
159 | def forward(self, x):
160 | x = self.conv0_1(x)
161 | x = self.bn0_1(x)
162 | x = self.relu(x)
163 | x = self.conv0_2(x)
164 | x = self.bn0_2(x)
165 | x = self.relu(x)
166 |
167 | x = self.maxpool1(x)
168 | x = self.layer1(x)
169 | x = self.conv1(x)
170 | x = self.bn1(x)
171 | x = self.relu(x)
172 |
173 | x = self.maxpool2(x)
174 | x = self.layer2(x)
175 | x = self.conv2(x)
176 | x = self.bn2(x)
177 | x = self.relu(x)
178 |
179 | x = self.maxpool3(x)
180 | x = self.layer3(x)
181 | x = self.conv3(x)
182 | x = self.bn3(x)
183 | x = self.relu(x)
184 |
185 | x = self.layer4(x)
186 | x = self.conv4_1(x)
187 | x = self.bn4_1(x)
188 | x = self.relu(x)
189 | x = self.conv4_2(x)
190 | x = self.bn4_2(x)
191 | conv = self.relu(x)
192 |
193 | conv = conv.transpose(-1, -2)
194 | conv = conv.flatten(2)
195 | conv = conv.permute(-1, 0, 1)
196 |
197 | return conv
198 |
199 |
200 | def Resnet50(ss, hidden):
201 | return ResNet(3, hidden, BasicBlock, [1, 2, 5, 3])
202 |
--------------------------------------------------------------------------------
/vietocr/model/backbone/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchvision import models
4 | from einops import rearrange
5 | from torchvision.models._utils import IntermediateLayerGetter
6 |
7 |
8 | class Vgg(nn.Module):
9 | def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
10 | super(Vgg, self).__init__()
11 |
12 | if pretrained:
13 | weights = "DEFAULT"
14 | else:
15 | weights = None
16 |
17 | if name == "vgg11_bn":
18 | cnn = models.vgg11_bn(weights=weights)
19 | elif name == "vgg19_bn":
20 | cnn = models.vgg19_bn(weights=weights)
21 |
22 | pool_idx = 0
23 |
24 | for i, layer in enumerate(cnn.features):
25 | if isinstance(layer, torch.nn.MaxPool2d):
26 | cnn.features[i] = torch.nn.AvgPool2d(
27 | kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0
28 | )
29 | pool_idx += 1
30 |
31 | self.features = cnn.features
32 | self.dropout = nn.Dropout(dropout)
33 | self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
34 |
35 | def forward(self, x):
36 | """
37 | Shape:
38 | - x: (N, C, H, W)
39 | - output: (W, N, C)
40 | """
41 |
42 | conv = self.features(x)
43 | conv = self.dropout(conv)
44 | conv = self.last_conv_1x1(conv)
45 |
46 | # conv = rearrange(conv, 'b d h w -> b d (w h)')
47 | conv = conv.transpose(-1, -2)
48 | conv = conv.flatten(2)
49 | conv = conv.permute(-1, 0, 1)
50 | return conv
51 |
52 |
53 | def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
54 | return Vgg("vgg11_bn", ss, ks, hidden, pretrained, dropout)
55 |
56 |
57 | def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
58 | return Vgg("vgg19_bn", ss, ks, hidden, pretrained, dropout)
59 |
--------------------------------------------------------------------------------
/vietocr/model/beam.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Beam:
5 |
6 | def __init__(
7 | self,
8 | beam_size=8,
9 | min_length=0,
10 | n_top=1,
11 | ranker=None,
12 | start_token_id=1,
13 | end_token_id=2,
14 | ):
15 | self.beam_size = beam_size
16 | self.min_length = min_length
17 | self.ranker = ranker
18 |
19 | self.end_token_id = end_token_id
20 | self.top_sentence_ended = False
21 |
22 | self.prev_ks = []
23 | self.next_ys = [
24 | torch.LongTensor(beam_size).fill_(start_token_id)
25 | ] # remove padding
26 |
27 | self.current_scores = torch.FloatTensor(beam_size).zero_()
28 | self.all_scores = []
29 |
30 | # Time and k pair for finished.
31 | self.finished = []
32 | self.n_top = n_top
33 |
34 | self.ranker = ranker
35 |
36 | def advance(self, next_log_probs):
37 | # next_probs : beam_size X vocab_size
38 |
39 | vocabulary_size = next_log_probs.size(1)
40 | # current_beam_size = next_log_probs.size(0)
41 |
42 | current_length = len(self.next_ys)
43 | if current_length < self.min_length:
44 | for beam_index in range(len(next_log_probs)):
45 | next_log_probs[beam_index][self.end_token_id] = -1e10
46 |
47 | if len(self.prev_ks) > 0:
48 | beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as(
49 | next_log_probs
50 | )
51 | # Don't let EOS have children.
52 | last_y = self.next_ys[-1]
53 | for beam_index in range(last_y.size(0)):
54 | if last_y[beam_index] == self.end_token_id:
55 | beam_scores[beam_index] = -1e10 # -1e20 raises error when executing
56 | else:
57 | beam_scores = next_log_probs[0]
58 |
59 | flat_beam_scores = beam_scores.view(-1)
60 | top_scores, top_score_ids = flat_beam_scores.topk(
61 | k=self.beam_size, dim=0, largest=True, sorted=True
62 | )
63 |
64 | self.current_scores = top_scores
65 | self.all_scores.append(self.current_scores)
66 |
67 | prev_k = top_score_ids // vocabulary_size # (beam_size, )
68 | next_y = top_score_ids - prev_k * vocabulary_size # (beam_size, )
69 |
70 | self.prev_ks.append(prev_k)
71 | self.next_ys.append(next_y)
72 |
73 | for beam_index, last_token_id in enumerate(next_y):
74 |
75 | if last_token_id == self.end_token_id:
76 |
77 | # skip scoring
78 | self.finished.append(
79 | (self.current_scores[beam_index], len(self.next_ys) - 1, beam_index)
80 | )
81 |
82 | if next_y[0] == self.end_token_id:
83 | self.top_sentence_ended = True
84 |
85 | def get_current_state(self):
86 | "Get the outputs for the current timestep."
87 | return torch.stack(self.next_ys, dim=1)
88 |
89 | def get_current_origin(self):
90 | "Get the backpointers for the current timestep."
91 | return self.prev_ks[-1]
92 |
93 | def done(self):
94 | return self.top_sentence_ended and len(self.finished) >= self.n_top
95 |
96 | def get_hypothesis(self, timestep, k):
97 | hypothesis = []
98 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
99 | hypothesis.append(self.next_ys[j + 1][k])
100 | # for RNN, [:, k, :], and for trnasformer, [k, :, :]
101 | k = self.prev_ks[j][k]
102 |
103 | return hypothesis[::-1]
104 |
105 | def sort_finished(self, minimum=None):
106 | if minimum is not None:
107 | i = 0
108 | # Add from beam until we have minimum outputs.
109 | while len(self.finished) < minimum:
110 | # global_scores = self.global_scorer.score(self, self.scores)
111 | # s = global_scores[i]
112 | s = self.current_scores[i]
113 | self.finished.append((s, len(self.next_ys) - 1, i))
114 | i += 1
115 |
116 | self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True)
117 | scores = [sc for sc, _, _ in self.finished]
118 | ks = [(t, k) for _, t, k in self.finished]
119 | return scores, ks
120 |
--------------------------------------------------------------------------------
/vietocr/model/seqmodel/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/model/seqmodel/__init__.py
--------------------------------------------------------------------------------
/vietocr/model/seqmodel/convseq2seq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 |
6 |
7 | class Encoder(nn.Module):
8 | def __init__(
9 | self, emb_dim, hid_dim, n_layers, kernel_size, dropout, device, max_length=512
10 | ):
11 | super().__init__()
12 |
13 | assert kernel_size % 2 == 1, "Kernel size must be odd!"
14 |
15 | self.device = device
16 |
17 | self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
18 |
19 | # self.tok_embedding = nn.Embedding(input_dim, emb_dim)
20 | self.pos_embedding = nn.Embedding(max_length, emb_dim)
21 |
22 | self.emb2hid = nn.Linear(emb_dim, hid_dim)
23 | self.hid2emb = nn.Linear(hid_dim, emb_dim)
24 |
25 | self.convs = nn.ModuleList(
26 | [
27 | nn.Conv1d(
28 | in_channels=hid_dim,
29 | out_channels=2 * hid_dim,
30 | kernel_size=kernel_size,
31 | padding=(kernel_size - 1) // 2,
32 | )
33 | for _ in range(n_layers)
34 | ]
35 | )
36 |
37 | self.dropout = nn.Dropout(dropout)
38 |
39 | def forward(self, src):
40 |
41 | # src = [batch size, src len]
42 |
43 | src = src.transpose(0, 1)
44 |
45 | batch_size = src.shape[0]
46 | src_len = src.shape[1]
47 | device = src.device
48 |
49 | # create position tensor
50 | pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)
51 |
52 | # pos = [0, 1, 2, 3, ..., src len - 1]
53 |
54 | # pos = [batch size, src len]
55 |
56 | # embed tokens and positions
57 |
58 | # tok_embedded = self.tok_embedding(src)
59 | tok_embedded = src
60 |
61 | pos_embedded = self.pos_embedding(pos)
62 |
63 | # tok_embedded = pos_embedded = [batch size, src len, emb dim]
64 |
65 | # combine embeddings by elementwise summing
66 | embedded = self.dropout(tok_embedded + pos_embedded)
67 |
68 | # embedded = [batch size, src len, emb dim]
69 |
70 | # pass embedded through linear layer to convert from emb dim to hid dim
71 | conv_input = self.emb2hid(embedded)
72 |
73 | # conv_input = [batch size, src len, hid dim]
74 |
75 | # permute for convolutional layer
76 | conv_input = conv_input.permute(0, 2, 1)
77 |
78 | # conv_input = [batch size, hid dim, src len]
79 |
80 | # begin convolutional blocks...
81 |
82 | for i, conv in enumerate(self.convs):
83 |
84 | # pass through convolutional layer
85 | conved = conv(self.dropout(conv_input))
86 |
87 | # conved = [batch size, 2 * hid dim, src len]
88 |
89 | # pass through GLU activation function
90 | conved = F.glu(conved, dim=1)
91 |
92 | # conved = [batch size, hid dim, src len]
93 |
94 | # apply residual connection
95 | conved = (conved + conv_input) * self.scale
96 |
97 | # conved = [batch size, hid dim, src len]
98 |
99 | # set conv_input to conved for next loop iteration
100 | conv_input = conved
101 |
102 | # ...end convolutional blocks
103 |
104 | # permute and convert back to emb dim
105 | conved = self.hid2emb(conved.permute(0, 2, 1))
106 |
107 | # conved = [batch size, src len, emb dim]
108 |
109 | # elementwise sum output (conved) and input (embedded) to be used for attention
110 | combined = (conved + embedded) * self.scale
111 |
112 | # combined = [batch size, src len, emb dim]
113 |
114 | return conved, combined
115 |
116 |
117 | class Decoder(nn.Module):
118 | def __init__(
119 | self,
120 | output_dim,
121 | emb_dim,
122 | hid_dim,
123 | n_layers,
124 | kernel_size,
125 | dropout,
126 | trg_pad_idx,
127 | device,
128 | max_length=512,
129 | ):
130 | super().__init__()
131 |
132 | self.kernel_size = kernel_size
133 | self.trg_pad_idx = trg_pad_idx
134 | self.device = device
135 |
136 | self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
137 |
138 | self.tok_embedding = nn.Embedding(output_dim, emb_dim)
139 | self.pos_embedding = nn.Embedding(max_length, emb_dim)
140 |
141 | self.emb2hid = nn.Linear(emb_dim, hid_dim)
142 | self.hid2emb = nn.Linear(hid_dim, emb_dim)
143 |
144 | self.attn_hid2emb = nn.Linear(hid_dim, emb_dim)
145 | self.attn_emb2hid = nn.Linear(emb_dim, hid_dim)
146 |
147 | self.fc_out = nn.Linear(emb_dim, output_dim)
148 |
149 | self.convs = nn.ModuleList(
150 | [
151 | nn.Conv1d(
152 | in_channels=hid_dim,
153 | out_channels=2 * hid_dim,
154 | kernel_size=kernel_size,
155 | )
156 | for _ in range(n_layers)
157 | ]
158 | )
159 |
160 | self.dropout = nn.Dropout(dropout)
161 |
162 | def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined):
163 |
164 | # embedded = [batch size, trg len, emb dim]
165 | # conved = [batch size, hid dim, trg len]
166 | # encoder_conved = encoder_combined = [batch size, src len, emb dim]
167 |
168 | # permute and convert back to emb dim
169 | conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1))
170 |
171 | # conved_emb = [batch size, trg len, emb dim]
172 |
173 | combined = (conved_emb + embedded) * self.scale
174 |
175 | # combined = [batch size, trg len, emb dim]
176 |
177 | energy = torch.matmul(combined, encoder_conved.permute(0, 2, 1))
178 |
179 | # energy = [batch size, trg len, src len]
180 |
181 | attention = F.softmax(energy, dim=2)
182 |
183 | # attention = [batch size, trg len, src len]
184 |
185 | attended_encoding = torch.matmul(attention, encoder_combined)
186 |
187 | # attended_encoding = [batch size, trg len, emd dim]
188 |
189 | # convert from emb dim -> hid dim
190 | attended_encoding = self.attn_emb2hid(attended_encoding)
191 |
192 | # attended_encoding = [batch size, trg len, hid dim]
193 |
194 | # apply residual connection
195 | attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale
196 |
197 | # attended_combined = [batch size, hid dim, trg len]
198 |
199 | return attention, attended_combined
200 |
201 | def forward(self, trg, encoder_conved, encoder_combined):
202 |
203 | # trg = [batch size, trg len]
204 | # encoder_conved = encoder_combined = [batch size, src len, emb dim]
205 | trg = trg.transpose(0, 1)
206 |
207 | batch_size = trg.shape[0]
208 | trg_len = trg.shape[1]
209 | device = trg.device
210 |
211 | # create position tensor
212 | pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device)
213 |
214 | # pos = [batch size, trg len]
215 |
216 | # embed tokens and positions
217 | tok_embedded = self.tok_embedding(trg)
218 | pos_embedded = self.pos_embedding(pos)
219 |
220 | # tok_embedded = [batch size, trg len, emb dim]
221 | # pos_embedded = [batch size, trg len, emb dim]
222 |
223 | # combine embeddings by elementwise summing
224 | embedded = self.dropout(tok_embedded + pos_embedded)
225 |
226 | # embedded = [batch size, trg len, emb dim]
227 |
228 | # pass embedded through linear layer to go through emb dim -> hid dim
229 | conv_input = self.emb2hid(embedded)
230 |
231 | # conv_input = [batch size, trg len, hid dim]
232 |
233 | # permute for convolutional layer
234 | conv_input = conv_input.permute(0, 2, 1)
235 |
236 | # conv_input = [batch size, hid dim, trg len]
237 |
238 | batch_size = conv_input.shape[0]
239 | hid_dim = conv_input.shape[1]
240 |
241 | for i, conv in enumerate(self.convs):
242 |
243 | # apply dropout
244 | conv_input = self.dropout(conv_input)
245 |
246 | # need to pad so decoder can't "cheat"
247 | padding = (
248 | torch.zeros(batch_size, hid_dim, self.kernel_size - 1)
249 | .fill_(self.trg_pad_idx)
250 | .to(device)
251 | )
252 |
253 | padded_conv_input = torch.cat((padding, conv_input), dim=2)
254 |
255 | # padded_conv_input = [batch size, hid dim, trg len + kernel size - 1]
256 |
257 | # pass through convolutional layer
258 | conved = conv(padded_conv_input)
259 |
260 | # conved = [batch size, 2 * hid dim, trg len]
261 |
262 | # pass through GLU activation function
263 | conved = F.glu(conved, dim=1)
264 |
265 | # conved = [batch size, hid dim, trg len]
266 |
267 | # calculate attention
268 | attention, conved = self.calculate_attention(
269 | embedded, conved, encoder_conved, encoder_combined
270 | )
271 |
272 | # attention = [batch size, trg len, src len]
273 |
274 | # apply residual connection
275 | conved = (conved + conv_input) * self.scale
276 |
277 | # conved = [batch size, hid dim, trg len]
278 |
279 | # set conv_input to conved for next loop iteration
280 | conv_input = conved
281 |
282 | conved = self.hid2emb(conved.permute(0, 2, 1))
283 |
284 | # conved = [batch size, trg len, emb dim]
285 |
286 | output = self.fc_out(self.dropout(conved))
287 |
288 | # output = [batch size, trg len, output dim]
289 |
290 | return output, attention
291 |
292 |
293 | class ConvSeq2Seq(nn.Module):
294 | def __init__(
295 | self,
296 | vocab_size,
297 | emb_dim,
298 | hid_dim,
299 | enc_layers,
300 | dec_layers,
301 | enc_kernel_size,
302 | dec_kernel_size,
303 | enc_max_length,
304 | dec_max_length,
305 | dropout,
306 | pad_idx,
307 | device,
308 | ):
309 | super().__init__()
310 |
311 | enc = Encoder(
312 | emb_dim,
313 | hid_dim,
314 | enc_layers,
315 | enc_kernel_size,
316 | dropout,
317 | device,
318 | enc_max_length,
319 | )
320 | dec = Decoder(
321 | vocab_size,
322 | emb_dim,
323 | hid_dim,
324 | dec_layers,
325 | dec_kernel_size,
326 | dropout,
327 | pad_idx,
328 | device,
329 | dec_max_length,
330 | )
331 |
332 | self.encoder = enc
333 | self.decoder = dec
334 |
335 | def forward_encoder(self, src):
336 | encoder_conved, encoder_combined = self.encoder(src)
337 |
338 | return encoder_conved, encoder_combined
339 |
340 | def forward_decoder(self, trg, memory):
341 | encoder_conved, encoder_combined = memory
342 | output, attention = self.decoder(trg, encoder_conved, encoder_combined)
343 |
344 | return output, (encoder_conved, encoder_combined)
345 |
346 | def forward(self, src, trg):
347 |
348 | # src = [batch size, src len]
349 | # trg = [batch size, trg len - 1] ( token sliced off the end)
350 |
351 | # calculate z^u (encoder_conved) and (z^u + e) (encoder_combined)
352 | # encoder_conved is output from final encoder conv. block
353 | # encoder_combined is encoder_conved plus (elementwise) src embedding plus
354 | # positional embeddings
355 | encoder_conved, encoder_combined = self.encoder(src)
356 |
357 | # encoder_conved = [batch size, src len, emb dim]
358 | # encoder_combined = [batch size, src len, emb dim]
359 |
360 | # calculate predictions of next words
361 | # output is a batch of predictions for each word in the trg sentence
362 | # attention a batch of attention scores across the src sentence for
363 | # each word in the trg sentence
364 | output, attention = self.decoder(trg, encoder_conved, encoder_combined)
365 |
366 | # output = [batch size, trg len - 1, output dim]
367 | # attention = [batch size, trg len - 1, src len]
368 |
369 | return output # , attention
370 |
--------------------------------------------------------------------------------
/vietocr/model/seqmodel/seq2seq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 |
6 |
7 | class Encoder(nn.Module):
8 | def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
9 | super().__init__()
10 |
11 | self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True)
12 | self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
13 | self.dropout = nn.Dropout(dropout)
14 |
15 | def forward(self, src):
16 | """
17 | src: src_len x batch_size x img_channel
18 | outputs: src_len x batch_size x hid_dim
19 | hidden: batch_size x hid_dim
20 | """
21 |
22 | embedded = self.dropout(src)
23 |
24 | outputs, hidden = self.rnn(embedded)
25 |
26 | hidden = torch.tanh(
27 | self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
28 | )
29 |
30 | return outputs, hidden
31 |
32 |
33 | class Attention(nn.Module):
34 | def __init__(self, enc_hid_dim, dec_hid_dim):
35 | super().__init__()
36 |
37 | self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
38 | self.v = nn.Linear(dec_hid_dim, 1, bias=False)
39 |
40 | def forward(self, hidden, encoder_outputs):
41 | """
42 | hidden: batch_size x hid_dim
43 | encoder_outputs: src_len x batch_size x hid_dim,
44 | outputs: batch_size x src_len
45 | """
46 |
47 | batch_size = encoder_outputs.shape[1]
48 | src_len = encoder_outputs.shape[0]
49 |
50 | hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
51 |
52 | encoder_outputs = encoder_outputs.permute(1, 0, 2)
53 |
54 | energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
55 |
56 | attention = self.v(energy).squeeze(2)
57 |
58 | return F.softmax(attention, dim=1)
59 |
60 |
61 | class Decoder(nn.Module):
62 | def __init__(
63 | self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention
64 | ):
65 | super().__init__()
66 |
67 | self.output_dim = output_dim
68 | self.attention = attention
69 |
70 | self.embedding = nn.Embedding(output_dim, emb_dim)
71 | self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
72 | self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
73 | self.dropout = nn.Dropout(dropout)
74 |
75 | def forward(self, input, hidden, encoder_outputs):
76 | """
77 | inputs: batch_size
78 | hidden: batch_size x hid_dim
79 | encoder_outputs: src_len x batch_size x hid_dim
80 | """
81 |
82 | input = input.unsqueeze(0)
83 |
84 | embedded = self.dropout(self.embedding(input))
85 |
86 | a = self.attention(hidden, encoder_outputs)
87 |
88 | a = a.unsqueeze(1)
89 |
90 | encoder_outputs = encoder_outputs.permute(1, 0, 2)
91 |
92 | weighted = torch.bmm(a, encoder_outputs)
93 |
94 | weighted = weighted.permute(1, 0, 2)
95 |
96 | rnn_input = torch.cat((embedded, weighted), dim=2)
97 |
98 | output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
99 |
100 | assert (output == hidden).all()
101 |
102 | embedded = embedded.squeeze(0)
103 | output = output.squeeze(0)
104 | weighted = weighted.squeeze(0)
105 |
106 | prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))
107 |
108 | return prediction, hidden.squeeze(0), a.squeeze(1)
109 |
110 |
111 | class Seq2Seq(nn.Module):
112 | def __init__(
113 | self,
114 | vocab_size,
115 | encoder_hidden,
116 | decoder_hidden,
117 | img_channel,
118 | decoder_embedded,
119 | dropout=0.1,
120 | ):
121 | super().__init__()
122 |
123 | attn = Attention(encoder_hidden, decoder_hidden)
124 |
125 | self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout)
126 | self.decoder = Decoder(
127 | vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn
128 | )
129 |
130 | def forward_encoder(self, src):
131 | """
132 | src: timestep x batch_size x channel
133 | hidden: batch_size x hid_dim
134 | encoder_outputs: src_len x batch_size x hid_dim
135 | """
136 |
137 | encoder_outputs, hidden = self.encoder(src)
138 |
139 | return (hidden, encoder_outputs)
140 |
141 | def forward_decoder(self, tgt, memory):
142 | """
143 | tgt: timestep x batch_size
144 | hidden: batch_size x hid_dim
145 | encouder: src_len x batch_size x hid_dim
146 | output: batch_size x 1 x vocab_size
147 | """
148 |
149 | tgt = tgt[-1]
150 | hidden, encoder_outputs = memory
151 | output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs)
152 | output = output.unsqueeze(1)
153 |
154 | return output, (hidden, encoder_outputs)
155 |
156 | def forward(self, src, trg):
157 | """
158 | src: time_step x batch_size
159 | trg: time_step x batch_size
160 | outputs: batch_size x time_step x vocab_size
161 | """
162 |
163 | batch_size = src.shape[1]
164 | trg_len = trg.shape[0]
165 | trg_vocab_size = self.decoder.output_dim
166 | device = src.device
167 |
168 | outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device)
169 | encoder_outputs, hidden = self.encoder(src)
170 |
171 | for t in range(trg_len):
172 | input = trg[t]
173 | output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
174 |
175 | outputs[t] = output
176 |
177 | outputs = outputs.transpose(0, 1).contiguous()
178 |
179 | return outputs
180 |
181 | def expand_memory(self, memory, beam_size):
182 | hidden, encoder_outputs = memory
183 | hidden = hidden.repeat(beam_size, 1)
184 | encoder_outputs = encoder_outputs.repeat(1, beam_size, 1)
185 |
186 | return (hidden, encoder_outputs)
187 |
188 | def get_memory(self, memory, i):
189 | hidden, encoder_outputs = memory
190 | hidden = hidden[[i]]
191 | encoder_outputs = encoder_outputs[:, [i], :]
192 |
193 | return (hidden, encoder_outputs)
194 |
--------------------------------------------------------------------------------
/vietocr/model/seqmodel/transformer.py:
--------------------------------------------------------------------------------
1 | from einops import rearrange
2 | from torchvision import models
3 | import math
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class LanguageTransformer(nn.Module):
9 | def __init__(
10 | self,
11 | vocab_size,
12 | d_model,
13 | nhead,
14 | num_encoder_layers,
15 | num_decoder_layers,
16 | dim_feedforward,
17 | max_seq_length,
18 | pos_dropout,
19 | trans_dropout,
20 | ):
21 | super().__init__()
22 |
23 | self.d_model = d_model
24 | self.embed_tgt = nn.Embedding(vocab_size, d_model)
25 | self.pos_enc = PositionalEncoding(d_model, pos_dropout, max_seq_length)
26 | # self.learned_pos_enc = LearnedPositionalEncoding(d_model, pos_dropout, max_seq_length)
27 |
28 | self.transformer = nn.Transformer(
29 | d_model,
30 | nhead,
31 | num_encoder_layers,
32 | num_decoder_layers,
33 | dim_feedforward,
34 | trans_dropout,
35 | )
36 |
37 | self.fc = nn.Linear(d_model, vocab_size)
38 |
39 | def forward(
40 | self,
41 | src,
42 | tgt,
43 | src_key_padding_mask=None,
44 | tgt_key_padding_mask=None,
45 | memory_key_padding_mask=None,
46 | ):
47 | """
48 | Shape:
49 | - src: (W, N, C)
50 | - tgt: (T, N)
51 | - src_key_padding_mask: (N, S)
52 | - tgt_key_padding_mask: (N, T)
53 | - memory_key_padding_mask: (N, S)
54 | - output: (N, T, E)
55 |
56 | """
57 | tgt_mask = self.gen_nopeek_mask(tgt.shape[0]).to(src.device)
58 |
59 | src = self.pos_enc(src * math.sqrt(self.d_model))
60 | # src = self.learned_pos_enc(src*math.sqrt(self.d_model))
61 |
62 | tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
63 |
64 | output = self.transformer(
65 | src,
66 | tgt,
67 | tgt_mask=tgt_mask,
68 | src_key_padding_mask=src_key_padding_mask,
69 | tgt_key_padding_mask=tgt_key_padding_mask.float(),
70 | memory_key_padding_mask=memory_key_padding_mask,
71 | )
72 | # output = rearrange(output, 't n e -> n t e')
73 | output = output.transpose(0, 1)
74 | return self.fc(output)
75 |
76 | def gen_nopeek_mask(self, length):
77 | mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1)
78 | mask = (
79 | mask.float()
80 | .masked_fill(mask == 0, float("-inf"))
81 | .masked_fill(mask == 1, float(0.0))
82 | )
83 |
84 | return mask
85 |
86 | def forward_encoder(self, src):
87 | src = self.pos_enc(src * math.sqrt(self.d_model))
88 | memory = self.transformer.encoder(src)
89 | return memory
90 |
91 | def forward_decoder(self, tgt, memory):
92 | tgt_mask = self.gen_nopeek_mask(tgt.shape[0]).to(tgt.device)
93 | tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
94 |
95 | output = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask)
96 | # output = rearrange(output, 't n e -> n t e')
97 | output = output.transpose(0, 1)
98 |
99 | return self.fc(output), memory
100 |
101 | def expand_memory(self, memory, beam_size):
102 | memory = memory.repeat(1, beam_size, 1)
103 | return memory
104 |
105 | def get_memory(self, memory, i):
106 | memory = memory[:, [i], :]
107 | return memory
108 |
109 |
110 | class PositionalEncoding(nn.Module):
111 | def __init__(self, d_model, dropout=0.1, max_len=100):
112 | super(PositionalEncoding, self).__init__()
113 | self.dropout = nn.Dropout(p=dropout)
114 |
115 | pe = torch.zeros(max_len, d_model)
116 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
117 | div_term = torch.exp(
118 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
119 | )
120 | pe[:, 0::2] = torch.sin(position * div_term)
121 | pe[:, 1::2] = torch.cos(position * div_term)
122 | pe = pe.unsqueeze(0).transpose(0, 1)
123 | self.register_buffer("pe", pe)
124 |
125 | def forward(self, x):
126 | x = x + self.pe[: x.size(0), :]
127 |
128 | return self.dropout(x)
129 |
130 |
131 | class LearnedPositionalEncoding(nn.Module):
132 | def __init__(self, d_model, dropout=0.1, max_len=100):
133 | super(LearnedPositionalEncoding, self).__init__()
134 | self.dropout = nn.Dropout(p=dropout)
135 |
136 | self.pos_embed = nn.Embedding(max_len, d_model)
137 | self.layernorm = LayerNorm(d_model)
138 |
139 | def forward(self, x):
140 | seq_len = x.size(0)
141 | pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
142 | pos = pos.unsqueeze(-1).expand(x.size()[:2])
143 | x = x + self.pos_embed(pos)
144 | return self.dropout(self.layernorm(x))
145 |
146 |
147 | class LayerNorm(nn.Module):
148 | "A layernorm module in the TF style (epsilon inside the square root)."
149 |
150 | def __init__(self, d_model, variance_epsilon=1e-12):
151 | super().__init__()
152 | self.gamma = nn.Parameter(torch.ones(d_model))
153 | self.beta = nn.Parameter(torch.zeros(d_model))
154 | self.variance_epsilon = variance_epsilon
155 |
156 | def forward(self, x):
157 | u = x.mean(-1, keepdim=True)
158 | s = (x - u).pow(2).mean(-1, keepdim=True)
159 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
160 | return self.gamma * x + self.beta
161 |
--------------------------------------------------------------------------------
/vietocr/model/trainer.py:
--------------------------------------------------------------------------------
1 | from vietocr.optim.optim import ScheduledOptim
2 | from vietocr.optim.labelsmoothingloss import LabelSmoothingLoss
3 | from torch.optim import Adam, SGD, AdamW
4 | from torch import nn
5 | from vietocr.tool.translate import build_model
6 | from vietocr.tool.translate import translate, batch_translate_beam_search
7 | from vietocr.tool.utils import download_weights
8 | from vietocr.tool.logger import Logger
9 | from vietocr.loader.aug import ImgAugTransformV2
10 |
11 | import yaml
12 | import torch
13 | from vietocr.loader.dataloader_v1 import DataGen
14 | from vietocr.loader.dataloader import OCRDataset, ClusterRandomSampler, Collator
15 | from torch.utils.data import DataLoader
16 | from einops import rearrange
17 | from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, OneCycleLR
18 |
19 | import torchvision
20 |
21 | from vietocr.tool.utils import compute_accuracy
22 | from PIL import Image
23 | import numpy as np
24 | import os
25 | import matplotlib.pyplot as plt
26 | import time
27 |
28 |
29 | class Trainer:
30 | def __init__(self, config, pretrained=True, augmentor=ImgAugTransformV2()):
31 |
32 | self.config = config
33 | self.model, self.vocab = build_model(config)
34 |
35 | self.device = config["device"]
36 | self.num_iters = config["trainer"]["iters"]
37 | self.beamsearch = config["predictor"]["beamsearch"]
38 |
39 | self.data_root = config["dataset"]["data_root"]
40 | self.train_annotation = config["dataset"]["train_annotation"]
41 | self.valid_annotation = config["dataset"]["valid_annotation"]
42 | self.dataset_name = config["dataset"]["name"]
43 |
44 | self.batch_size = config["trainer"]["batch_size"]
45 | self.print_every = config["trainer"]["print_every"]
46 | self.valid_every = config["trainer"]["valid_every"]
47 |
48 | self.image_aug = config["aug"]["image_aug"]
49 | self.masked_language_model = config["aug"]["masked_language_model"]
50 |
51 | self.checkpoint = config["trainer"]["checkpoint"]
52 | self.export_weights = config["trainer"]["export"]
53 | self.metrics = config["trainer"]["metrics"]
54 | logger = config["trainer"]["log"]
55 |
56 | if logger:
57 | self.logger = Logger(logger)
58 |
59 | if pretrained:
60 | weight_file = download_weights(config["pretrain"], quiet=config["quiet"])
61 | self.load_weights(weight_file)
62 |
63 | self.iter = 0
64 |
65 | self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09)
66 | self.scheduler = OneCycleLR(
67 | self.optimizer, total_steps=self.num_iters, **config["optimizer"]
68 | )
69 | # self.optimizer = ScheduledOptim(
70 | # Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
71 | # #config['transformer']['d_model'],
72 | # 512,
73 | # **config['optimizer'])
74 |
75 | self.criterion = LabelSmoothingLoss(
76 | len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1
77 | )
78 |
79 | transforms = None
80 | if self.image_aug:
81 | transforms = augmentor
82 |
83 | self.train_gen = self.data_gen(
84 | "train_{}".format(self.dataset_name),
85 | self.data_root,
86 | self.train_annotation,
87 | self.masked_language_model,
88 | transform=transforms,
89 | )
90 | if self.valid_annotation:
91 | self.valid_gen = self.data_gen(
92 | "valid_{}".format(self.dataset_name),
93 | self.data_root,
94 | self.valid_annotation,
95 | masked_language_model=False,
96 | )
97 |
98 | self.train_losses = []
99 |
100 | def train(self):
101 | total_loss = 0
102 |
103 | total_loader_time = 0
104 | total_gpu_time = 0
105 | best_acc = 0
106 |
107 | data_iter = iter(self.train_gen)
108 | for i in range(self.num_iters):
109 | self.iter += 1
110 |
111 | start = time.time()
112 |
113 | try:
114 | batch = next(data_iter)
115 | except StopIteration:
116 | data_iter = iter(self.train_gen)
117 | batch = next(data_iter)
118 |
119 | total_loader_time += time.time() - start
120 |
121 | start = time.time()
122 | loss = self.step(batch)
123 | total_gpu_time += time.time() - start
124 |
125 | total_loss += loss
126 | self.train_losses.append((self.iter, loss))
127 |
128 | if self.iter % self.print_every == 0:
129 | info = "iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}".format(
130 | self.iter,
131 | total_loss / self.print_every,
132 | self.optimizer.param_groups[0]["lr"],
133 | total_loader_time,
134 | total_gpu_time,
135 | )
136 |
137 | total_loss = 0
138 | total_loader_time = 0
139 | total_gpu_time = 0
140 | print(info)
141 | self.logger.log(info)
142 |
143 | if self.valid_annotation and self.iter % self.valid_every == 0:
144 | val_loss = self.validate()
145 | acc_full_seq, acc_per_char = self.precision(self.metrics)
146 |
147 | info = "iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}".format(
148 | self.iter, val_loss, acc_full_seq, acc_per_char
149 | )
150 | print(info)
151 | self.logger.log(info)
152 |
153 | if acc_full_seq > best_acc:
154 | self.save_weights(self.export_weights)
155 | best_acc = acc_full_seq
156 |
157 | def validate(self):
158 | self.model.eval()
159 |
160 | total_loss = []
161 |
162 | with torch.no_grad():
163 | for step, batch in enumerate(self.valid_gen):
164 | batch = self.batch_to_device(batch)
165 | img, tgt_input, tgt_output, tgt_padding_mask = (
166 | batch["img"],
167 | batch["tgt_input"],
168 | batch["tgt_output"],
169 | batch["tgt_padding_mask"],
170 | )
171 |
172 | outputs = self.model(img, tgt_input, tgt_padding_mask)
173 | # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
174 |
175 | outputs = outputs.flatten(0, 1)
176 | tgt_output = tgt_output.flatten()
177 | loss = self.criterion(outputs, tgt_output)
178 |
179 | total_loss.append(loss.item())
180 |
181 | del outputs
182 | del loss
183 |
184 | total_loss = np.mean(total_loss)
185 | self.model.train()
186 |
187 | return total_loss
188 |
189 | def predict(self, sample=None):
190 | pred_sents = []
191 | actual_sents = []
192 | img_files = []
193 |
194 | for batch in self.valid_gen:
195 | batch = self.batch_to_device(batch)
196 |
197 | if self.beamsearch:
198 | translated_sentence = batch_translate_beam_search(
199 | batch["img"], self.model
200 | )
201 | prob = None
202 | else:
203 | translated_sentence, prob = translate(batch["img"], self.model)
204 |
205 | pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
206 | actual_sent = self.vocab.batch_decode(batch["tgt_output"].tolist())
207 |
208 | img_files.extend(batch["filenames"])
209 |
210 | pred_sents.extend(pred_sent)
211 | actual_sents.extend(actual_sent)
212 |
213 | if sample != None and len(pred_sents) > sample:
214 | break
215 |
216 | return pred_sents, actual_sents, img_files, prob
217 |
218 | def precision(self, sample=None):
219 |
220 | pred_sents, actual_sents, _, _ = self.predict(sample=sample)
221 |
222 | acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode="full_sequence")
223 | acc_per_char = compute_accuracy(actual_sents, pred_sents, mode="per_char")
224 |
225 | return acc_full_seq, acc_per_char
226 |
227 | def visualize_prediction(
228 | self, sample=16, errorcase=False, fontname="serif", fontsize=16
229 | ):
230 |
231 | pred_sents, actual_sents, img_files, probs = self.predict(sample)
232 |
233 | if errorcase:
234 | wrongs = []
235 | for i in range(len(img_files)):
236 | if pred_sents[i] != actual_sents[i]:
237 | wrongs.append(i)
238 |
239 | pred_sents = [pred_sents[i] for i in wrongs]
240 | actual_sents = [actual_sents[i] for i in wrongs]
241 | img_files = [img_files[i] for i in wrongs]
242 | probs = [probs[i] for i in wrongs]
243 |
244 | img_files = img_files[:sample]
245 |
246 | fontdict = {"family": fontname, "size": fontsize}
247 |
248 | for vis_idx in range(0, len(img_files)):
249 | img_path = img_files[vis_idx]
250 | pred_sent = pred_sents[vis_idx]
251 | actual_sent = actual_sents[vis_idx]
252 | prob = probs[vis_idx]
253 |
254 | img = Image.open(open(img_path, "rb"))
255 | plt.figure()
256 | plt.imshow(img)
257 | plt.title(
258 | "prob: {:.3f} - pred: {} - actual: {}".format(
259 | prob, pred_sent, actual_sent
260 | ),
261 | loc="left",
262 | fontdict=fontdict,
263 | )
264 | plt.axis("off")
265 |
266 | plt.show()
267 |
268 | def visualize_dataset(self, sample=16, fontname="serif"):
269 | n = 0
270 | for batch in self.train_gen:
271 | for i in range(self.batch_size):
272 | img = batch["img"][i].numpy().transpose(1, 2, 0)
273 | sent = self.vocab.decode(batch["tgt_input"].T[i].tolist())
274 |
275 | plt.figure()
276 | plt.title("sent: {}".format(sent), loc="center", fontname=fontname)
277 | plt.imshow(img)
278 | plt.axis("off")
279 |
280 | n += 1
281 | if n >= sample:
282 | plt.show()
283 | return
284 |
285 | def load_checkpoint(self, filename):
286 | checkpoint = torch.load(filename)
287 |
288 | optim = ScheduledOptim(
289 | Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
290 | self.config["transformer"]["d_model"],
291 | **self.config["optimizer"]
292 | )
293 |
294 | self.optimizer.load_state_dict(checkpoint["optimizer"])
295 | self.model.load_state_dict(checkpoint["state_dict"])
296 | self.iter = checkpoint["iter"]
297 |
298 | self.train_losses = checkpoint["train_losses"]
299 |
300 | def save_checkpoint(self, filename):
301 | state = {
302 | "iter": self.iter,
303 | "state_dict": self.model.state_dict(),
304 | "optimizer": self.optimizer.state_dict(),
305 | "train_losses": self.train_losses,
306 | }
307 |
308 | path, _ = os.path.split(filename)
309 | os.makedirs(path, exist_ok=True)
310 |
311 | torch.save(state, filename)
312 |
313 | def load_weights(self, filename):
314 | state_dict = torch.load(filename, map_location=torch.device(self.device))
315 |
316 | for name, param in self.model.named_parameters():
317 | if name not in state_dict:
318 | print("{} not found".format(name))
319 | elif state_dict[name].shape != param.shape:
320 | print(
321 | "{} missmatching shape, required {} but found {}".format(
322 | name, param.shape, state_dict[name].shape
323 | )
324 | )
325 | del state_dict[name]
326 |
327 | self.model.load_state_dict(state_dict, strict=False)
328 |
329 | def save_weights(self, filename):
330 | path, _ = os.path.split(filename)
331 | os.makedirs(path, exist_ok=True)
332 |
333 | torch.save(self.model.state_dict(), filename)
334 |
335 | def batch_to_device(self, batch):
336 | img = batch["img"].to(self.device, non_blocking=True)
337 | tgt_input = batch["tgt_input"].to(self.device, non_blocking=True)
338 | tgt_output = batch["tgt_output"].to(self.device, non_blocking=True)
339 | tgt_padding_mask = batch["tgt_padding_mask"].to(self.device, non_blocking=True)
340 |
341 | batch = {
342 | "img": img,
343 | "tgt_input": tgt_input,
344 | "tgt_output": tgt_output,
345 | "tgt_padding_mask": tgt_padding_mask,
346 | "filenames": batch["filenames"],
347 | }
348 |
349 | return batch
350 |
351 | def data_gen(
352 | self,
353 | lmdb_path,
354 | data_root,
355 | annotation,
356 | masked_language_model=True,
357 | transform=None,
358 | ):
359 | dataset = OCRDataset(
360 | lmdb_path=lmdb_path,
361 | root_dir=data_root,
362 | annotation_path=annotation,
363 | vocab=self.vocab,
364 | transform=transform,
365 | image_height=self.config["dataset"]["image_height"],
366 | image_min_width=self.config["dataset"]["image_min_width"],
367 | image_max_width=self.config["dataset"]["image_max_width"],
368 | )
369 |
370 | sampler = ClusterRandomSampler(dataset, self.batch_size, True)
371 | collate_fn = Collator(masked_language_model)
372 |
373 | gen = DataLoader(
374 | dataset,
375 | batch_size=self.batch_size,
376 | sampler=sampler,
377 | collate_fn=collate_fn,
378 | shuffle=False,
379 | drop_last=False,
380 | **self.config["dataloader"]
381 | )
382 |
383 | return gen
384 |
385 | def data_gen_v1(self, lmdb_path, data_root, annotation):
386 | data_gen = DataGen(
387 | data_root,
388 | annotation,
389 | self.vocab,
390 | "cpu",
391 | image_height=self.config["dataset"]["image_height"],
392 | image_min_width=self.config["dataset"]["image_min_width"],
393 | image_max_width=self.config["dataset"]["image_max_width"],
394 | )
395 |
396 | return data_gen
397 |
398 | def step(self, batch):
399 | self.model.train()
400 |
401 | batch = self.batch_to_device(batch)
402 | img, tgt_input, tgt_output, tgt_padding_mask = (
403 | batch["img"],
404 | batch["tgt_input"],
405 | batch["tgt_output"],
406 | batch["tgt_padding_mask"],
407 | )
408 |
409 | outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask)
410 | # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
411 | outputs = outputs.view(-1, outputs.size(2)) # flatten(0, 1)
412 | tgt_output = tgt_output.view(-1) # flatten()
413 |
414 | loss = self.criterion(outputs, tgt_output)
415 |
416 | self.optimizer.zero_grad()
417 |
418 | loss.backward()
419 |
420 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
421 |
422 | self.optimizer.step()
423 | self.scheduler.step()
424 |
425 | loss_item = loss.item()
426 |
427 | return loss_item
428 |
--------------------------------------------------------------------------------
/vietocr/model/transformerocr.py:
--------------------------------------------------------------------------------
1 | from vietocr.model.backbone.cnn import CNN
2 | from vietocr.model.seqmodel.transformer import LanguageTransformer
3 | from vietocr.model.seqmodel.seq2seq import Seq2Seq
4 | from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq
5 | from torch import nn
6 |
7 |
8 | class VietOCR(nn.Module):
9 | def __init__(
10 | self,
11 | vocab_size,
12 | backbone,
13 | cnn_args,
14 | transformer_args,
15 | seq_modeling="transformer",
16 | ):
17 |
18 | super(VietOCR, self).__init__()
19 |
20 | self.cnn = CNN(backbone, **cnn_args)
21 | self.seq_modeling = seq_modeling
22 |
23 | if seq_modeling == "transformer":
24 | self.transformer = LanguageTransformer(vocab_size, **transformer_args)
25 | elif seq_modeling == "seq2seq":
26 | self.transformer = Seq2Seq(vocab_size, **transformer_args)
27 | elif seq_modeling == "convseq2seq":
28 | self.transformer = ConvSeq2Seq(vocab_size, **transformer_args)
29 | else:
30 | raise ("Not Support Seq Model")
31 |
32 | def forward(self, img, tgt_input, tgt_key_padding_mask):
33 | """
34 | Shape:
35 | - img: (N, C, H, W)
36 | - tgt_input: (T, N)
37 | - tgt_key_padding_mask: (N, T)
38 | - output: b t v
39 | """
40 | src = self.cnn(img)
41 |
42 | if self.seq_modeling == "transformer":
43 | outputs = self.transformer(
44 | src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask
45 | )
46 | elif self.seq_modeling == "seq2seq":
47 | outputs = self.transformer(src, tgt_input)
48 | elif self.seq_modeling == "convseq2seq":
49 | outputs = self.transformer(src, tgt_input)
50 | return outputs
51 |
--------------------------------------------------------------------------------
/vietocr/model/vocab.py:
--------------------------------------------------------------------------------
1 | class Vocab:
2 | def __init__(self, chars):
3 | self.pad = 0
4 | self.go = 1
5 | self.eos = 2
6 | self.mask_token = 3
7 |
8 | self.chars = chars
9 |
10 | self.c2i = {c: i + 4 for i, c in enumerate(chars)}
11 |
12 | self.i2c = {i + 4: c for i, c in enumerate(chars)}
13 |
14 | self.i2c[0] = ""
15 | self.i2c[1] = ""
16 | self.i2c[2] = ""
17 | self.i2c[3] = "*"
18 |
19 | def encode(self, chars):
20 | return [self.go] + [self.c2i[c] for c in chars] + [self.eos]
21 |
22 | def decode(self, ids):
23 | first = 1 if self.go in ids else 0
24 | last = ids.index(self.eos) if self.eos in ids else None
25 | sent = "".join([self.i2c[i] for i in ids[first:last]])
26 | return sent
27 |
28 | def __len__(self):
29 | return len(self.c2i) + 4
30 |
31 | def batch_decode(self, arr):
32 | texts = [self.decode(ids) for ids in arr]
33 | return texts
34 |
35 | def __str__(self):
36 | return self.chars
37 |
--------------------------------------------------------------------------------
/vietocr/optim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/optim/__init__.py
--------------------------------------------------------------------------------
/vietocr/optim/labelsmoothingloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LabelSmoothingLoss(nn.Module):
6 | def __init__(self, classes, padding_idx, smoothing=0.0, dim=-1):
7 | super(LabelSmoothingLoss, self).__init__()
8 | self.confidence = 1.0 - smoothing
9 | self.smoothing = smoothing
10 | self.cls = classes
11 | self.dim = dim
12 | self.padding_idx = padding_idx
13 |
14 | def forward(self, pred, target):
15 | pred = pred.log_softmax(dim=self.dim)
16 | with torch.no_grad():
17 | # true_dist = pred.data.clone()
18 | true_dist = torch.zeros_like(pred)
19 | true_dist.fill_(self.smoothing / (self.cls - 2))
20 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
21 | true_dist[:, self.padding_idx] = 0
22 | mask = torch.nonzero(target.data == self.padding_idx, as_tuple=False)
23 | if mask.dim() > 0:
24 | true_dist.index_fill_(0, mask.squeeze(), 0.0)
25 |
26 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
27 |
--------------------------------------------------------------------------------
/vietocr/optim/optim.py:
--------------------------------------------------------------------------------
1 | class ScheduledOptim:
2 | """A simple wrapper class for learning rate scheduling"""
3 |
4 | def __init__(self, optimizer, d_model, init_lr, n_warmup_steps):
5 | assert n_warmup_steps > 0, "must be greater than 0"
6 |
7 | self._optimizer = optimizer
8 | self.init_lr = init_lr
9 | self.d_model = d_model
10 | self.n_warmup_steps = n_warmup_steps
11 | self.n_steps = 0
12 |
13 | def step(self):
14 | "Step with the inner optimizer"
15 | self._update_learning_rate()
16 | self._optimizer.step()
17 |
18 | def zero_grad(self):
19 | "Zero out the gradients with the inner optimizer"
20 | self._optimizer.zero_grad()
21 |
22 | def _get_lr_scale(self):
23 | d_model = self.d_model
24 | n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
25 | return (d_model**-0.5) * min(
26 | n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5)
27 | )
28 |
29 | def state_dict(self):
30 | optimizer_state_dict = {
31 | "init_lr": self.init_lr,
32 | "d_model": self.d_model,
33 | "n_warmup_steps": self.n_warmup_steps,
34 | "n_steps": self.n_steps,
35 | "_optimizer": self._optimizer.state_dict(),
36 | }
37 |
38 | return optimizer_state_dict
39 |
40 | def load_state_dict(self, state_dict):
41 | self.init_lr = state_dict["init_lr"]
42 | self.d_model = state_dict["d_model"]
43 | self.n_warmup_steps = state_dict["n_warmup_steps"]
44 | self.n_steps = state_dict["n_steps"]
45 |
46 | self._optimizer.load_state_dict(state_dict["_optimizer"])
47 |
48 | def _update_learning_rate(self):
49 | """Learning rate scheduling per step"""
50 |
51 | self.n_steps += 1
52 |
53 | for param_group in self._optimizer.param_groups:
54 | lr = self.init_lr * self._get_lr_scale()
55 | self.lr = lr
56 |
57 | param_group["lr"] = lr
58 |
--------------------------------------------------------------------------------
/vietocr/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from PIL import Image
3 |
4 | from vietocr.tool.predictor import Predictor
5 | from vietocr.tool.config import Cfg
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--img", required=True, help="foo help")
11 | parser.add_argument("--config", required=True, help="foo help")
12 |
13 | args = parser.parse_args()
14 | config = Cfg.load_config_from_file(args.config)
15 |
16 | detector = Predictor(config)
17 |
18 | img = Image.open(args.img)
19 | s = detector.predict(img)
20 |
21 | print(s)
22 |
23 |
24 | if __name__ == "__main__":
25 | main()
26 |
--------------------------------------------------------------------------------
/vietocr/requirement.txt:
--------------------------------------------------------------------------------
1 | einops==0.2.0
2 |
--------------------------------------------------------------------------------
/vietocr/tests/image/001099025107.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/001099025107.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/026301003919.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/026301003919.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/036170002830.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/036170002830.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/038078002355.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/038078002355.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/038089010274.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/038089010274.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/038144000109.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/038144000109.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/060085000115.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/060085000115.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/072183002222.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/072183002222.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/079084000809.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/079084000809.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/image/079193002341.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tests/image/079193002341.jpeg
--------------------------------------------------------------------------------
/vietocr/tests/sample.txt:
--------------------------------------------------------------------------------
1 | ./image/036170002830.jpeg HOÀNG THỊ THOI
2 | ./image/079193002341.jpeg TRỊNH THỊ THÚY HẰNG
3 | ./image/001099025107.jpeg NGUYỄN VĂN BÌNH
4 | ./image/060085000115.jpeg NGUYỄN MINH TOÀN
5 | ./image/026301003919.jpeg NGUYỄN THỊ KIỀU TRANG
6 | ./image/079084000809.jpeg LÊ NGỌC PHƯƠNG KHANH
7 | ./image/038144000109.jpeg ĐÀO THỊ TƠ
8 | ./image/072183002222.jpeg NGUYỄN THANH PHƯỚC
9 | ./image/038078002355.jpeg HÀ ĐÌNH LỢI
10 | ./image/038089010274.jpeg HÀ VĂN LUÂN
11 |
--------------------------------------------------------------------------------
/vietocr/tests/utest.py:
--------------------------------------------------------------------------------
1 | from vietocr.loader.dataloader_v1 import DataGen
2 | from vietocr.model.vocab import Vocab
3 |
4 |
5 | def test_loader():
6 | chars = "aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ "
7 |
8 | vocab = Vocab(chars)
9 | s_gen = DataGen("./vietocr/tests/", "sample.txt", vocab, "cpu", 32, 512)
10 |
11 | iterator = s_gen.gen(30)
12 | for batch in iterator:
13 | assert batch["img"].shape[1] == 3, "image must have 3 channels"
14 | assert batch["img"].shape[2] == 32, "the height must be 32"
15 | print(
16 | batch["img"].shape,
17 | batch["tgt_input"].shape,
18 | batch["tgt_output"].shape,
19 | batch["tgt_padding_mask"].shape,
20 | )
21 |
22 |
23 | if __name__ == "__main__":
24 | test_loader()
25 |
--------------------------------------------------------------------------------
/vietocr/tool/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbcquoc/vietocr/fe8c3a7fc714aec57ab81cec844eb3adf0c1636c/vietocr/tool/__init__.py
--------------------------------------------------------------------------------
/vietocr/tool/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from vietocr.tool.utils import download_config
3 |
4 | url_config = {
5 | "vgg_transformer": "vgg-transformer.yml",
6 | "resnet_transformer": "resnet_transformer.yml",
7 | "resnet_fpn_transformer": "resnet_fpn_transformer.yml",
8 | "vgg_seq2seq": "vgg-seq2seq.yml",
9 | "vgg_convseq2seq": "vgg_convseq2seq.yml",
10 | "vgg_decoderseq2seq": "vgg_decoderseq2seq.yml",
11 | "base": "base.yml",
12 | }
13 |
14 |
15 | class Cfg(dict):
16 | def __init__(self, config_dict):
17 | super(Cfg, self).__init__(**config_dict)
18 | self.__dict__ = self
19 |
20 | @staticmethod
21 | def load_config_from_file(fname):
22 | # base_config = download_config(url_config['base'])
23 | base_config = {}
24 | with open(fname, encoding="utf-8") as f:
25 | config = yaml.safe_load(f)
26 | base_config.update(config)
27 |
28 | return Cfg(base_config)
29 |
30 | @staticmethod
31 | def load_config_from_name(name):
32 | base_config = download_config(url_config["base"])
33 | config = download_config(url_config[name])
34 |
35 | base_config.update(config)
36 | return Cfg(base_config)
37 |
38 | def save(self, fname):
39 | with open(fname, "w", encoding="utf-8") as outfile:
40 | yaml.dump(dict(self), outfile, default_flow_style=False, allow_unicode=True)
41 |
--------------------------------------------------------------------------------
/vietocr/tool/create_dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import lmdb # install lmdb by "pip install lmdb"
4 | import cv2
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 |
9 | def checkImageIsValid(imageBin):
10 | isvalid = True
11 | imgH = None
12 | imgW = None
13 |
14 | imageBuf = np.fromstring(imageBin, dtype=np.uint8)
15 | try:
16 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
17 |
18 | imgH, imgW = img.shape[0], img.shape[1]
19 | if imgH * imgW == 0:
20 | isvalid = False
21 | except Exception as e:
22 | isvalid = False
23 |
24 | return isvalid, imgH, imgW
25 |
26 |
27 | def writeCache(env, cache):
28 | with env.begin(write=True) as txn:
29 | for k, v in cache.items():
30 | txn.put(k.encode(), v)
31 |
32 |
33 | def createDataset(outputPath, root_dir, annotation_path):
34 | """
35 | Create LMDB dataset for CRNN training.
36 | ARGS:
37 | outputPath : LMDB output path
38 | imagePathList : list of image path
39 | labelList : list of corresponding groundtruth texts
40 | lexiconList : (optional) list of lexicon lists
41 | checkValid : if true, check the validity of every image
42 | """
43 |
44 | annotation_path = os.path.join(root_dir, annotation_path)
45 | with open(annotation_path, "r", encoding="utf-8") as ann_file:
46 | lines = ann_file.readlines()
47 | annotations = [l.strip().split("\t") for l in lines]
48 |
49 | nSamples = len(annotations)
50 | env = lmdb.open(outputPath, map_size=1099511627776)
51 | cache = {}
52 | cnt = 0
53 | error = 0
54 |
55 | pbar = tqdm(range(nSamples), ncols=100, desc="Create {}".format(outputPath))
56 | for i in pbar:
57 | imageFile, label = annotations[i]
58 | imagePath = os.path.join(root_dir, imageFile)
59 |
60 | if not os.path.exists(imagePath):
61 | error += 1
62 | continue
63 |
64 | with open(imagePath, "rb") as f:
65 | imageBin = f.read()
66 | isvalid, imgH, imgW = checkImageIsValid(imageBin)
67 |
68 | if not isvalid:
69 | error += 1
70 | continue
71 |
72 | imageKey = "image-%09d" % cnt
73 | labelKey = "label-%09d" % cnt
74 | pathKey = "path-%09d" % cnt
75 | dimKey = "dim-%09d" % cnt
76 |
77 | cache[imageKey] = imageBin
78 | cache[labelKey] = label.encode()
79 | cache[pathKey] = imageFile.encode()
80 | cache[dimKey] = np.array([imgH, imgW], dtype=np.int32).tobytes()
81 |
82 | cnt += 1
83 |
84 | if cnt % 1000 == 0:
85 | writeCache(env, cache)
86 | cache = {}
87 |
88 | nSamples = cnt - 1
89 | cache["num-samples"] = str(nSamples).encode()
90 | writeCache(env, cache)
91 |
92 | if error > 0:
93 | print("Remove {} invalid images".format(error))
94 | print("Created dataset with %d samples" % nSamples)
95 | sys.stdout.flush()
96 |
--------------------------------------------------------------------------------
/vietocr/tool/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class Logger:
5 | def __init__(self, fname):
6 | path, _ = os.path.split(fname)
7 | os.makedirs(path, exist_ok=True)
8 |
9 | self.logger = open(fname, "w")
10 |
11 | def log(self, string):
12 | self.logger.write(string + "\n")
13 | self.logger.flush()
14 |
15 | def close(self):
16 | self.logger.close()
17 |
--------------------------------------------------------------------------------
/vietocr/tool/predictor.py:
--------------------------------------------------------------------------------
1 | from vietocr.tool.translate import (
2 | build_model,
3 | translate,
4 | translate_beam_search,
5 | process_input,
6 | predict,
7 | )
8 | from vietocr.tool.utils import download_weights
9 |
10 | import torch
11 | from collections import defaultdict
12 |
13 |
14 | class Predictor:
15 | def __init__(self, config):
16 |
17 | device = config["device"]
18 |
19 | model, vocab = build_model(config)
20 | weights = "/tmp/weights.pth"
21 |
22 | if config["weights"].startswith("http"):
23 | weights = download_weights(config["weights"])
24 | else:
25 | weights = config["weights"]
26 |
27 | model.load_state_dict(torch.load(weights, map_location=torch.device(device)))
28 |
29 | self.config = config
30 | self.model = model
31 | self.vocab = vocab
32 | self.device = device
33 |
34 | def predict(self, img, return_prob=False):
35 | img = process_input(
36 | img,
37 | self.config["dataset"]["image_height"],
38 | self.config["dataset"]["image_min_width"],
39 | self.config["dataset"]["image_max_width"],
40 | )
41 | img = img.to(self.config["device"])
42 |
43 | if self.config["predictor"]["beamsearch"]:
44 | sent = translate_beam_search(img, self.model)
45 | s = sent
46 | prob = None
47 | else:
48 | s, prob = translate(img, self.model)
49 | s = s[0].tolist()
50 | prob = prob[0]
51 |
52 | s = self.vocab.decode(s)
53 |
54 | if return_prob:
55 | return s, prob
56 | else:
57 | return s
58 |
59 | def predict_batch(self, imgs, return_prob=False):
60 | bucket = defaultdict(list)
61 | bucket_idx = defaultdict(list)
62 | bucket_pred = {}
63 |
64 | sents, probs = [0] * len(imgs), [0] * len(imgs)
65 |
66 | for i, img in enumerate(imgs):
67 | img = process_input(
68 | img,
69 | self.config["dataset"]["image_height"],
70 | self.config["dataset"]["image_min_width"],
71 | self.config["dataset"]["image_max_width"],
72 | )
73 |
74 | bucket[img.shape[-1]].append(img)
75 | bucket_idx[img.shape[-1]].append(i)
76 |
77 | for k, batch in bucket.items():
78 | batch = torch.cat(batch, 0).to(self.device)
79 | s, prob = translate(batch, self.model)
80 | prob = prob.tolist()
81 |
82 | s = s.tolist()
83 | s = self.vocab.batch_decode(s)
84 |
85 | bucket_pred[k] = (s, prob)
86 |
87 | for k in bucket_pred:
88 | idx = bucket_idx[k]
89 | sent, prob = bucket_pred[k]
90 | for i, j in enumerate(idx):
91 | sents[j] = sent[i]
92 | probs[j] = prob[i]
93 |
94 | if return_prob:
95 | return sents, probs
96 | else:
97 | return sents
98 |
--------------------------------------------------------------------------------
/vietocr/tool/translate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 | from PIL import Image
5 | from torch.nn.functional import log_softmax, softmax
6 |
7 | from vietocr.model.transformerocr import VietOCR
8 | from vietocr.model.vocab import Vocab
9 | from vietocr.model.beam import Beam
10 |
11 |
12 | def batch_translate_beam_search(
13 | img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2
14 | ):
15 | # img: NxCxHxW
16 | model.eval()
17 | device = img.device
18 | sents = []
19 |
20 | with torch.no_grad():
21 | src = model.cnn(img)
22 | print(src.shap)
23 | memories = model.transformer.forward_encoder(src)
24 | for i in range(src.size(0)):
25 | # memory = memories[:,i,:].repeat(1, beam_size, 1) # TxNxE
26 | memory = model.transformer.get_memory(memories, i)
27 | sent = beamsearch(
28 | memory,
29 | model,
30 | device,
31 | beam_size,
32 | candidates,
33 | max_seq_length,
34 | sos_token,
35 | eos_token,
36 | )
37 | sents.append(sent)
38 |
39 | sents = np.asarray(sents)
40 |
41 | return sents
42 |
43 |
44 | def translate_beam_search(
45 | img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2
46 | ):
47 | # img: 1xCxHxW
48 | model.eval()
49 | device = img.device
50 |
51 | with torch.no_grad():
52 | src = model.cnn(img)
53 | memory = model.transformer.forward_encoder(src) # TxNxE
54 | sent = beamsearch(
55 | memory,
56 | model,
57 | device,
58 | beam_size,
59 | candidates,
60 | max_seq_length,
61 | sos_token,
62 | eos_token,
63 | )
64 |
65 | return sent
66 |
67 |
68 | def beamsearch(
69 | memory,
70 | model,
71 | device,
72 | beam_size=4,
73 | candidates=1,
74 | max_seq_length=128,
75 | sos_token=1,
76 | eos_token=2,
77 | ):
78 | # memory: Tx1xE
79 | model.eval()
80 |
81 | beam = Beam(
82 | beam_size=beam_size,
83 | min_length=0,
84 | n_top=candidates,
85 | ranker=None,
86 | start_token_id=sos_token,
87 | end_token_id=eos_token,
88 | )
89 |
90 | with torch.no_grad():
91 | # memory = memory.repeat(1, beam_size, 1) # TxNxE
92 | memory = model.transformer.expand_memory(memory, beam_size)
93 |
94 | for _ in range(max_seq_length):
95 |
96 | tgt_inp = beam.get_current_state().transpose(0, 1).to(device) # TxN
97 | decoder_outputs, memory = model.transformer.forward_decoder(tgt_inp, memory)
98 |
99 | log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0), dim=-1)
100 | beam.advance(log_prob.cpu())
101 |
102 | if beam.done():
103 | break
104 |
105 | scores, ks = beam.sort_finished(minimum=1)
106 |
107 | hypothesises = []
108 | for i, (times, k) in enumerate(ks[:candidates]):
109 | hypothesis = beam.get_hypothesis(times, k)
110 | hypothesises.append(hypothesis)
111 |
112 | return [1] + [int(i) for i in hypothesises[0][:-1]]
113 |
114 |
115 | def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
116 | "data: BxCXHxW"
117 | model.eval()
118 | device = img.device
119 |
120 | with torch.no_grad():
121 | src = model.cnn(img)
122 | memory = model.transformer.forward_encoder(src)
123 |
124 | translated_sentence = [[sos_token] * len(img)]
125 | char_probs = [[1] * len(img)]
126 |
127 | max_length = 0
128 |
129 | while max_length <= max_seq_length and not all(
130 | np.any(np.asarray(translated_sentence).T == eos_token, axis=1)
131 | ):
132 |
133 | tgt_inp = torch.LongTensor(translated_sentence).to(device)
134 |
135 | # output = model(img, tgt_inp, tgt_key_padding_mask=None)
136 | # output = model.transformer(src, tgt_inp, tgt_key_padding_mask=None)
137 | output, memory = model.transformer.forward_decoder(tgt_inp, memory)
138 | output = softmax(output, dim=-1)
139 | output = output.to("cpu")
140 |
141 | values, indices = torch.topk(output, 5)
142 |
143 | indices = indices[:, -1, 0]
144 | indices = indices.tolist()
145 |
146 | values = values[:, -1, 0]
147 | values = values.tolist()
148 | char_probs.append(values)
149 |
150 | translated_sentence.append(indices)
151 | max_length += 1
152 |
153 | del output
154 |
155 | translated_sentence = np.asarray(translated_sentence).T
156 |
157 | char_probs = np.asarray(char_probs).T
158 | char_probs = np.multiply(char_probs, translated_sentence > 3)
159 | char_probs = np.sum(char_probs, axis=-1) / (char_probs > 0).sum(-1)
160 |
161 | return translated_sentence, char_probs
162 |
163 |
164 | def build_model(config):
165 | vocab = Vocab(config["vocab"])
166 | device = config["device"]
167 |
168 | model = VietOCR(
169 | len(vocab),
170 | config["backbone"],
171 | config["cnn"],
172 | config["transformer"],
173 | config["seq_modeling"],
174 | )
175 |
176 | model = model.to(device)
177 |
178 | return model, vocab
179 |
180 |
181 | def resize(w, h, expected_height, image_min_width, image_max_width):
182 | new_w = int(expected_height * float(w) / float(h))
183 | round_to = 10
184 | new_w = math.ceil(new_w / round_to) * round_to
185 | new_w = max(new_w, image_min_width)
186 | new_w = min(new_w, image_max_width)
187 |
188 | return new_w, expected_height
189 |
190 |
191 | def process_image(image, image_height, image_min_width, image_max_width):
192 | img = image.convert("RGB")
193 |
194 | w, h = img.size
195 | new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
196 |
197 | img = img.resize((new_w, image_height), Image.LANCZOS)
198 |
199 | img = np.asarray(img).transpose(2, 0, 1)
200 | img = img / 255
201 | return img
202 |
203 |
204 | def process_input(image, image_height, image_min_width, image_max_width):
205 | img = process_image(image, image_height, image_min_width, image_max_width)
206 | img = img[np.newaxis, ...]
207 | img = torch.FloatTensor(img)
208 | return img
209 |
210 |
211 | def predict(filename, config):
212 | img = Image.open(filename)
213 | img = process_input(img)
214 |
215 | img = img.to(config["device"])
216 |
217 | model, vocab = build_model(config)
218 | s = translate(img, model)[0].tolist()
219 | s = vocab.decode(s)
220 |
221 | return s
222 |
--------------------------------------------------------------------------------
/vietocr/tool/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gdown
3 | import yaml
4 | import numpy as np
5 | import uuid
6 | import requests
7 | import tempfile
8 | from tqdm import tqdm
9 |
10 |
11 | def download_weights(uri, cached=None, md5=None, quiet=False):
12 | if uri.startswith("http"):
13 | return download(url=uri, quiet=quiet)
14 | return uri
15 |
16 |
17 | def download(url, quiet=False):
18 | tmp_dir = tempfile.gettempdir()
19 | filename = url.split("/")[-1]
20 | full_path = os.path.join(tmp_dir, filename)
21 |
22 | if os.path.exists(full_path):
23 | print("Model weight {} exsits. Ignore download!".format(full_path))
24 | return full_path
25 |
26 | with requests.get(url, stream=True) as r:
27 | r.raise_for_status()
28 | with open(full_path, "wb") as f:
29 | for chunk in tqdm(r.iter_content(chunk_size=8192)):
30 | # If you have chunk encoded response uncomment if
31 | # and set chunk_size parameter to None.
32 | # if chunk:
33 | f.write(chunk)
34 | return full_path
35 |
36 |
37 | def download_config(id):
38 | url = "https://vocr.vn/data/vietocr/config/{}".format(id)
39 | r = requests.get(url)
40 | config = yaml.safe_load(r.text)
41 | return config
42 |
43 |
44 | def compute_accuracy(ground_truth, predictions, mode="full_sequence"):
45 | """
46 | Computes accuracy
47 | :param ground_truth:
48 | :param predictions:
49 | :param display: Whether to print values to stdout
50 | :param mode: if 'per_char' is selected then
51 | single_label_accuracy = correct_predicted_char_nums_of_single_sample / single_label_char_nums
52 | avg_label_accuracy = sum(single_label_accuracy) / label_nums
53 | if 'full_sequence' is selected then
54 | single_label_accuracy = 1 if the prediction result is exactly the same as label else 0
55 | avg_label_accuracy = sum(single_label_accuracy) / label_nums
56 | :return: avg_label_accuracy
57 | """
58 | if mode == "per_char":
59 |
60 | accuracy = []
61 |
62 | for index, label in enumerate(ground_truth):
63 | prediction = predictions[index]
64 | total_count = len(label)
65 | correct_count = 0
66 | try:
67 | for i, tmp in enumerate(label):
68 | if tmp == prediction[i]:
69 | correct_count += 1
70 | except IndexError:
71 | continue
72 | finally:
73 | try:
74 | accuracy.append(correct_count / total_count)
75 | except ZeroDivisionError:
76 | if len(prediction) == 0:
77 | accuracy.append(1)
78 | else:
79 | accuracy.append(0)
80 | avg_accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
81 | elif mode == "full_sequence":
82 | try:
83 | correct_count = 0
84 | for index, label in enumerate(ground_truth):
85 | prediction = predictions[index]
86 | if prediction == label:
87 | correct_count += 1
88 | avg_accuracy = correct_count / len(ground_truth)
89 | except ZeroDivisionError:
90 | if not predictions:
91 | avg_accuracy = 1
92 | else:
93 | avg_accuracy = 0
94 | else:
95 | raise NotImplementedError(
96 | "Other accuracy compute mode has not been implemented"
97 | )
98 |
99 | return avg_accuracy
100 |
--------------------------------------------------------------------------------
/vietocr/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from vietocr.model.trainer import Trainer
4 | from vietocr.tool.config import Cfg
5 |
6 |
7 | def main():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--config", required=True, help="see example at ")
10 | parser.add_argument("--checkpoint", required=False, help="your checkpoint")
11 |
12 | args = parser.parse_args()
13 | config = Cfg.load_config_from_file(args.config)
14 |
15 | trainer = Trainer(config)
16 |
17 | if args.checkpoint:
18 | trainer.load_checkpoint(args.checkpoint)
19 |
20 | trainer.train()
21 |
22 |
23 | if __name__ == "__main__":
24 | main()
25 |
--------------------------------------------------------------------------------