├── .github
├── ISSUE_TEMPLATE
│ ├── bug-report.md
│ ├── feature-request.md
│ └── write-document.md
└── PULL_REQUEST_TEMPLATE.md
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── Attention.png
├── SATRN.png
├── SATRN_feedforward.png
├── competition-overview.png
├── demo.png
└── logo.png
├── checkpoint.py
├── configs
├── Attention.yaml
└── SATRN.yaml
├── custom_augment.py
├── data_tools
├── download.sh
├── extract_tokens.py
├── make_dataset.py
├── parse_upstage.py
└── train_test_split.py
├── dataset.py
├── flags.py
├── inference.py
├── metrics.py
├── networks
├── Attention.py
├── README.md
├── SATRN.py
├── loss.py
└── spatial_transformation.py
├── pre_processing.py
├── requirements.txt
├── scheduler.py
├── train.py
├── transform.py
├── utils.py
└── vedastr_cstr
├── LICENSE
├── README.md
├── configs
└── cstr.py
├── requirements.txt
├── tools
├── deploy
│ ├── benchmark.py
│ ├── export.py
│ └── utils
│ │ ├── __init__.py
│ │ └── common.py
├── dist_test.sh
├── dist_train.sh
├── inference.py
├── test.py
└── train.py
└── vedastr
├── __init__.py
├── converter
├── __init__.py
├── attn_converter.py
├── base_convert.py
├── builder.py
├── ctc_converter.py
├── custom_converter.py
├── fc_converter.py
└── registry.py
├── criteria
├── __init__.py
├── builder.py
├── cross_entropy_loss.py
├── ctc_loss.py
├── label_smooth_cross_entropy_loss.py
└── registry.py
├── dataloaders
├── __init__.py
├── builder.py
├── registry.py
└── samplers
│ ├── __init__.py
│ ├── balance_sampler.py
│ ├── builder.py
│ ├── default_sampler.py
│ ├── dist_balance_sampler.py
│ ├── dist_default_sampler.py
│ └── registry.py
├── datasets
├── __init__.py
├── base.py
├── builder.py
├── concat_dataset.py
├── fold_dataset.py
├── lmdb_dataset.py
├── paste_dataset.py
├── registry.py
└── txt_datasets.py
├── logger
├── __init__.py
└── builder.py
├── lr_schedulers
├── __init__.py
├── base.py
├── builder.py
├── constant_lr.py
├── cosine_lr.py
├── exponential_lr.py
├── poly_lr.py
├── registry.py
└── step_lr.py
├── metrics
├── __init__.py
├── accuracy.py
├── builder.py
└── registry.py
├── models
├── __init__.py
├── bodies
│ ├── __init__.py
│ ├── body.py
│ ├── builder.py
│ ├── component.py
│ ├── feature_extractors
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── decoders
│ │ │ ├── __init__.py
│ │ │ ├── bricks
│ │ │ │ ├── __init__.py
│ │ │ │ ├── bricks.py
│ │ │ │ ├── builder.py
│ │ │ │ ├── pva.py
│ │ │ │ └── registry.py
│ │ │ ├── builder.py
│ │ │ ├── gfpn.py
│ │ │ └── registry.py
│ │ └── encoders
│ │ │ ├── __init__.py
│ │ │ ├── backbones
│ │ │ ├── __init__.py
│ │ │ ├── builder.py
│ │ │ ├── general_backbone.py
│ │ │ ├── registry.py
│ │ │ └── resnet.py
│ │ │ ├── builder.py
│ │ │ └── enhance_modules
│ │ │ ├── __init__.py
│ │ │ ├── aspp.py
│ │ │ ├── builder.py
│ │ │ ├── ppm.py
│ │ │ └── registry.py
│ ├── rectificators
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── registry.py
│ │ ├── spin.py
│ │ ├── sspin.py
│ │ └── tps_stn.py
│ ├── registry.py
│ └── sequences
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── registry.py
│ │ ├── rnn
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ └── encoder.py
│ │ └── transformer
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ ├── encoder.py
│ │ ├── position_encoder
│ │ ├── __init__.py
│ │ ├── adaptive_2d_encoder.py
│ │ ├── builder.py
│ │ ├── encoder.py
│ │ ├── registry.py
│ │ └── utils.py
│ │ └── unit
│ │ ├── __init__.py
│ │ ├── attention
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── multihead_attention.py
│ │ └── registry.py
│ │ ├── builder.py
│ │ ├── decoder.py
│ │ ├── encoder.py
│ │ ├── feedforward
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── feedforward.py
│ │ └── registry.py
│ │ └── registry.py
├── builder.py
├── heads
│ ├── __init__.py
│ ├── att_head.py
│ ├── builder.py
│ ├── conv_head.py
│ ├── ctc_head.py
│ ├── fc_head.py
│ ├── head.py
│ ├── multi_head.py
│ ├── registry.py
│ └── transformer_head.py
├── model.py
├── registry.py
├── utils
│ ├── __init__.py
│ ├── builder.py
│ ├── cbam.py
│ ├── conv_module.py
│ ├── fc_module.py
│ ├── non_local.py
│ ├── norm.py
│ ├── registry.py
│ ├── residual_module.py
│ ├── squeeze_excitation_module.py
│ └── upsample.py
└── weight_init.py
├── optimizers
├── __init__.py
└── builder.py
├── runners
├── __init__.py
├── base.py
├── inference_runner.py
├── test_runner.py
└── train_runner.py
├── transforms
├── __init__.py
├── builder.py
├── registry.py
└── transforms.py
└── utils
├── __init__.py
├── checkpoint.py
├── common.py
├── config.py
├── dist_utils.py
├── misc.py
├── path.py
└── registry.py
/.github/ISSUE_TEMPLATE/bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug Report
3 | about: 버그 보고
4 | title: "[BUG] "
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### 🐞 Describe the bug
11 |
12 | ### 📷 Screenshots
13 |
14 | ### 📄 Additional context
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature Request
3 | about: 새로운 기능 제안
4 | title: "[FEAT] "
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### 😥 Explain the Problem
11 |
12 | ### ✨ Describe A New Feature
13 |
14 | ### 📄 Additional Context
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/write-document.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Write Document
3 | about: 문서화 작업
4 | title: "[DOC] "
5 | labels: documentation
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### ✅ Check the Type
11 |
12 | - [ ] Comments to explain Code
13 | - [ ] Documentation to explain Project
14 |
15 | ### 📝 What to Document
16 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## ✨ Types of Changes
2 |
3 | - [ ] Bugfix
4 | - [ ] New Feature
5 | - [ ] Documentation
6 |
7 | ## ✅ Checklist
8 |
9 | - [ ] Enough documentation in Pull Request or README.md
10 | - [ ] Add comments to explain code
11 | - [ ] Observe python PEP8 rule
12 | - [ ] Check if code works well
13 |
14 | ## 📝 Proposed Changes
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/linux,python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,python
4 |
5 | ### Linux ###
6 | *~
7 |
8 | # temporary files which can be created if a process still has a handle open of a deleted file
9 | .fuse_hidden*
10 |
11 | # KDE directory preferences
12 | .directory
13 |
14 | # Linux trash folder which might appear on any partition or disk
15 | .Trash-*
16 |
17 | # .nfs files are created when an open file is removed but is still being accessed
18 | .nfs*
19 |
20 | ### Python ###
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 |
26 | # C extensions
27 | *.so
28 |
29 | # Distribution / packaging
30 | .Python
31 | build/
32 | develop-eggs/
33 | dist/
34 | downloads/
35 | eggs/
36 | .eggs/
37 | parts/
38 | sdist/
39 | var/
40 | wheels/
41 | pip-wheel-metadata/
42 | share/python-wheels/
43 | *.egg-info/
44 | .installed.cfg
45 | *.egg
46 | MANIFEST
47 |
48 | # PyInstaller
49 | # Usually these files are written by a python script from a template
50 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
51 | *.manifest
52 | *.spec
53 |
54 | # Installer logs
55 | pip-log.txt
56 | pip-delete-this-directory.txt
57 |
58 | # Unit test / coverage reports
59 | htmlcov/
60 | .tox/
61 | .nox/
62 | .coverage
63 | .coverage.*
64 | .cache
65 | nosetests.xml
66 | coverage.xml
67 | *.cover
68 | *.py,cover
69 | .hypothesis/
70 | .pytest_cache/
71 | pytestdebug.log
72 |
73 | # Translations
74 | *.mo
75 | *.pot
76 |
77 | # Django stuff:
78 | *.log
79 | local_settings.py
80 | db.sqlite3
81 | db.sqlite3-journal
82 |
83 | # Flask stuff:
84 | instance/
85 | .webassets-cache
86 |
87 | # Scrapy stuff:
88 | .scrapy
89 |
90 | # Sphinx documentation
91 | docs/_build/
92 | doc/_build/
93 |
94 | # PyBuilder
95 | target/
96 |
97 | # Jupyter Notebook
98 | .ipynb_checkpoints
99 |
100 | # IPython
101 | profile_default/
102 | ipython_config.py
103 |
104 | # pyenv
105 | .python-version
106 |
107 | # pipenv
108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
111 | # install all needed dependencies.
112 | #Pipfile.lock
113 |
114 | # poetry
115 | #poetry.lock
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | # .env
129 | .env/
130 | .venv/
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 | pythonenv*
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # operating system-related files
160 | # file properties cache/storage on macOS
161 | *.DS_Store
162 | # thumbnail cache on Windows
163 | Thumbs.db
164 |
165 | # profiling data
166 | .prof
167 |
168 |
169 | # End of https://www.toptal.com/developers/gitignore/api/linux,python
170 |
171 | #log
172 | log/
173 | wandb/
174 | .vscode/
175 | .idea/
176 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Team DKT
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/assets/Attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/Attention.png
--------------------------------------------------------------------------------
/assets/SATRN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/SATRN.png
--------------------------------------------------------------------------------
/assets/SATRN_feedforward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/SATRN_feedforward.png
--------------------------------------------------------------------------------
/assets/competition-overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/competition-overview.png
--------------------------------------------------------------------------------
/assets/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/demo.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/logo.png
--------------------------------------------------------------------------------
/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from tensorboardX import SummaryWriter
4 |
5 | use_cuda = torch.cuda.is_available()
6 |
7 | default_checkpoint = {
8 | "epoch": 0,
9 | "train_losses": [],
10 | "train_symbol_accuracy": [],
11 | "train_sentence_accuracy": [],
12 | "train_wer": [],
13 | "train_score": [],
14 | "validation_losses": [],
15 | "validation_symbol_accuracy": [],
16 | "validation_sentence_accuracy": [],
17 | "validation_wer": [],
18 | "validation_score": [],
19 | "lr": [],
20 | "grad_norm": [],
21 | "model": {},
22 | "configs":{},
23 | "token_to_id":{},
24 | "id_to_token":{},
25 | }
26 |
27 |
28 | def save_checkpoint(checkpoint, dir="./checkpoints", prefix=""):
29 | """ Saving check point
30 |
31 | Args:
32 | checkpoint(dict) : Checkpoint to save
33 | dir(str) : Path to save the checkpoint
34 | prefix(str) : Path of location of dir
35 | """
36 | # Padded to 4 digits because of lexical sorting of numbers.
37 | # e.g. 0009.pth
38 | filename = "{num:0>4}.pth".format(num=checkpoint["epoch"])
39 | if not os.path.exists(os.path.join(prefix, dir)):
40 | os.makedirs(os.path.join(prefix, dir))
41 | torch.save(checkpoint, os.path.join(prefix, dir, filename))
42 |
43 |
44 | def load_checkpoint(path, cuda=use_cuda):
45 | """ Load check point
46 |
47 | Args:
48 | path(str) : Path checkpoint located
49 | cuda : Whether use cuda or not [Default: use_cuda]
50 | Returns
51 | Loaded checkpoints
52 | """
53 | if cuda:
54 | return torch.load(path)
55 | else:
56 | # Load GPU model on CPU
57 | return torch.load(path, map_location=lambda storage, loc: storage)
58 |
59 |
60 | def init_tensorboard(name="", base_dir="./tensorboard"):
61 | """Init tensorboard
62 | Args:
63 | name(str) : name of tensorboard
64 | base_dir(str): path of tesnorboard
65 | """
66 | return SummaryWriter(os.path.join(name, base_dir))
67 |
68 |
69 | def write_tensorboard(
70 | writer,
71 | epoch,
72 | grad_norm,
73 | train_loss,
74 | train_symbol_accuracy,
75 | train_sentence_accuracy,
76 | train_wer,
77 | train_score,
78 | validation_loss,
79 | validation_symbol_accuracy,
80 | validation_sentence_accuracy,
81 | validation_wer,
82 | validation_score,
83 | model,
84 | ):
85 | writer.add_scalar("train_loss", train_loss, epoch)
86 | writer.add_scalar("train_symbol_accuracy", train_symbol_accuracy, epoch)
87 | writer.add_scalar("train_sentence_accuracy",train_sentence_accuracy,epoch)
88 | writer.add_scalar("train_wer", train_wer, epoch)
89 | writer.add_scalar("train_score", train_score, epoch)
90 | writer.add_scalar("validation_loss", validation_loss, epoch)
91 | writer.add_scalar("validation_symbol_accuracy", validation_symbol_accuracy, epoch)
92 | writer.add_scalar("validation_sentence_accuracy",validation_sentence_accuracy,epoch)
93 | writer.add_scalar("validation_wer",validation_wer,epoch)
94 | writer.add_scalar("validation_score", validation_score, epoch)
95 | writer.add_scalar("grad_norm", grad_norm, epoch)
96 |
97 | for name, param in model.encoder.named_parameters():
98 | writer.add_histogram(
99 | "encoder/{}".format(name), param.detach().cpu().numpy(), epoch
100 | )
101 | if param.grad is not None:
102 | writer.add_histogram(
103 | "encoder/{}/grad".format(name), param.grad.detach().cpu().numpy(), epoch
104 | )
105 |
106 | for name, param in model.decoder.named_parameters():
107 | writer.add_histogram(
108 | "decoder/{}".format(name), param.detach().cpu().numpy(), epoch
109 | )
110 | if param.grad is not None:
111 | writer.add_histogram(
112 | "decoder/{}/grad".format(name), param.grad.detach().cpu().numpy(), epoch
113 | )
114 |
--------------------------------------------------------------------------------
/configs/Attention.yaml:
--------------------------------------------------------------------------------
1 | network: "Attention"
2 | input_size:
3 | height: 128
4 | width: 128
5 | SATRN:
6 | encoder:
7 | hidden_dim: 300
8 | filter_dim: 600
9 | layer_num: 6
10 | head_num: 8
11 | decoder:
12 | src_dim: 300
13 | hidden_dim: 128
14 | filter_dim: 512
15 | layer_num: 3
16 | head_num: 8
17 | Attention:
18 | src_dim: 512
19 | hidden_dim: 128
20 | embedding_dim: 128
21 | layer_num: 1
22 | cell_type: "LSTM"
23 | checkpoint: ""
24 | prefix: "./log/attention_50"
25 |
26 | data:
27 | train:
28 | - "/opt/ml/input/data/train_dataset/gt.txt"
29 | test:
30 | - ""
31 | token_paths:
32 | - "/opt/ml/input/data/train_dataset/tokens.txt" # 241 tokens
33 | dataset_proportions: # proportion of data to take from train (not test)
34 | - 1.0
35 | random_split: True # if True, random split from train files
36 | test_proportions: 0.2 # only if random_split is True
37 | crop: True
38 | rgb: 1 # 3 for color, 1 for greyscale
39 |
40 | batch_size: 96
41 | num_workers: 8
42 | num_epochs: 50
43 | print_epochs: 1
44 | dropout_rate: 0.1
45 | teacher_forcing_ratio: 0.5
46 | max_grad_norm: 2.0
47 | seed: 1234
48 | optimizer:
49 | optimizer: 'Adam' # Adam, Adadelta
50 | lr: 5e-4 # 1e-4
51 | weight_decay: 1e-4
52 | is_cycle: True
53 |
54 | patience: -1 # -1 for off
55 | save_best_only: False
56 |
57 | wandb:
58 | wandb: True
59 | run_name: ""
--------------------------------------------------------------------------------
/configs/SATRN.yaml:
--------------------------------------------------------------------------------
1 | network: SATRN
2 | input_size:
3 | height: 48
4 | width: 192
5 | SATRN:
6 | encoder:
7 | hidden_dim: 300
8 | filter_dim: 1200
9 | layer_num: 6
10 | head_num: 8
11 |
12 | shallower_cnn: True # shallow CNN
13 | adaptive_gate: True # A2DPE
14 | conv_ff: True # locality-aware feedforward
15 | separable_ff: True # only if conv_ff is True
16 | decoder:
17 | src_dim: 300
18 | hidden_dim: 300
19 | filter_dim: 1200
20 | layer_num: 3
21 | head_num: 8
22 |
23 | checkpoint: ""
24 | prefix: "./log/satrn"
25 |
26 | data:
27 | train:
28 | - "/opt/ml/input/data/train_dataset/gt.txt"
29 | # - /opt/ml/input/data/train_dataset/custom_train.txt # for experiments
30 | test:
31 | -
32 | # - /opt/ml/input/data/train_dataset/custom_test.txt # for experiments
33 | token_paths:
34 | - "/opt/ml/input/data/train_dataset/tokens.txt" # 241 tokens
35 | dataset_proportions: # proportion of data to take from train (not test)
36 | - 1.0
37 | random_split: True # if True, random split from train files
38 | test_proportions: 0.2 # only if random_split is True
39 | crop: True
40 | rgb: 1 # 3 for color, 1 for greyscale
41 |
42 | batch_size: 16
43 | num_workers: 8
44 | num_epochs: 200
45 | print_epochs: 1
46 | dropout_rate: 0.1
47 | teacher_forcing_ratio: 0.5
48 | teacher_forcing_damp: 5e-3 # 0 to turn off
49 | max_grad_norm: 2.0
50 | seed: 1234
51 | optimizer:
52 | optimizer: AdamP # Adam, Adadelta
53 | lr: 5e-4 # 1e-4
54 | weight_decay: 1e-4
55 | selective_weight_decay: True
56 | is_cycle: True
57 | label_smoothing: 0.2 # 0 to off
58 |
59 | patience: 30 # -1 for off
60 | save_best_only: True
61 |
62 | fp16: True
63 |
64 | wandb:
65 | wandb: True
66 | run_name: ^____________^
--------------------------------------------------------------------------------
/data_tools/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | wget -P /opt/ml/input/data/ https://prod-aistages-public.s3.ap-northeast-2.amazonaws.com/app/Competitions/000043/data/train_dataset.zip
3 | cd /opt/ml/input/data
4 | unzip train_dataset.zip
5 | cd ~
6 |
--------------------------------------------------------------------------------
/data_tools/extract_tokens.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import sys
4 |
5 |
6 | def parse_symbols(truth):
7 | """Returns the unique tokens of the groundtrurh.
8 |
9 | Args:
10 | truth(string) : groundtruth
11 |
12 | Returns:
13 | unique_symbols(set): unique_symbols
14 | """
15 | unique_symbols = set(truth.split())
16 | return unique_symbols
17 |
18 |
19 | def create_tokens(groundtruth, output="tokens.txt"):
20 | """Save a unique tokens file for the ground trurh.
21 |
22 | Args:
23 | groundtruth (text file) : groundtruth file
24 | output (str, optional): output filename. Defaults to "tokens.txt".
25 | """
26 | with open(groundtruth, "r") as fd:
27 | data = fd.read()
28 |
29 | unique_symbols = set()
30 | data = data.split("\n")
31 | data = [x.split("\t") for x in data]
32 | for _, truth in data:
33 | truth_symbols = parse_symbols(truth)
34 | unique_symbols = unique_symbols.union(truth_symbols)
35 |
36 | symbols = list(unique_symbols)
37 | symbols.sort()
38 | with open(output, "w") as output_fd:
39 | writer = csv.writer(output_fd, delimiter="\n")
40 | writer.writerow(symbols)
41 |
42 |
43 | if __name__ == "__main__":
44 | """
45 | extract_tokens path/to/groundtruth.tsv [-o OUTPUT]
46 | """
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument(
49 | "-o",
50 | "--output",
51 | dest="output",
52 | default="tokens.txt",
53 | help="Output path of the tokens text file",
54 | )
55 | parser.add_argument("groundtruth", nargs=1, help="Ground truth TXT file")
56 | args = parser.parse_args()
57 | create_tokens(args.groundtruth[0], args.output)
58 |
--------------------------------------------------------------------------------
/data_tools/train_test_split.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import os
4 | import random
5 |
6 | test_percent = 0.2
7 | output_dir = "gt-split"
8 |
9 |
10 | # Split the ground truth into train, test sets
11 | def split_gt(groundtruth, test_percent=0.2, data_num=None):
12 | """Split the ground truth into train, test sets
13 |
14 | Args:
15 | groundtruth (text file) : ground truth file
16 | test_percent (float) : represent the proportion of the dataset to include in the test split. Defaults to 0.2.
17 | data_num (int) : represents the absolute number of test samples. Defaults to None.
18 |
19 | Returns:
20 | train dataset
21 | test dataset
22 | """
23 | with open(groundtruth, "r") as fd:
24 | data = fd.read()
25 | data = data.split('\n')
26 | data = [x.split('\t') for x in data]
27 | random.shuffle(data)
28 | if data_num:
29 | assert sum(data_num) < len(data)
30 | return data[:data_num[0]], data[data_num[0]:data_num[0] + data_num[1]]
31 | test_len = round(len(data) * test_percent)
32 | return data[test_len:], data[:test_len] # train, test
33 |
34 |
35 | def write_tsv(data, path):
36 | with open(path, "w") as fd:
37 | writer = csv.writer(fd, delimiter="\t")
38 | writer.writerows(data)
39 |
40 |
41 | def parse_args():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument(
44 | "-p",
45 | "--test-percent",
46 | dest="test_percent",
47 | default=test_percent,
48 | type=float,
49 | help="Percent of data to use for test [Default: {}]".format(test_percent)
50 | )
51 | parser.add_argument(
52 | "-n",
53 | "--data_num",
54 | nargs=2,
55 | type=int,
56 | help="Number of train data and test data",
57 | )
58 | parser.add_argument(
59 | "-i",
60 | "--input",
61 | dest="input",
62 | required=True,
63 | type=str,
64 | help="Path to input ground truth file",
65 | )
66 | parser.add_argument(
67 | "-o",
68 | "--output-dir",
69 | dest="output_dir",
70 | default=output_dir,
71 | type=str,
72 | help="Directory to save the split ground truth files",
73 | )
74 | return parser.parse_args()
75 |
76 |
77 | if __name__ == "__main__":
78 | options = parse_args()
79 | train_gt, test_gt = split_gt(options.input, options.test_percent, options.data_num)
80 | if not os.path.exists(options.output_dir):
81 | os.makedirs(options.output_dir)
82 | write_tsv(train_gt, os.path.join(options.output_dir, "train.txt"))
83 | write_tsv(test_gt, os.path.join(options.output_dir, "test.txt"))
84 |
--------------------------------------------------------------------------------
/flags.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code from clovaai/SATRN
3 | """
4 | import os
5 | import yaml
6 | import collections
7 |
8 |
9 | def dict_to_namedtuple(d):
10 | """ Convert dictionary to named tuple.
11 | """
12 | FLAGSTuple = collections.namedtuple('FLAGS', sorted(d.keys()))
13 |
14 | for k, v in d.items():
15 |
16 | if k == 'prefix':
17 | v = os.path.join('./', v)
18 |
19 | if type(v) is dict:
20 | d[k] = dict_to_namedtuple(v)
21 |
22 | elif type(v) is str:
23 | try:
24 | d[k] = eval(v)
25 | except:
26 | d[k] = v
27 |
28 | nt = FLAGSTuple(**d)
29 |
30 | return nt
31 |
32 |
33 | class Flags:
34 | """ Flags object.
35 | """
36 |
37 | def __init__(self, config_file):
38 | try:
39 | with open(config_file, 'r') as f:
40 | d = yaml.safe_load(f)
41 | except:
42 | d = config_file
43 |
44 |
45 | self.flags = dict_to_namedtuple(d)
46 |
47 | def get(self):
48 | return self.flags
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from train import id_to_string
4 | from checkpoint import load_checkpoint
5 | from torchvision import transforms
6 | from dataset import LoadEvalDataset, collate_eval_batch
7 | from flags import Flags
8 | from utils import get_network
9 | import csv
10 | from torch.utils.data import DataLoader
11 | import argparse
12 | import random
13 | from tqdm import tqdm
14 |
15 |
16 | def main(parser):
17 | """Inference code
18 | """
19 | is_cuda = torch.cuda.is_available()
20 | # load pretrained model checkpoint
21 | checkpoint = load_checkpoint(parser.checkpoint, cuda=is_cuda)
22 | options = Flags(checkpoint["configs"]).get()
23 | torch.manual_seed(options.seed)
24 | random.seed(options.seed)
25 | torch.backends.cudnn.deterministic = True
26 | torch.backends.cudnn.benchmark = False
27 |
28 | hardware = "cuda" if is_cuda else "cpu"
29 | device = torch.device(hardware)
30 | print("--------------------------------")
31 | print("Running {} on device {}\n".format(options.network, device))
32 |
33 | model_checkpoint = checkpoint["model"]
34 | if model_checkpoint:
35 | print(
36 | "[+] Checkpoint\n",
37 | "Resuming from epoch : {}\n".format(checkpoint["epoch"]),
38 | )
39 | print(options.input_size.height)
40 |
41 | # transform to be applied on a sample.
42 | transformed = transforms.Compose(
43 | [
44 | transforms.Resize((options.input_size.height, options.input_size.width)),
45 | transforms.ToTensor(),
46 | ]
47 | )
48 |
49 | dummy_gt = "\sin " * parser.max_sequence # set maximum inference sequence
50 | # make dataset from test folder
51 | root = os.path.join(os.path.dirname(parser.file_path), "images")
52 | with open(parser.file_path, "r") as fd:
53 | reader = csv.reader(fd, delimiter="\t")
54 | data = list(reader)
55 | test_data = [[os.path.join(root, x[0]), x[0], dummy_gt] for x in data]
56 | test_dataset = LoadEvalDataset(
57 | test_data, checkpoint["token_to_id"], checkpoint["id_to_token"], crop=False, transform=transformed,
58 | rgb=options.data.rgb
59 | )
60 | test_data_loader = DataLoader(
61 | test_dataset,
62 | batch_size=parser.batch_size,
63 | shuffle=False,
64 | num_workers=options.num_workers,
65 | collate_fn=collate_eval_batch,
66 | )
67 |
68 | print(
69 | "[+] Data\n",
70 | "The number of test samples : {}\n".format(len(test_dataset)),
71 | )
72 |
73 | model = get_network(
74 | options.network,
75 | options,
76 | model_checkpoint,
77 | device,
78 | test_dataset,
79 | )
80 | model.eval()
81 | results = []
82 | for d in tqdm(test_data_loader):
83 | input = d["image"].to(device)
84 | expected = d["truth"]["encoded"].to(device)
85 |
86 | output = model(input, expected, False, 0.0)
87 | decoded_values = output.transpose(1, 2)
88 | _, sequence = torch.topk(decoded_values, 1, dim=1)
89 | sequence = sequence.squeeze(1)
90 | sequence_str = id_to_string(sequence, test_data_loader, do_eval=1)
91 | for path, predicted in zip(d["file_path"], sequence_str):
92 | results.append((path, predicted))
93 | # save inference results as csv file
94 | os.makedirs(parser.output_dir, exist_ok=True)
95 | with open(os.path.join(parser.output_dir, "output.csv"), "w") as w:
96 | for path, predicted in results:
97 | w.write(path + "\t" + predicted + "\n")
98 |
99 |
100 | if __name__ == "__main__":
101 | parser = argparse.ArgumentParser()
102 | parser.add_argument(
103 | "--checkpoint",
104 | dest="checkpoint",
105 | default="./log/satrn/checkpoints/0015.pth",
106 | type=str,
107 | help="Path of checkpoint file",
108 | )
109 | parser.add_argument(
110 | "--max_sequence",
111 | dest="max_sequence",
112 | default=230,
113 | type=int,
114 | help="maximun sequence when doing inference",
115 | )
116 | parser.add_argument(
117 | "--batch_size",
118 | dest="batch_size",
119 | default=8,
120 | type=int,
121 | help="batch size when doing inference",
122 | )
123 |
124 | eval_dir = os.environ.get('SM_CHANNEL_EVAL', '/opt/ml/input/data/')
125 | file_path = os.path.join(eval_dir, 'eval_dataset/input.txt')
126 | parser.add_argument(
127 | "--file_path",
128 | dest="file_path",
129 | default=file_path,
130 | type=str,
131 | help="file path when doing inference",
132 | )
133 |
134 | output_dir = os.environ.get('SM_OUTPUT_DATA_DIR', 'submit')
135 | parser.add_argument(
136 | "--output_dir",
137 | dest="output_dir",
138 | default=output_dir,
139 | type=str,
140 | help="output directory",
141 | )
142 |
143 | parser = parser.parse_args()
144 | main(parser)
145 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import editdistance
2 | import numpy as np
3 |
4 | def word_error_rate(predicted_outputs, ground_truths):
5 | """ Estimate Word_error_rate.
6 | Args:
7 | predicted_outputs(list) : result of model prediction
8 | ground_truths(list) : ground truth
9 |
10 | Returns:
11 | World Error rate(float) : Word error rate Estimated by Edit distance.
12 | """
13 | sum_wer=0.0
14 | for output,ground_truth in zip(predicted_outputs,ground_truths):
15 | output=output.split(" ")
16 | ground_truth=ground_truth.split(" ")
17 | distance = editdistance.eval(output, ground_truth)
18 | length = max(len(output),len(ground_truth))
19 | sum_wer+=(distance/length)
20 | return sum_wer/len(predicted_outputs)
21 |
22 |
23 | def sentence_acc(predicted_outputs, ground_truths):
24 | """ Estimate sentence_acc.
25 | Args:
26 | predicted_outputs(list) : result of model prediction
27 | ground_truths(list) : ground truth
28 |
29 | Returns:
30 | sentence_acc(float) : Acurracy between preicted_output and ground_truths
31 | """
32 | correct_sentences=0
33 | for output,ground_truth in zip(predicted_outputs,ground_truths):
34 | if np.array_equal(output,ground_truth):
35 | correct_sentences+=1
36 | return correct_sentences/len(predicted_outputs)
37 |
38 |
39 | def get_worst_wer_img_path(img_path_list, predicted_outputs, ground_truths):
40 | """ Return Information of max word error rate Image
41 | Args:
42 | img_path_list(list) : list of image path
43 | predicted_outputs(list) : result of model prediction
44 | ground_truths(list) : ground truth
45 |
46 | Returns:
47 | image path(str) : Image path of worst error rate
48 | word error rate(float) : max word error rate
49 | ground truth(str) : Ground truth of max word error rate image
50 | predicted_output(str) : Prediction of model
51 | """
52 | max_wer_ind = 0
53 | max_wer = 0
54 |
55 | i = 0
56 | for output, ground_truth in zip(predicted_outputs,ground_truths):
57 | output=output.split(" ")
58 | ground_truth=ground_truth.split(" ")
59 |
60 | distance = editdistance.eval(output, ground_truth)
61 | length = max(len(output), len(ground_truth))
62 | cur_wer = (distance / length)
63 | if max_wer < cur_wer:
64 | max_wer = cur_wer
65 | max_wer_ind = i
66 | i+=1
67 |
68 | return img_path_list[max_wer_ind], max_wer, ground_truths[max_wer_ind], predicted_outputs[max_wer_ind]
69 |
--------------------------------------------------------------------------------
/networks/README.md:
--------------------------------------------------------------------------------
1 | # Supported Models
2 |
3 | ## SATRN
4 | [Transformer](https://arxiv.org/abs/1706.03762)의 encoder-decoder 구조를 STR 테스크에 적합하게 변경한 모델입니다. [On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention](https://arxiv.org/abs/1910.04396)에서 제안되었으며, 주요 특징은 다음과 같습니다.
5 |
6 |
7 |

8 |
9 |
10 | ### Shallow CNN
11 | 2D feature map을 생성하는 CNN의 깊이를 크게 줄임으로써, self-attention encoder block이 spatial dependency를 더 잘 포착하도록 할 수 있습니다. `ShallowConvLayer`를 두 겹 쌓는 형태로 구현되었으며, config file에서 `SATRN.encoder.shallow_cnn` 을 `True`로 하여 설정(`False`일 때는 `DeepCNN300`)할 수 있습니다.
12 |
13 | ### Adaptive 2D positional encoding
14 | Transformer의 positional encoding을 2D로 확장하기 위한 방안입니다. 두 방향의 sinusoidal positional encoding을 weighted sum하는 형태로 구현하였으며, 이는 위 논문에서 제시한 바와 같습니다. 이 때, weighted sum의 'weight'는 이미지에 따라 두 방향 정보의 중요도를 조절하는 adaptive gate로서, `Linear-ReLU-Linear-sigmoid`의 구조를 가집니다. 코드에서 adaptive gate는 `AdaptiveGate`으로, positional encoding은 `AdaptivePositionalEncoding2D`으로 구현되어 있습니다. Config file에서 `SATRN.encoder.adaptive_gate`을 `True`로 하여 설정(`False`일 때는 `PositionalEncoding2D` (non-adaptive concat))할 수 있습니다.
15 |
16 | ### Locality-aware feedforward layer
17 | Transformer의 point-wise feedforward layer를 3x3 convolutional layer를 활용한 구조로 변경하여, short-term dependency를 더욱 효과적으로 포착할 수 있도록 했습니다. Config file에서 `SATRN.encoder.conv_ff`를 `True`로 하면 self-attention block의 feedforward layer를 `LocalityAwareFeedforward`로 설정(`False`일 때 `Feedforward`)할 수 있습니다. `SATRN.encoder.seprable_ff`가 `True`일 때 separable, `False`일 때 convolution으로 설정할 수 있습니다(아래 그림 참고).
18 |
19 |

20 |
21 |
22 | ## Attention
23 | [ASTER: An Attentional Scene Text Recognizer with Flexible Rectification](https://ieeexplore.ieee.org/document/8395027)에서 제안된 ASTER에서 Bi-LSTM을 제거한 구조입니다. CNN encoder와 RNN+attention decoder로 이루어져 있으며, decoder의 RNN은 LSTM 혹은 GRU로 설정할 수 있습니다(config file의 `Attention.encoder.cell_type`).
24 |
25 |
26 |

27 |
28 |
--------------------------------------------------------------------------------
/networks/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class LabelSmoothingCrossEntropy(nn.Module):
6 | """
7 | Label Smoothing Cross Entropy
8 |
9 | A Cross Entropy with Label Smooothing
10 |
11 | """
12 |
13 | def __init__(self, eps: float = 0.1, reduction='mean', ignore_index=-100):
14 | """
15 | Args:
16 | eps(float) :Rate of Label Smoothing
17 | reduction(str) : The way of reduction [mean, sum]
18 | ignore_index(int) : Index wants to ignore
19 | """
20 |
21 | super(LabelSmoothingCrossEntropy, self).__init__()
22 | self.eps, self.reduction = eps, reduction
23 | self.ignore_index = ignore_index
24 |
25 | def forward(self, output, target, *args):
26 |
27 | output = output.transpose(1,2)
28 |
29 | pred = output.contiguous().view(-1, output.shape[-1])
30 | target = target.to(pred.device).contiguous().view(-1)
31 | c = pred.size()[-1]
32 |
33 | log_preds = F.log_softmax(pred, dim=-1)
34 | ignore_target = target != self.ignore_index
35 | log_preds = log_preds * ignore_target[:, None]
36 |
37 | if self.reduction == 'sum':
38 | loss = -log_preds.sum()
39 | else:
40 | loss = -log_preds.sum(dim=-1)
41 | if self.reduction == 'mean':
42 | loss = loss.mean()
43 |
44 | return (
45 | loss * self.eps / c + (1 - self.eps) *
46 | F.nll_loss(
47 | log_preds,
48 | target,
49 | reduction=self.reduction,
50 | ignore_index=self.ignore_index,
51 | )
52 | )
53 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit_image==0.14.1
2 | opencv_python==3.4.4.19
3 | tqdm==4.28.1
4 | --find-links=https://download.pytorch.org/whl/torch_stable.html
5 | torch==1.7.1+cu101
6 | torchvision==0.8.2+cu101
7 | scipy==1.2.0
8 | numpy==1.15.4
9 | pillow==8.2.0
10 | tensorboardX==1.5
11 | editdistance==0.5.3
12 | python-dotenv==0.17.1
13 | wandb==0.10.30
14 | adamp==0.3.0
15 | python-dotenv==0.17.1
--------------------------------------------------------------------------------
/transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms.functional as F
4 |
5 |
6 | class RotateByDistribution(nn.Module):
7 | """ Image Rotation
8 | Rotate Image angle between (-34 ~ 34) or input distribution
9 | """
10 | def __init__(self, distribution=None):
11 | super(RotateByDistribution, self).__init__()
12 | if distribution is None:
13 | self.distribution = torch.distributions.normal.Normal(0, 34)
14 | else:
15 | self.distribution = distribution
16 |
17 | def forward(self, img):
18 | degree = self.distribution.sample().item()
19 | return F.rotate(img, degree)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from copy import deepcopy
3 | import torch.optim as optim
4 | from adamp import AdamP
5 |
6 | from networks.Attention import Attention
7 | from networks.SATRN import SATRN
8 |
9 |
10 | def get_network(
11 | model_type,
12 | FLAGS,
13 | model_checkpoint,
14 | device,
15 | train_dataset,
16 | ):
17 | """Get network
18 |
19 | Args:
20 | model_type (str): Model name that wants to use.
21 | FLAGS (Flag): Configs of model.
22 | model_checkpoint (dict): model checkpoint.
23 | device (torch.device): Device type to use.
24 | train_dataset (list): train_dataset
25 |
26 | Returns:
27 | model : model
28 | """
29 | model = None
30 |
31 | if model_type == "SATRN":
32 | model = SATRN(FLAGS, train_dataset, model_checkpoint).to(device)
33 | elif model_type == "CRNN":
34 | model = CRNN()
35 | elif model_type == "Attention":
36 | model = Attention(FLAGS, train_dataset, model_checkpoint).to(device)
37 | else:
38 | raise NotImplementedError
39 |
40 | return model
41 |
42 |
43 | def get_optimizer(optimizer, params, lr, weight_decay=None):
44 | """Get Optimizer
45 |
46 | Args:
47 | optimizer (optimizer): optimizer.
48 | params (optimizer.params): optimizer.params
49 | lr (optimizer.lr): optimizer LR
50 | weight_decay (float, optional): weight decay (L2 penalty). Defaults to None.
51 |
52 | Returns:
53 | optimizer: optimizer
54 | """
55 | if optimizer == "AdamP":
56 | optimizer = AdamP(params, lr=lr)
57 | elif optimizer == "Adam":
58 | optimizer = optim.Adam(params, lr=lr)
59 | elif optimizer == "Adadelta":
60 | optim.Adadelta(params, lr=lr, weight_decay=weight_decay)
61 | else:
62 | raise NotImplementedError
63 | return optimizer
64 |
65 |
66 | def get_wandb_config(config_file):
67 | """Get Wandb config from config_file
68 |
69 | Args:
70 | config_file (str): config_file path
71 | Returns:
72 | config (dict): original config
73 | """
74 | # load config file
75 | with open(config_file, 'r') as f:
76 | option = yaml.safe_load(f)
77 | config = deepcopy(option)
78 |
79 | # remove all except network
80 | keys = ["checkpoint", "input_size", "data", "optimizer", "wandb", "prefix"]
81 | for key in keys:
82 | del config[key]
83 |
84 | # modify some config key-value
85 | new_config = {
86 | "log_path": option['prefix'],
87 | "dataset_proportions": option['data']['dataset_proportions'],
88 | "test_proportions": option['data']['test_proportions'],
89 | "crop": option['data']['crop'],
90 | "rgb": "grayscale" if option['data']['rgb']==1 else "color",
91 | "input_size": (option['input_size']['height'], option['input_size']['width']),
92 | "optimizer": option['optimizer']['optimizer'],
93 | "learning_rate": option['optimizer']['lr'],
94 | "weight_decay": option['optimizer']['weight_decay'],
95 | "is_cycle": option['optimizer']['is_cycle'],
96 | }
97 |
98 | # merge
99 | config.update(new_config)
100 |
101 | # print log
102 | print("wandb save configs below:\n", list(config.keys()))
103 |
104 | return config
--------------------------------------------------------------------------------
/vedastr_cstr/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | opencv-python
3 | addict
4 | six
5 | torch>=1.6.*
6 | torchvision>=0.7.*
7 | Pillow>=8.2.0
8 | lmdb
9 | nltk
10 | terminaltables
11 | albumentations
12 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/deploy/benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
6 |
7 | import cv2 # noqa 402
8 | from volksdep.benchmark import benchmark # noqa 402
9 |
10 | from tools.deploy.utils import CALIBRATORS, CalibDataset, Metric, MetricDataset # noqa 402
11 | from vedastr.runners import TestRunner # noqa 402
12 | from vedastr.utils import Config # noqa 402
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Inference')
17 | parser.add_argument('config', type=str, help='config file path')
18 | parser.add_argument('checkpoint', type=str, help='checkpoint file path')
19 | parser.add_argument('image', type=str, help='sample image path')
20 | parser.add_argument(
21 | '--dtypes',
22 | default=('fp32', 'fp16', 'int8'),
23 | nargs='+',
24 | type=str,
25 | choices=['fp32', 'fp16', 'int8'],
26 | help='dtypes for benchmark')
27 | parser.add_argument(
28 | '--iters', default=100, type=int, help='iters for benchmark')
29 | parser.add_argument(
30 | '--calibration_images',
31 | default=None,
32 | type=str,
33 | help='images dir used when int8 in dtypes')
34 | parser.add_argument(
35 | '--calibration_modes',
36 | nargs='+',
37 | default=['entropy', 'entropy_2', 'minmax'],
38 | type=str,
39 | choices=['entropy_2', 'entropy', 'minmax'],
40 | help='calibration modes for benchmark')
41 | args = parser.parse_args()
42 |
43 | return args
44 |
45 |
46 | def main():
47 | args = parse_args()
48 |
49 | cfg_path = args.config
50 | cfg = Config.fromfile(cfg_path)
51 |
52 | test_cfg = cfg['test']
53 | infer_cfg = cfg['inference']
54 | common_cfg = cfg['common']
55 |
56 | runner = TestRunner(test_cfg, infer_cfg, common_cfg)
57 | assert runner.use_gpu, 'Please use gpu for benchmark.'
58 | runner.load_checkpoint(args.checkpoint)
59 |
60 | image = cv2.imread(args.image)
61 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
62 | aug = runner.transform(image=image, label='')
63 | image, dummy_label = aug['image'], aug['label'] # noqa 841
64 | image = image.unsqueeze(0)
65 | input_len = runner.converter.test_encode(1)[0]
66 | model = runner.model
67 | need_text = runner.need_text
68 | if need_text:
69 | shape = tuple(image.shape), tuple(input_len.shape)
70 | else:
71 | shape = tuple(image.shape)
72 |
73 | dtypes = args.dtypes
74 | iters = args.iters
75 | int8_calibrator = None
76 | if args.calibration_images:
77 | calib_dataset = CalibDataset(args.calibration_images, runner.converter,
78 | runner.transform, need_text)
79 | int8_calibrator = [
80 | CALIBRATORS[mode](dataset=calib_dataset)
81 | for mode in args.calibration_modes
82 | ]
83 | dataset = runner.test_dataloader.dataset
84 | dataset = MetricDataset(dataset, runner.converter, need_text)
85 | metric = Metric(runner.metric, runner.converter)
86 | benchmark(
87 | model,
88 | shape,
89 | dtypes=dtypes,
90 | iters=iters,
91 | int8_calibrator=int8_calibrator,
92 | dataset=dataset,
93 | metric=metric)
94 |
95 |
96 | if __name__ == '__main__':
97 | main()
98 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/deploy/export.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
6 |
7 | import cv2 # noqa 402
8 | from volksdep.converters import save, torch2onnx, torch2trt # noqa 402
9 |
10 | from tools.deploy.utils import CALIBRATORS, CalibDataset # noqa 402
11 | from vedastr.runners import InferenceRunner # noqa 402
12 | from vedastr.utils import Config # noqa 402
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Inference')
17 | parser.add_argument('config', type=str, help='config file path')
18 | parser.add_argument('checkpoint', type=str, help='checkpoint file path')
19 | parser.add_argument('image', type=str, help='sample image path')
20 | parser.add_argument('out', type=str, help='output model file name')
21 | parser.add_argument(
22 | '--onnx', default=False, action='store_true', help='convert to onnx')
23 | parser.add_argument(
24 | '--max_batch_size',
25 | default=1,
26 | type=int,
27 | help='max batch size for trt engine execution')
28 | parser.add_argument(
29 | '--max_workspace_size',
30 | default=1,
31 | type=int,
32 | help='max workspace size for building trt engine')
33 | parser.add_argument(
34 | '--fp16',
35 | default=False,
36 | action='store_true',
37 | help='convert to trt engine with fp16 mode')
38 | parser.add_argument(
39 | '--int8',
40 | default=False,
41 | action='store_true',
42 | help='convert to trt engine with int8 mode')
43 | parser.add_argument(
44 | '--calibration_mode',
45 | default='entropy_2',
46 | type=str,
47 | choices=['entropy_2', 'entropy', 'minmax'])
48 | parser.add_argument(
49 | '--calibration_images',
50 | default=None,
51 | type=str,
52 | help='images dir used when int8 mode is True')
53 | args = parser.parse_args()
54 |
55 | return args
56 |
57 |
58 | def main():
59 | args = parse_args()
60 | out_name = args.out
61 |
62 | cfg_path = args.config
63 | cfg = Config.fromfile(cfg_path)
64 |
65 | infer_cfg = cfg['inference']
66 | common_cfg = cfg.get('common')
67 |
68 | runner = InferenceRunner(infer_cfg, common_cfg)
69 | assert runner.use_gpu, 'Please use valid gpu to export model.'
70 | runner.load_checkpoint(args.checkpoint)
71 |
72 | image = cv2.imread(args.image)
73 |
74 | aug = runner.transform(image=image, label='')
75 | image, label = aug['image'], aug['label'] # noqa 841
76 | image = image.unsqueeze(0).cuda()
77 | dummy_input = (image, runner.converter.test_encode([''])[0])
78 | model = runner.model.cuda().eval()
79 | need_text = runner.need_text
80 | if not need_text:
81 | dummy_input = dummy_input[0]
82 |
83 | if args.onnx:
84 | runner.logger.info('Convert to onnx model')
85 | torch2onnx(model, dummy_input, out_name)
86 | else:
87 | max_batch_size = args.max_batch_size
88 | max_workspace_size = args.max_workspace_size
89 | fp16_mode = args.fp16
90 | int8_mode = args.int8
91 | int8_calibrator = None
92 | if int8_mode:
93 | runner.logger.info('Convert to trt engine with int8')
94 | if args.calibration_images:
95 | runner.logger.info(
96 | 'Use calibration with mode {} and data {}'.format(
97 | args.calibration_mode, args.calibration_images))
98 | dataset = CalibDataset(args.calibration_images,
99 | runner.converter, runner.transform,
100 | need_text)
101 | int8_calibrator = CALIBRATORS[args.calibration_mode](
102 | dataset=dataset)
103 | else:
104 | runner.logger.info('Use default calibration mode and data')
105 | elif fp16_mode:
106 | runner.logger.info('Convert to trt engine with fp16')
107 | else:
108 | runner.logger.info('Convert to trt engine with fp32')
109 | trt_model = torch2trt(
110 | model,
111 | dummy_input,
112 | max_batch_size=max_batch_size,
113 | max_workspace_size=max_workspace_size,
114 | fp16_mode=fp16_mode,
115 | int8_mode=int8_mode,
116 | int8_calibrator=int8_calibrator)
117 | save(trt_model, out_name)
118 | runner.logger.info(
119 | 'Convert successfully, save model to {}'.format(out_name))
120 |
121 |
122 | if __name__ == '__main__':
123 | main()
124 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/deploy/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import CALIBRATORS, CalibDataset, Metric, MetricDataset # noqa 401
2 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/deploy/utils/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from PIL import Image
4 | from volksdep.calibrators import (EntropyCalibrator, EntropyCalibrator2,
5 | MinMaxCalibrator)
6 | from volksdep.datasets import Dataset
7 | from volksdep.metrics import Metric as BaseMetric
8 |
9 | CALIBRATORS = {
10 | 'entropy': EntropyCalibrator,
11 | 'entropy_2': EntropyCalibrator2,
12 | 'minmax': MinMaxCalibrator,
13 | }
14 |
15 |
16 | class CalibDataset(Dataset):
17 |
18 | def __init__(self, images_dir, converter, transform=None, need_text=False):
19 | super(CalibDataset, self).__init__()
20 |
21 | self.root = images_dir
22 | self.samples = os.listdir(images_dir)
23 | self.converter = converter
24 | self.transform = transform
25 | self.need_text = need_text
26 |
27 | def __getitem__(self, idx):
28 | image_file = os.path.join(self.root, self.samples[idx])
29 | image = Image.open(image_file)
30 | if self.transform:
31 | image, _ = self.transform(image=image, label='')
32 | label = self.converter.test_encode(1)
33 | if self.need_text:
34 | return image, label
35 | else:
36 | return image
37 |
38 | def __len__(self):
39 | return len(self.samples)
40 |
41 |
42 | class MetricDataset(Dataset):
43 |
44 | def __init__(self, dataset, converter, need_text):
45 | super(MetricDataset, self).__init__()
46 | self.dataset = dataset
47 | self.converter = converter
48 | self.need_text = need_text
49 |
50 | def __getitem__(self, idx):
51 | image, label = self.dataset[idx]
52 | label_input, _, _ = self.converter.test_encode(1)
53 | _, _, label_target = self.converter.train_encode([label])
54 |
55 | if self.need_text:
56 | return (image, label_input[0]), label
57 | else:
58 | return image, label
59 |
60 | def __len__(self):
61 | return len(self.dataset)
62 |
63 |
64 | class Metric(BaseMetric):
65 |
66 | def __init__(self, metric, converter):
67 | self.metric = metric
68 | self.converter = converter
69 |
70 | def decode(self, preds):
71 | indexes = np.argmax(preds, 2)
72 | pred_str = self.converter.decode(indexes)
73 |
74 | return pred_str
75 |
76 | def __call__(self, preds, targets):
77 | self.metric.reset()
78 | pred_str = self.decode(preds)
79 |
80 | self.metric.measure(pred_str, None, targets)
81 | res = self.metric.result
82 |
83 | return ', '.join(['{}: {:.4f}'.format(k, v) for k, v in res.items()])
84 |
85 | def __str__(self):
86 | return self.metric.__class__.__name__.lower()
87 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | if (($# < 3)); then
4 | echo "Uasage: bash tools/dist_test.sh config_file checkpoint gpus_to_use"
5 | exit 1
6 | fi
7 |
8 | CONFIG="$1"
9 | CHECKPOINT="$2"
10 | GPUS="$3"
11 |
12 | IFS=', ' read -r -a gpus <<<"${GPUS}"
13 | NGPUS="${#gpus[@]}"
14 | PORT="$((29400 + RANDOM % 100))"
15 |
16 | export CUDA_VISIBLE_DEVICES=${GPUS}
17 |
18 | PYTHONPATH="$(dirname "$0")/..":${PYTHONPATH} \
19 | python -m torch.distributed.launch \
20 | --nproc_per_node="${NGPUS}" \
21 | --master_port=${PORT} \
22 | "$(dirname "$0")"/test.py \
23 | "$CONFIG" \
24 | "$CHECKPOINT" \
25 | --distribute \
26 | "${@:4}"
27 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | if (($# < 2)); then
4 | echo "Uasage: bash tools/dist_train.sh config_file gpus_to_use"
5 | exit 1
6 | fi
7 | CONFIG="$1"
8 | GPUS="$2"
9 |
10 | IFS=', ' read -r -a gpus <<<"${GPUS}"
11 | NGPUS="${#gpus[@]}"
12 | PORT="$((29400 + RANDOM % 100))"
13 |
14 | export CUDA_VISIBLE_DEVICES=${GPUS}
15 |
16 | PYTHONPATH="$(dirname "$0")/..":${PYTHONPATH} \
17 | python -m torch.distributed.launch \
18 | --nproc_per_node="${NGPUS}" \
19 | --master_port=${PORT} \
20 | "$(dirname "$0")"/train.py \
21 | "$CONFIG" \
22 | --distribute \
23 | "${@:3}"
24 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
6 |
7 | import cv2 # noqa 402
8 |
9 | from vedastr.runners import InferenceRunner # noqa 402
10 | from vedastr.utils import Config # noqa 402
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description='Inference')
15 | parser.add_argument('config', type=str, help='Config file path')
16 | parser.add_argument('checkpoint', type=str, help='Checkpoint file path')
17 | parser.add_argument('image', type=str, help='input image path')
18 | args = parser.parse_args()
19 |
20 | return args
21 |
22 |
23 | def main():
24 | args = parse_args()
25 |
26 | cfg_path = args.config
27 | cfg = Config.fromfile(cfg_path)
28 |
29 | inference_cfg = cfg['inference']
30 | common_cfg = cfg.get('common')
31 |
32 | runner = InferenceRunner(inference_cfg, common_cfg)
33 | runner.load_checkpoint(args.checkpoint)
34 | if os.path.isfile(args.image):
35 | images = [args.image]
36 | else:
37 | images = [
38 | os.path.join(args.image, name) for name in os.listdir(args.image)
39 | ]
40 | for img in images:
41 | image = cv2.imread(img)
42 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
43 | pred_str, probs = runner(image)
44 | runner.logger.info('Text in {} is:\t {} '.format(pred_str, img))
45 |
46 |
47 | if __name__ == '__main__':
48 | main()
49 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
6 |
7 | from vedastr.runners import TestRunner # noqa 402
8 | from vedastr.utils import Config # noqa 402
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser(description='Test.')
13 | parser.add_argument('config', type=str, help='Config file path')
14 | parser.add_argument('checkpoint', type=str, help='Checkpoint file path')
15 | parser.add_argument('--distribute', default=False, action='store_true')
16 | parser.add_argument('--local_rank', type=int, default=0)
17 | args = parser.parse_args()
18 | if 'LOCAL_RANK' not in os.environ:
19 | os.environ['LOCAL_RANK'] = str(args.local_rank)
20 | return args
21 |
22 |
23 | def main():
24 | args = parse_args()
25 |
26 | cfg_path = args.config
27 | cfg = Config.fromfile(cfg_path)
28 |
29 | _, fullname = os.path.split(cfg_path)
30 | fname, ext = os.path.splitext(fullname)
31 |
32 | root_workdir = cfg.pop('root_workdir')
33 | workdir = os.path.join(root_workdir, fname)
34 | os.makedirs(workdir, exist_ok=True)
35 |
36 | test_cfg = cfg['test']
37 | inference_cfg = cfg['inference']
38 | common_cfg = cfg['common']
39 | common_cfg['workdir'] = workdir
40 | common_cfg['distribute'] = args.distribute
41 |
42 | runner = TestRunner(test_cfg, inference_cfg, common_cfg)
43 | runner.load_checkpoint(args.checkpoint)
44 | runner()
45 |
46 |
47 | if __name__ == '__main__':
48 | main()
49 |
--------------------------------------------------------------------------------
/vedastr_cstr/tools/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import sys
5 |
6 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../"))
7 |
8 | from vedastr.runners import TrainRunner # noqa 402
9 | from vedastr.utils import Config # noqa 402
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser(description="Train.")
14 | parser.add_argument("config", type=str, help="config file path")
15 | parser.add_argument("--distribute", default=False, action="store_true")
16 | parser.add_argument("--local_rank", type=int, default=0)
17 | args = parser.parse_args()
18 | if "LOCAL_RANK" not in os.environ:
19 | os.environ["LOCAL_RANK"] = str(args.local_rank)
20 | return args
21 |
22 |
23 | def main():
24 | args = parse_args()
25 |
26 | cfg_path = args.config
27 | cfg = Config.fromfile(cfg_path)
28 |
29 | _, fullname = os.path.split(cfg_path)
30 | fname, ext = os.path.splitext(fullname)
31 |
32 | root_workdir = cfg.pop("root_workdir")
33 | workdir = os.path.join(root_workdir, fname)
34 | os.makedirs(workdir, exist_ok=True)
35 | # copy corresponding cfg to workdir
36 | shutil.copy(cfg_path, os.path.join(workdir, os.path.basename(cfg_path)))
37 |
38 | train_cfg = cfg["train"]
39 | inference_cfg = cfg["inference"]
40 | common_cfg = cfg["common"]
41 | common_cfg["workdir"] = workdir
42 | common_cfg["distribute"] = args.distribute
43 |
44 | runner = TrainRunner(train_cfg, inference_cfg, common_cfg)
45 | runner()
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/vedastr_cstr/vedastr/__init__.py
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/converter/__init__.py:
--------------------------------------------------------------------------------
1 | from .attn_converter import AttnConverter # noqa 401
2 | from .builder import build_converter # noqa 401
3 | from .ctc_converter import CTCConverter # noqa 401
4 | from .fc_converter import FCConverter # noqa 401
5 | from .custom_converter import CustomConverter
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/converter/attn_converter.py:
--------------------------------------------------------------------------------
1 | # modify from clovaai
2 |
3 | import torch
4 |
5 | from .base_convert import BaseConverter
6 | from .registry import CONVERTERS
7 |
8 |
9 | @CONVERTERS.register_module
10 | class AttnConverter(BaseConverter):
11 |
12 | def __init__(self, character, batch_max_length, go_last=False):
13 | list_character = list(character)
14 | self.batch_max_length = batch_max_length + 1
15 | if go_last:
16 | list_token = ['[s]', '[GO]']
17 | character = list_character + list_token
18 | else:
19 | list_token = ['[GO]', '[s]']
20 | character = list_token + list_character
21 | super(AttnConverter, self).__init__(character=character)
22 | self.ignore_index = self.dict['[GO]']
23 |
24 | def train_encode(self, text):
25 | length = [len(s) + 1 for s in text]
26 | batch_text = torch.LongTensor(len(text), self.batch_max_length + 1).fill_(self.ignore_index) # noqa 501
27 | for idx, t in enumerate(text):
28 | text = list(t)
29 | text.append('[s]')
30 | text = [self.dict[char] for char in text]
31 | batch_text[idx][1:1 + len(text)] = torch.LongTensor(text)
32 | batch_text_input = batch_text[:, :-1]
33 | batch_text_target = batch_text[:, 1:]
34 |
35 | return batch_text_input, torch.IntTensor(length), batch_text_target
36 |
37 | def decode(self, text_index):
38 | texts = []
39 | batch_size = text_index.shape[0]
40 | for index in range(batch_size):
41 | text = ''.join([self.character[i] for i in text_index[index, :]])
42 | text = text[:text.find('[s]')]
43 | texts.append(text)
44 |
45 | return texts
46 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/converter/base_convert.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import torch
4 |
5 | from .registry import CONVERTERS
6 |
7 |
8 | @CONVERTERS.register_module
9 | class BaseConverter(object):
10 | def __init__(self, character):
11 | self.character = list(character)
12 | self.dict = {}
13 | for i, char in enumerate(self.character):
14 | self.dict[char] = i
15 | self.ignore_index = None
16 |
17 | @abc.abstractmethod
18 | def train_encode(self, *args, **kwargs):
19 | '''encode text in train phase'''
20 |
21 | def test_encode(self, text):
22 | if isinstance(text, (list, tuple)):
23 | num = len(text)
24 | elif isinstance(text, int):
25 | num = text
26 | else:
27 | raise TypeError(
28 | f'Type of text should in (list, tuple, int) '
29 | f'but got {type(text)}'
30 | )
31 | ignore_index = self.ignore_index
32 | if ignore_index is None:
33 | ignore_index = 0
34 | batch_text = torch.LongTensor(num, 1).fill_(ignore_index)
35 | length = [1 for i in range(num)]
36 |
37 | return batch_text, torch.IntTensor(length), batch_text
38 |
39 | @abc.abstractmethod
40 | def decode(self, *args, **kwargs):
41 | '''decode label to text in train and test phase'''
42 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/converter/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import CONVERTERS
3 |
4 |
5 | def build_converter(cfg, default_args=None):
6 | converter = build_from_cfg(cfg, CONVERTERS, default_args=default_args)
7 |
8 | return converter
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/converter/ctc_converter.py:
--------------------------------------------------------------------------------
1 | # modify from clovaai
2 |
3 | import torch
4 |
5 | from .base_convert import BaseConverter
6 | from .registry import CONVERTERS
7 |
8 |
9 | @CONVERTERS.register_module
10 | class CTCConverter(BaseConverter):
11 | def __init__(self, character: str, batch_max_length: int):
12 | list_token = ['[blank]']
13 | list_character = list(character)
14 | self.batch_max_length = batch_max_length
15 | super(CTCConverter,
16 | self).__init__(character=list_token + list_character)
17 |
18 | def train_encode(self, text: list):
19 | length = [len(s) for s in text]
20 | batch_text = torch.LongTensor(len(text),
21 | self.batch_max_length).fill_(0)
22 | for i, t in enumerate(text):
23 | text = list(t)
24 | text = [self.dict[char] for char in text]
25 | batch_text[i][:len(text)] = torch.LongTensor(text)
26 |
27 | return batch_text, torch.IntTensor(length), batch_text
28 |
29 | def decode(self, text_index):
30 | texts = []
31 | batch_size = text_index.shape[0]
32 | length = text_index.shape[1]
33 | for i in range(batch_size):
34 | t = text_index[i]
35 | char_list = []
36 | for idx in range(length):
37 | if t[idx] != 0 and (not (idx > 0 and t[idx - 1] == t[idx])):
38 | char_list.append(self.character[t[idx]])
39 | text = ''.join(char_list)
40 | texts.append(text)
41 | return texts
42 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/converter/custom_converter.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .registry import CONVERTERS
4 | from .base_convert import BaseConverter
5 |
6 | # special token 정의
7 | START = "" # 문장 시작
8 | END = "" # 문장 끝
9 | PAD = "" # 패딩
10 | SPECIAL_TOKENS = [START, END, PAD]
11 |
12 | def load_vocab(tokens_paths):
13 | '''
14 | 토큰 txt 파일을 불러와 토큰-id 딕셔너리 생성
15 | - `token_paths`: 토큰 txt 파일 경로
16 | '''
17 | tokens = [] # 토큰 저장 공간
18 | tokens.extend(SPECIAL_TOKENS) # special 토큰 추가
19 |
20 | # 파일을 불러와 토큰 추출
21 | for tokens_file in tokens_paths:
22 | with open(tokens_file, "r") as fd:
23 | reader = fd.read()
24 | for token in reader.split("\n"):
25 | if token not in tokens:
26 | tokens.append(token)
27 | tokens.remove('') # 빈 문자열 제거
28 | token_to_id = {tok: i for i, tok in enumerate(tokens)} # 토큰 to id
29 | id_to_token = {i: tok for i, tok in enumerate(tokens)} # id to 토큰
30 | return tokens, token_to_id, id_to_token
31 |
32 | @CONVERTERS.register_module
33 | class CustomConverter(BaseConverter):
34 |
35 | def __init__(self, character, batch_max_length=25):
36 | token_path = ["/opt/ml/input/data/train_dataset/tokens.txt"] # 토큰 파일 경로
37 | self.character, self.dict, self.id_to_token = load_vocab(token_path) # 토큰 파싱
38 | self.batch_max_length = batch_max_length # 최대 시퀀스 길이 설정
39 | self.ignore_index = self.dict[PAD] # 무시할 토큰 지정
40 |
41 |
42 | def train_encode(self, text_list):
43 | '''
44 | 배치 크기의 시퀀스들을 인코딩
45 | - `text_list`: 배치 크기의 시퀀스 리스트
46 | '''
47 | length = [len(s) + 1 for s in text_list] # batch 안에 있는 ground truth의 길이 저장
48 | batch_text = torch.LongTensor(len(text_list), self.batch_max_length+1).fill_(self.ignore_index) # 인코딩 행렬 생성 (b, max_len+1)
49 | for i, t in enumerate(text_list):
50 | text = t.split(' ')
51 | text.append(END) # EOS 추가
52 | text = [self.dict[char] for char in text] # 토큰 to id로 변환
53 | batch_text[i][:len(text)] = torch.LongTensor(text) # 행렬에 저장
54 | batch_text_input = batch_text
55 | batch_text_target = batch_text
56 |
57 | return batch_text_input, torch.IntTensor(length), batch_text_target
58 |
59 |
60 | def decode(self, text_index):
61 | '''
62 | 배치 크기의 id 시퀀스를 토큰 시퀀스로 디코딩
63 | - `text_index`: 배치 크기의 id 시퀀스 리스트
64 | '''
65 | texts = []
66 | batch_size = text_index.shape[0]
67 | for index in range(batch_size):
68 | text = ' '.join([self.id_to_token[int(i)] for i in text_index[index, :]])
69 | text = text[:text.find(END)] # eos까지만 가져오기
70 | texts.append(text)
71 |
72 | return texts
73 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/converter/fc_converter.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base_convert import BaseConverter
4 | from .registry import CONVERTERS
5 |
6 |
7 | @CONVERTERS.register_module
8 | class FCConverter(BaseConverter):
9 |
10 | def __init__(self, character, batch_max_length=25):
11 |
12 | list_token = ['[s]']
13 | ignore_token = ['[ignore]']
14 | list_character = list(character)
15 | self.batch_max_length = batch_max_length + 1
16 | super(FCConverter, self).__init__(character=list_token + list_character + ignore_token) # noqa 501
17 | self.ignore_index = self.dict[ignore_token[0]]
18 |
19 | def train_encode(self, text):
20 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
21 | batch_text = torch.LongTensor(len(text), self.batch_max_length).fill_(self.ignore_index) # noqa 501
22 | for i, t in enumerate(text):
23 | text = list(t)
24 | text.append('[s]')
25 | text = [self.dict[char] for char in text]
26 | batch_text[i][:len(text)] = torch.LongTensor(text)
27 | batch_text_input = batch_text
28 | batch_text_target = batch_text
29 |
30 | return batch_text_input, torch.IntTensor(length), batch_text_target
31 |
32 | def decode(self, text_index):
33 | texts = []
34 | batch_size = text_index.shape[0]
35 | for index in range(batch_size):
36 | text = ''.join([self.character[i] for i in text_index[index, :]])
37 | text = text[:text.find('[s]')]
38 | texts.append(text)
39 |
40 | return texts
41 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/converter/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | CONVERTERS = Registry('convert')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/criteria/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_criterion # noqa 401
2 | from .cross_entropy_loss import CrossEntropyLoss # noqa 401
3 | from .ctc_loss import CTCLoss # noqa 401
4 | from .label_smooth_cross_entropy_loss import LabelSmoothingCrossEntropy # noqa 401
5 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/criteria/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import CRITERIA
3 |
4 |
5 | def build_criterion(cfg):
6 | criterion = build_from_cfg(cfg, CRITERIA, src='registry')
7 |
8 | return criterion
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/criteria/cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .registry import CRITERIA
4 |
5 |
6 | @CRITERIA.register_module
7 | class CrossEntropyLoss(nn.Module):
8 |
9 | def __init__(self,
10 | weight=None,
11 | size_average=None,
12 | ignore_index=-100,
13 | reduce=None,
14 | reduction='mean'):
15 | super(CrossEntropyLoss, self).__init__()
16 | self.criteron = nn.CrossEntropyLoss(
17 | weight=weight,
18 | size_average=size_average,
19 | ignore_index=ignore_index,
20 | reduce=reduce,
21 | reduction=reduction)
22 |
23 | def forward(self, pred, target, *args):
24 | return self.criteron(pred.contiguous().view(-1, pred.shape[-1]),
25 | target.to(pred.device).contiguous().view(-1))
26 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/criteria/ctc_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .registry import CRITERIA
5 |
6 |
7 | @CRITERIA.register_module
8 | class CTCLoss(nn.Module):
9 |
10 | def __init__(self, zero_infinity=True, blank=0, reduction='mean'):
11 | super(CTCLoss, self).__init__()
12 | self.criterion = nn.CTCLoss(
13 | zero_infinity=zero_infinity, blank=blank, reduction=reduction)
14 |
15 | def forward(self, pred, target, target_length, batch_size):
16 | pred = pred.log_softmax(2)
17 | input_lengths = torch.full(
18 | size=(batch_size, ), fill_value=pred.size(1), dtype=torch.long)
19 | pred_ = pred.permute(1, 0, 2)
20 | cost = self.criterion(
21 | log_probs=pred_,
22 | targets=target.to(pred.device),
23 | input_lengths=input_lengths.to(pred.device),
24 | target_lengths=target_length.to(pred.device))
25 | return cost
26 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/criteria/label_smooth_cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | # modify from fast.ai
2 | # https://github.com/fastai/fastai/blob/8013797e05f0ae0d771d60ecf7cf524da591503c/fastai/layers.py#L300
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from collections import Counter
6 | from .registry import CRITERIA
7 |
8 |
9 | @CRITERIA.register_module
10 | class LabelSmoothingCrossEntropy(nn.Module):
11 |
12 | def __init__(self, eps: float = 0.1, reduction='mean', ignore_index=-100):
13 | super(LabelSmoothingCrossEntropy, self).__init__()
14 | self.eps, self.reduction = eps, reduction
15 | self.ignore_index = ignore_index
16 |
17 | def forward(self, output, target, *args):
18 | pred = output.contiguous().view(-1, output.shape[-1])
19 | target = target.to(pred.device).contiguous().view(-1)
20 |
21 | c = pred.size()[-1]
22 | log_preds = F.log_softmax(pred, dim=-1)
23 |
24 | # ignore index for smooth label
25 | ignore_target = target != self.ignore_index
26 | log_preds = log_preds * ignore_target[:, None]
27 |
28 | if self.reduction == 'sum':
29 | loss = -log_preds.sum()
30 | else:
31 | loss = -log_preds.sum(dim=-1)
32 | if self.reduction == 'mean':
33 | loss = loss.mean()
34 | return loss * self.eps / c + (1 - self.eps) * \
35 | F.nll_loss(log_preds, target, reduction=self.reduction,
36 | ignore_index=self.ignore_index)
37 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/criteria/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | CRITERIA = Registry('criterion')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_dataloader # noqa 401
2 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/dataloaders/builder.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as tud
2 |
3 | from vedastr.utils import WorkerInit, build_from_cfg, get_dist_info
4 | from .registry import DATALOADERS
5 |
6 |
7 | def build_dataloader(distributed,
8 | num_gpus,
9 | cfg,
10 | default_args: dict = None,
11 | seed=None):
12 | cfg_ = cfg.copy()
13 |
14 | samples_per_gpu = cfg_.pop('samples_per_gpu')
15 | workers_per_gpu = cfg_.pop('workers_per_gpu')
16 |
17 | if distributed:
18 | batch_size = samples_per_gpu
19 | num_workers = workers_per_gpu
20 | else:
21 | batch_size = num_gpus * samples_per_gpu
22 | num_workers = num_gpus * workers_per_gpu
23 |
24 | cfg_.update({'batch_size': batch_size, 'num_workers': num_workers})
25 |
26 | dataloaders = {}
27 |
28 | # TODO, other implementations
29 | if DATALOADERS.get(cfg['type']):
30 | packages = DATALOADERS
31 | src = 'registry'
32 | else:
33 | packages = tud
34 | src = 'module'
35 |
36 | # build different dataloaders for different datasets
37 | if isinstance(default_args.get('dataset'), list):
38 | for idx, ds in enumerate(default_args['dataset']):
39 | assert isinstance(ds, tud.Dataset)
40 | if default_args.get('sampler'):
41 | sp = default_args['sampler'][idx]
42 | else:
43 | sp = None
44 | dataloader = build_from_cfg(
45 | cfg_, packages, dict(dataset=ds, sampler=sp), src=src)
46 | if hasattr(ds, 'root'):
47 | name = getattr(ds, 'root')
48 | else:
49 | name = str(idx)
50 | dataloaders[name] = dataloader
51 | else:
52 | rank, _ = get_dist_info()
53 | worker_init_fn = WorkerInit(
54 | num_workers=num_workers, rank=rank, seed=seed, epoch=0)
55 | default_args['worker_init_fn'] = worker_init_fn
56 | dataloaders = build_from_cfg(
57 | cfg_, packages, default_args, src=src) # build a single dataloader
58 |
59 | return dataloaders
60 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/dataloaders/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | DATALOADERS = Registry('dataloader')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/dataloaders/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .balance_sampler import BalanceSampler # noqa 401
2 | from .builder import build_sampler # noqa 401
3 | from .default_sampler import DefaultSampler # noqa 401
4 | from .dist_balance_sampler import BalanceSampler # noqa 401, 811
5 | from .dist_default_sampler import DefaultSampler # noqa 401, 811
6 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/dataloaders/samplers/builder.py:
--------------------------------------------------------------------------------
1 | from .registry import DISTSAMPLER, SAMPLER
2 | from ...utils import build_from_cfg
3 |
4 |
5 | def build_sampler(distributed, cfg, default_args=None):
6 | if distributed:
7 | sampler = build_from_cfg(cfg, DISTSAMPLER, default_args)
8 | else:
9 | sampler = build_from_cfg(cfg, SAMPLER, default_args)
10 |
11 | return sampler
12 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/dataloaders/samplers/default_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Sampler
3 |
4 | from .registry import SAMPLER
5 |
6 |
7 | @SAMPLER.register_module
8 | class DefaultSampler(Sampler):
9 | """Default non-distributed sampler."""
10 |
11 | def __init__(self,
12 | dataset,
13 | shuffle: bool = True,
14 | **kwargs
15 | ):
16 | self.dataset = dataset
17 | self.shuffle = shuffle
18 |
19 | def __iter__(self):
20 | if self.shuffle:
21 | return iter(torch.randperm(len(self.dataset)).tolist())
22 | else:
23 | return iter(range(len(self.dataset)))
24 |
25 | def __len__(self):
26 | return len(self.dataset)
27 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/dataloaders/samplers/dist_default_sampler.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DistributedSampler
2 |
3 | from .registry import DISTSAMPLER
4 | from ...utils import get_dist_info
5 |
6 |
7 | @DISTSAMPLER.register_module
8 | class DefaultSampler(DistributedSampler):
9 | """Default distributed sampler."""
10 |
11 | def __init__(self,
12 | dataset,
13 | shuffle: bool = True,
14 | seed=0,
15 | drop_last=False,
16 | **kwargs
17 | ):
18 | if seed is None:
19 | seed = 0
20 | rank, num_replicas = get_dist_info()
21 | super().__init__(dataset=dataset,
22 | num_replicas=num_replicas,
23 | rank=rank,
24 | shuffle=shuffle,
25 | seed=seed,
26 | drop_last=drop_last,
27 | )
28 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/dataloaders/samplers/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | SAMPLER = Registry('sampler')
4 | DISTSAMPLER = Registry('dist_sampler')
5 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_datasets # noqa 401
2 | from .concat_dataset import ConcatDatasets # noqa 401
3 | from .fold_dataset import FolderDataset # noqa 401
4 | from .lmdb_dataset import LmdbDataset # noqa 401
5 | from .paste_dataset import PasteDataset # noqa 401
6 | from .txt_datasets import TxtDataset # noqa 401
7 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/base.py:
--------------------------------------------------------------------------------
1 | # modify from clovaai
2 |
3 | import logging
4 | import os
5 | import re
6 |
7 | import cv2
8 | from torch.utils.data import Dataset
9 |
10 |
11 | class BaseDataset(Dataset):
12 | """
13 | Args:
14 | root (str): The dir path of image files.
15 | transform: Transformation for images, which will be passed
16 | automatically if you set transform cfg correctly in
17 | configure file.
18 | character (str): The character will be used. We will filter the
19 | sample based on the charatcer.
20 | batch_max_length (int): The max allowed length of the text
21 | after filter.
22 | data_filter (bool): If true, we will filter sample based on the
23 | character. Otherwise not filter.
24 |
25 | """
26 |
27 | def __init__(self,
28 | root: str,
29 | transform=None,
30 | character: str = 'abcdefghijklmnopqrstuvwxyz0123456789',
31 | batch_max_length: int = 100000,
32 | data_filter: bool = True):
33 |
34 | assert type(
35 | root
36 | ) == str, f'The type of root should be str but got {type(root)}'
37 |
38 | self.root = os.path.abspath(root)
39 | self.character = character
40 | self.batch_max_length = batch_max_length
41 | self.data_filter = data_filter
42 |
43 | if transform is not None:
44 | self.transforms = transform
45 | self.samples = 0
46 | self.img_names = []
47 | self.gt_texts = []
48 | self.get_name_list()
49 |
50 | self.logger = logging.getLogger()
51 | self.logger.info(
52 | f'current dataset length is {self.samples} in {self.root}')
53 |
54 | def get_name_list(self):
55 | raise NotImplementedError
56 |
57 | def filter(self, label, retrun_len=False):
58 | if not self.data_filter:
59 | if not retrun_len:
60 | return False
61 | return False, len(label)
62 | """We will filter those samples whose length is larger
63 | than defined max_length by default."""
64 | character = "".join(sorted(self.character, key=lambda x: ord(x)))
65 | out_of_char = f'[^{character}]'
66 | # replace those character not in self.character with ''
67 | label = re.sub(out_of_char, '', label.lower())
68 | # filter whose label larger than batch_max_length
69 | if len(label) > self.batch_max_length:
70 | if not retrun_len:
71 | return True
72 | return True, len(label)
73 | if not retrun_len:
74 | return False
75 | return False, len(label)
76 |
77 | def __getitem__(self, index):
78 | # default img channel is rgb
79 | img = cv2.imread(self.img_names[index])
80 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
81 | label = self.gt_texts[index]
82 |
83 | if self.transforms:
84 | aug = self.transforms(image=img, label=label)
85 | img, label = aug['image'], aug['label']
86 |
87 | return img, label
88 |
89 | def __len__(self):
90 | return self.samples
91 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/builder.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from vedastr.utils import build_from_cfg
4 | from .registry import DATASETS
5 |
6 | logger = logging.getLogger()
7 |
8 |
9 | def build_datasets(cfg, default_args=None):
10 | if isinstance(cfg, list):
11 | datasets = []
12 | for icfg in cfg:
13 | ds = build_from_cfg(icfg, DATASETS, default_args)
14 | datasets.append(ds)
15 | else:
16 | datasets = build_from_cfg(cfg, DATASETS, default_args)
17 |
18 | return datasets
19 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/concat_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import ConcatDataset as _ConcatDataset
2 |
3 | from .builder import build_datasets
4 | from .registry import DATASETS
5 |
6 |
7 | @DATASETS.register_module
8 | class ConcatDatasets(_ConcatDataset):
9 | """ Concat different datasets.
10 |
11 | Args:
12 | datasets (list[dict]): A list of which each elements is a dataset cfg.
13 | batch_ratio (list[float]): Ratio of corresponding dataset will be used
14 | in constructing a batch. It makes effect only with balance
15 | sampler.
16 | **kwargs:
17 |
18 | """
19 |
20 | def __init__(self,
21 | datasets: list,
22 | batch_ratio: list = None,
23 | **kwargs):
24 | assert isinstance(datasets, list)
25 | datasets = build_datasets(datasets, default_args=kwargs)
26 | self.root = ''.join([ds.root for ds in datasets])
27 | data_range = [len(dataset) for dataset in datasets]
28 | self.data_range = [
29 | sum(data_range[:i]) for i in range(1,
30 | len(data_range) + 1)
31 | ]
32 | self.batch_ratio = batch_ratio
33 | if self.batch_ratio is not None:
34 | assert len(self.batch_ratio) == len(datasets), \
35 | 'The length of batch_ratio and datasets should be equal. ' \
36 | f'But got {len(self.batch_ratio)} batch_ratio and ' \
37 | f'{len(datasets)} datasets.'
38 | super(ConcatDatasets, self).__init__(datasets=datasets)
39 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/fold_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .base import BaseDataset
4 | from .registry import DATASETS
5 |
6 |
7 | @DATASETS.register_module
8 | class FolderDataset(BaseDataset):
9 | """ Read images in a folder. The format of image filename should be
10 | same as follows:
11 | 'name_gt.extension', where name represents arbitrary string,
12 | gt represents the ground-truth of the image, and extension
13 | represents the postfix (png, jpg, etc.).
14 |
15 | """
16 |
17 | def __init__(self,
18 | root: str,
19 | transform=None,
20 | character: str = 'abcdefghijklmnopqrstuvwxyz0123456789',
21 | batch_max_length: int = 100000,
22 | data_filter: bool = True,
23 | extension_names: tuple = ('.jpg', '.png', '.bmp', '.jpeg')):
24 | super(FolderDataset, self).__init__(
25 | root=root,
26 | transform=transform,
27 | character=character,
28 | batch_max_length=batch_max_length,
29 | data_filter=data_filter)
30 | self.extension_names = extension_names
31 |
32 | @staticmethod
33 | def parse_filename(text):
34 | return text.split('_')[-1]
35 |
36 | def get_name_list(self):
37 | for item in os.listdir(self.root):
38 | file_name, file_extension = os.path.splitext(item)
39 | if file_extension in self.extension_names:
40 | label = self.parse_filename(file_name)
41 | if self.filter(label):
42 | continue
43 | else:
44 | self.img_names.append(os.path.join(self.root, item))
45 | self.gt_texts.append(label)
46 | self.samples = len(self.gt_texts)
47 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/lmdb_dataset.py:
--------------------------------------------------------------------------------
1 | # modify from clovaai
2 |
3 | import logging
4 |
5 | import lmdb
6 | import numpy as np
7 | import six
8 | from PIL import Image
9 |
10 | from .base import BaseDataset
11 | from .registry import DATASETS
12 |
13 | logger = logging.getLogger()
14 |
15 |
16 | @DATASETS.register_module
17 | class LmdbDataset(BaseDataset):
18 | """ Read the data of lmdb format.
19 | Please refer to https://github.com/Media-Smart/vedastr/issues/27#issuecomment-691793593 # noqa 501
20 | if you have problems with creating lmdb format file.
21 |
22 | """
23 |
24 | def __init__(self,
25 | root: str,
26 | transform=None,
27 | character: str = 'abcdefghijklmnopqrstuvwxyz0123456789',
28 | batch_max_length: int = 100000,
29 | data_filter: bool = True):
30 | self.index_list = []
31 | super(LmdbDataset, self).__init__(
32 | root=root,
33 | transform=transform,
34 | character=character,
35 | batch_max_length=batch_max_length,
36 | data_filter=data_filter)
37 |
38 | def get_name_list(self):
39 | self.env = lmdb.open(
40 | self.root,
41 | max_readers=32,
42 | readonly=True,
43 | lock=False,
44 | readahead=False,
45 | meminit=False)
46 | with self.env.begin(write=False) as txn:
47 | n_samples = int(txn.get('num-samples'.encode()))
48 | for index in range(n_samples):
49 | idx = index + 1 # lmdb starts with 1
50 | label_key = 'label-%09d'.encode() % idx
51 | label = txn.get(label_key).decode('utf-8')
52 | if self.filter(
53 | label
54 | ): # if length of label larger than max_len, drop this sample
55 | continue
56 | else:
57 | self.index_list.append(idx)
58 | self.samples = len(self.index_list)
59 |
60 | def read_data(self, index):
61 | assert index <= len(self), 'index range error'
62 | index = self.index_list[index]
63 | with self.env.begin(write=False) as txn:
64 | label_key = 'label-%09d'.encode() % index
65 | label = txn.get(label_key).decode('utf-8')
66 | img_key = 'image-%09d'.encode() % index
67 | imgbuf = txn.get(img_key)
68 |
69 | buf = six.BytesIO()
70 | buf.write(imgbuf)
71 | buf.seek(0)
72 | img = Image.open(buf).convert('RGB') # for color image
73 | img = np.array(img)
74 |
75 | return img, label
76 |
77 | def __getitem__(self, index):
78 |
79 | img, label = self.read_data(index)
80 | if self.transforms:
81 | aug = self.transforms(image=img, label=label)
82 | img, label = aug['image'], aug['label']
83 |
84 | return img, label
85 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/paste_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import cv2
4 | import lmdb
5 |
6 | from .lmdb_dataset import LmdbDataset
7 | from .registry import DATASETS
8 |
9 |
10 | @DATASETS.register_module
11 | class PasteDataset(LmdbDataset):
12 | """ Concat two images and the combined label should satisfy the constrains.
13 |
14 | Args:
15 | p: The probability of pasting operation.
16 |
17 | Warnings:: We will create a new transform operation to
18 | replace this dataset.
19 | """
20 |
21 | def __init__(self, p: float = 0.1, *args, **kwargs):
22 | self.len_sample = dict()
23 | self.len_lists = list()
24 | self.p = p
25 | super(PasteDataset, self).__init__(*args, **kwargs)
26 |
27 | def get_name_list(self):
28 | self.env = lmdb.open(
29 | self.root,
30 | max_readers=32,
31 | readonly=True,
32 | lock=False,
33 | readahead=False,
34 | meminit=False)
35 | with self.env.begin(write=False) as txn:
36 | n_samples = int(txn.get('num-samples'.encode()))
37 | for index in range(n_samples):
38 | idx = index + 1 # lmdb starts with 1
39 | label_key = 'label-%09d'.encode() % idx
40 | label = txn.get(label_key).decode('utf-8')
41 | flag, length = self.filter(label, retrun_len=True)
42 | if flag:
43 | continue
44 | else:
45 | self.index_list.append(idx)
46 | self.len_lists.append(length)
47 | if length not in self.len_sample:
48 | self.len_sample[length] = [len(self.index_list) - 1]
49 | else:
50 | self.len_sample[length].append(
51 | len(self.index_list) - 1)
52 |
53 | self.samples = len(self.index_list)
54 |
55 | def __getitem__(self, index):
56 | img, label = self.read_data(index)
57 | th, tw = img.shape[:2]
58 | c_len = self.len_lists[index]
59 | max_need_len = self.batch_max_length - c_len
60 | if max_need_len > 2 and random.random() < self.p:
61 | p_len = random.randint(1, max_need_len - 1)
62 | p_idx = random.choice(self.len_sample[p_len])
63 | p_img, p_label = self.read_data(p_idx)
64 | p_img = cv2.resize(p_img, (tw, th))
65 | img = cv2.hconcat([img, p_img])
66 | label = label + p_label
67 |
68 | if self.transforms:
69 | aug = self.transforms(image=img, label=label)
70 | img, label = aug['image'], aug['label']
71 |
72 | return img, label
73 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | DATASETS = Registry('datasets')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/datasets/txt_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .base import BaseDataset
4 | from .registry import DATASETS
5 |
6 |
7 | @DATASETS.register_module
8 | class TxtDataset(BaseDataset):
9 | """ Read images based on the txt file.
10 | The format of lines in txt should be same as follows:
11 | image_path label
12 |
13 | The image path and label should be split with '\t'.
14 |
15 | """
16 |
17 | def __init__(
18 | self,
19 | root: str,
20 | gt_txt: str,
21 | transform=None,
22 | character: str = 'abcdefghijklmnopqrstuvwxyz0123456789',
23 | batch_max_length: int = 25,
24 | data_filter: bool = True,
25 | ):
26 | # ground truth 파일 있는지 검사
27 | if gt_txt is not None:
28 | assert os.path.isfile(gt_txt)
29 | self.gt_txt = gt_txt
30 | super(TxtDataset, self).__init__(
31 | root=root,
32 | transform=transform,
33 | character=character,
34 | batch_max_length=batch_max_length,
35 | data_filter=data_filter,
36 | )
37 |
38 | def get_name_list(self):
39 | with open(self.gt_txt, 'r') as gt:
40 | for line in gt.readlines():
41 | img_name, label = line.strip().split('\t')
42 | if self.filter(label):
43 | continue
44 | else:
45 | self.img_names.append(os.path.join(self.root, img_name))
46 | self.gt_texts.append(label)
47 |
48 | self.samples = len(self.gt_texts)
49 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_logger # noqa 401
2 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/logger/builder.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import time
5 |
6 | import torch.distributed as dist
7 |
8 |
9 | def build_logger(cfg, default_args):
10 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
11 | format_ = '%(asctime)s - %(levelname)s - %(message)s'
12 |
13 | formatter = logging.Formatter(format_)
14 | logger = logging.getLogger()
15 | logger.setLevel(logging.DEBUG)
16 | if logger.parent is not None:
17 | logger.parent.handlers.clear()
18 | else:
19 | logger.handlers.clear()
20 | if dist.is_available() and dist.is_initialized():
21 | rank = dist.get_rank()
22 | else:
23 | rank = 0
24 | for handler in cfg['handlers']:
25 | if handler['type'] == 'StreamHandler':
26 | instance = logging.StreamHandler(sys.stdout)
27 | elif handler['type'] == 'FileHandler':
28 | # only rank 0 will add a FileHandler
29 | if default_args.get('workdir') and rank == 0:
30 | fp = os.path.join(default_args['workdir'],
31 | '%s.log' % timestamp)
32 | instance = logging.FileHandler(fp, 'w')
33 | else:
34 | continue
35 | else:
36 | instance = logging.StreamHandler(sys.stdout)
37 |
38 | level = getattr(logging, handler['level'])
39 |
40 | instance.setFormatter(formatter)
41 | if rank == 0:
42 | instance.setLevel(level)
43 | else:
44 | logger.setLevel(logging.ERROR)
45 |
46 | logger.addHandler(instance)
47 |
48 | return logger
49 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/lr_schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_lr_scheduler # noqa 401
2 | from .constant_lr import ConstantLR # noqa 401
3 | from .cosine_lr import CosineLR # noqa 401
4 | from .exponential_lr import ExponentialLR # noqa 401
5 | from .poly_lr import PolyLR # noqa 401
6 | from .step_lr import StepLR # noqa 401
7 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/lr_schedulers/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import LR_SCHEDULERS
3 |
4 |
5 | def build_lr_scheduler(cfg, default_args=None):
6 | scheduler = build_from_cfg(cfg, LR_SCHEDULERS, default_args, 'registry')
7 | return scheduler
8 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/lr_schedulers/constant_lr.py:
--------------------------------------------------------------------------------
1 | from .base import _Iter_LRScheduler
2 | from .registry import LR_SCHEDULERS
3 |
4 |
5 | @LR_SCHEDULERS.register_module
6 | class ConstantLR(_Iter_LRScheduler):
7 | """ConstantLR
8 | """
9 |
10 | def __init__(self,
11 | optimizer,
12 | niter_per_epoch,
13 | last_iter=-1,
14 | warmup_epochs=0,
15 | iter_based=True,
16 | **kwargs):
17 | self.warmup_iters = niter_per_epoch * warmup_epochs
18 | super().__init__(optimizer, niter_per_epoch, last_iter, iter_based)
19 |
20 | def get_lr(self):
21 | if self.last_iter < self.warmup_iters:
22 | multiplier = self.last_iter / float(self.warmup_iters)
23 | else:
24 | multiplier = 1.0
25 | return [base_lr * multiplier for base_lr in self.base_lrs]
26 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/lr_schedulers/cosine_lr.py:
--------------------------------------------------------------------------------
1 | from math import cos, pi
2 |
3 | from .base import _Iter_LRScheduler
4 | from .registry import LR_SCHEDULERS
5 |
6 |
7 | @LR_SCHEDULERS.register_module
8 | class CosineLR(_Iter_LRScheduler):
9 | """CosineLR
10 | """
11 |
12 | def __init__(self,
13 | optimizer,
14 | niter_per_epoch,
15 | max_epochs,
16 | last_iter=-1,
17 | warmup_epochs=0,
18 | iter_based=True):
19 | self.max_iters = niter_per_epoch * max_epochs
20 | self.warmup_iters = niter_per_epoch * warmup_epochs
21 | super().__init__(optimizer, niter_per_epoch, last_iter, iter_based)
22 |
23 | def get_lr(self):
24 | if self.last_iter < self.warmup_iters:
25 | multiplier = 0.5 * (
26 | 1 - cos(pi * (self.last_iter / float(self.warmup_iters))))
27 | else:
28 | multiplier = 0.5 * (1 + cos(pi * (
29 | (self.last_iter - self.warmup_iters) /
30 | float(self.max_iters - self.warmup_iters))))
31 | return [base_lr * multiplier for base_lr in self.base_lrs]
32 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/lr_schedulers/exponential_lr.py:
--------------------------------------------------------------------------------
1 | from .base import _Iter_LRScheduler
2 | from .registry import LR_SCHEDULERS
3 |
4 |
5 | @LR_SCHEDULERS.register_module
6 | class ExponentialLR(_Iter_LRScheduler):
7 | """ExponentialLR
8 | """
9 |
10 | def __init__(self,
11 | optimizer,
12 | niter_per_epoch,
13 | max_epochs,
14 | gamma,
15 | step,
16 | last_iter=-1,
17 | warmup_epochs=0,
18 | iter_based=True):
19 | self.max_iters = niter_per_epoch * max_epochs
20 | self.gamma = gamma
21 | self.step_iters = niter_per_epoch * step
22 | self.warmup_iters = int(niter_per_epoch * warmup_epochs)
23 | super().__init__(optimizer, niter_per_epoch, last_iter, iter_based)
24 |
25 | def get_lr(self):
26 | if self.last_iter < self.warmup_iters:
27 | multiplier = self.last_iter / float(self.warmup_iters)
28 | else:
29 | multiplier = self.gamma**((self.last_iter - self.warmup_iters) /
30 | float(self.step_iters))
31 | return [base_lr * multiplier for base_lr in self.base_lrs]
32 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/lr_schedulers/poly_lr.py:
--------------------------------------------------------------------------------
1 | from .base import _Iter_LRScheduler
2 | from .registry import LR_SCHEDULERS
3 |
4 |
5 | @LR_SCHEDULERS.register_module
6 | class PolyLR(_Iter_LRScheduler):
7 | """PolyLR
8 | """
9 |
10 | def __init__(self,
11 | optimizer,
12 | niter_per_epoch,
13 | max_epochs,
14 | power=0.9,
15 | last_iter=-1,
16 | warmup_epochs=0,
17 | iter_based=True):
18 | self.max_iters = niter_per_epoch * max_epochs
19 | self.power = power
20 | self.warmup_iters = niter_per_epoch * warmup_epochs
21 | super().__init__(optimizer, niter_per_epoch, last_iter, iter_based)
22 |
23 | def get_lr(self):
24 | if self.last_iter < self.warmup_iters:
25 | multiplier = (self.last_iter /
26 | float(self.warmup_iters))**self.power
27 | else:
28 | multiplier = (
29 | 1 - (self.last_iter - self.warmup_iters) /
30 | float(self.max_iters - self.warmup_iters))**self.power
31 | return [base_lr * multiplier for base_lr in self.base_lrs]
32 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/lr_schedulers/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | LR_SCHEDULERS = Registry('lr_scheduler')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/lr_schedulers/step_lr.py:
--------------------------------------------------------------------------------
1 | from .base import _Iter_LRScheduler
2 | from .registry import LR_SCHEDULERS
3 |
4 |
5 | @LR_SCHEDULERS.register_module
6 | class StepLR(_Iter_LRScheduler):
7 |
8 | def __init__(self,
9 | optimizer,
10 | niter_per_epoch,
11 | max_epochs,
12 | milestones,
13 | gamma=0.1,
14 | last_iter=-1,
15 | warmup_epochs=0,
16 | iter_based=True):
17 | self.max_iters = niter_per_epoch * max_epochs
18 | self.milestones = milestones
19 | self.count = 0
20 | self.gamma = gamma
21 | self.warmup_iters = int(niter_per_epoch * warmup_epochs)
22 | super(StepLR, self).__init__(optimizer, niter_per_epoch, last_iter,
23 | iter_based)
24 |
25 | def get_lr(self):
26 | if self._iter_based and self.last_iter in self.milestones:
27 | self.count += 1
28 | elif not self._iter_based and self.last_epoch in self.milestones:
29 | self.count += 1
30 |
31 | if self.last_iter < self.warmup_iters:
32 | multiplier = self.last_iter / float(self.warmup_iters)
33 | else:
34 | multiplier = self.gamma**self.count
35 | return [base_lr * multiplier for base_lr in self.base_lrs]
36 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .accuracy import Accuracy # noqa 401
2 | from .builder import build_metric # noqa 401
3 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/metrics/accuracy.py:
--------------------------------------------------------------------------------
1 | # modify from clovaai
2 | import random
3 |
4 | import torch
5 | from nltk.metrics.distance import edit_distance
6 |
7 | from .registry import METRICS
8 | from ..utils import gather_tensor, get_dist_info
9 |
10 |
11 | @METRICS.register_module
12 | class Accuracy(object):
13 |
14 | def __init__(self):
15 | self.reset()
16 | self.predict_example_log = None
17 |
18 | @property
19 | def result(self):
20 | res = {
21 | 'acc': self.avg['acc']['true'],
22 | 'edit_distance': self.avg['edit'],
23 | }
24 | return res
25 |
26 | def measure(self, preds, preds_prob, gts, exclude_num=0):
27 | batch_size = len(gts)
28 | true_nums = []
29 | norm_EDs = []
30 | r, w = get_dist_info()
31 | for pstr, gstr in zip(preds, gts):
32 | if pstr == gstr:
33 | true_nums.append(1.)
34 | else:
35 | true_nums.append(0.)
36 | if len(pstr) == 0 or len(gstr) == 0:
37 | norm_EDs.append(0)
38 | elif len(gstr) > len(pstr):
39 | norm_EDs.append(1 - edit_distance(pstr, gstr) / len(gstr))
40 | else:
41 | norm_EDs.append(1 - edit_distance(pstr, gstr) / len(pstr))
42 | # gather batch_size, true_num, norm_ED from different workers
43 | batch_sizes = gather_tensor(torch.tensor(batch_size)[None].cuda())
44 | true_nums = gather_tensor(
45 | torch.tensor(true_nums)[None].cuda()).flatten()
46 | norm_EDs = gather_tensor(torch.tensor(norm_EDs)[None].cuda()).flatten()
47 |
48 | # remove exclude data
49 | if exclude_num != 0:
50 | batch_size = torch.sum(batch_sizes).cpu().numpy() - exclude_num
51 | true_nums = list(true_nums.split(true_nums.shape[0] // w))
52 | for i in range(1, exclude_num + 1):
53 | true_nums[-i] = true_nums[-i][:-1]
54 | true_nums = torch.cat(true_nums)
55 | # true_nums = true_nums.flatten()[:-exclude_num]
56 | norm_EDs = list(norm_EDs.split(true_nums.shape[0] // w))
57 | for i in range(1, exclude_num + 1):
58 | norm_EDs[-i] = norm_EDs[-i][:-1]
59 | norm_EDs = torch.cat(norm_EDs)
60 | # norm_EDs = norm_EDs.flatten()[:-exclude_num]
61 | else:
62 | batch_size = torch.sum(batch_sizes).cpu().numpy()
63 |
64 | true_num = torch.sum(true_nums).cpu().numpy()
65 | norm_ED = torch.sum(norm_EDs).cpu().numpy()
66 |
67 | if preds_prob is not None:
68 | self.show_example(preds, preds_prob, gts)
69 | self.all['acc']['true'] += true_num
70 | self.all['acc']['false'] += (batch_size - true_num)
71 | self.all['edit'] += norm_ED
72 | self.count += batch_size
73 | for key, value in self.all['acc'].items():
74 | self.avg['acc'][key] = self.all['acc'][key] / self.count
75 | self.avg['edit'] = self.all['edit'] / self.count
76 |
77 | def reset(self):
78 | self.all = dict(acc=dict(true=0, false=0), edit=0)
79 | self.avg = dict(acc=dict(true=0, false=0), edit=0)
80 | self.count = 0
81 |
82 | def show_example(self, preds, preds_prob, gts):
83 | count = 0
84 | self.predict_example_log = None
85 | dashed_line = '-' * 80
86 |
87 | output = dashed_line + '\n'
88 | show_inds = list(range(len(gts)))
89 | random.shuffle(show_inds)
90 | show_inds = show_inds[:5]
91 | show_gts = [gts[i] for i in show_inds]
92 | show_preds = [preds[i] for i in show_inds]
93 | show_prob = [preds_prob[i] for i in show_inds]
94 | for gt, pred, prob in zip(show_gts, show_preds, show_prob):
95 | output += f'{"Ground Truth : ":18s} {gt:40s} \n' \
96 | f'{"Prediction : ":18s} {pred:40s} \n' \
97 | f'{"Confidence : ":18s} {prob:0.4f} \n' \
98 | f'{"Success(T/F) : ":18s} {str(pred == gt)} \n' \
99 | f'{dashed_line}\n'
100 | count += 1
101 | if count > 4:
102 | break
103 |
104 | output += dashed_line
105 |
106 | self.predict_example_log = output
107 |
108 | return self.predict_example_log
109 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/metrics/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import METRICS
3 |
4 |
5 | def build_metric(cfg, default_args=None):
6 | metric = build_from_cfg(cfg, METRICS, default_args)
7 |
8 | return metric
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/metrics/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | METRICS = Registry('metrics')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_model # noqa 401
2 | from .model import GModel # noqa 401
3 | from .registry import MODELS # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/__init__.py:
--------------------------------------------------------------------------------
1 | from .body import GBody # noqa 401
2 | from .builder import build_body, build_component # noqa 401
3 | from .feature_extractors import build_brick, build_feature_extractor # noqa 401
4 | from .rectificators import build_rectificator # noqa 401
5 | from .sequences import build_sequence_decoder, build_sequence_encoder # noqa 401
6 | from .component import (BrickComponent, FeatureExtractorComponent, # noqa 401
7 | PlugComponent, RectificatorComponent, # noqa 401
8 | SequenceEncoderComponent) # noqa 401
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/body.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .builder import build_component
4 | from .feature_extractors import build_brick
5 | from .registry import BODIES
6 |
7 |
8 | @BODIES.register_module
9 | class GBody(nn.Module):
10 |
11 | def __init__(self, pipelines, collect=None):
12 | super(GBody, self).__init__()
13 |
14 | self.input_to_layer = 'input'
15 | self.components = nn.ModuleList(
16 | [build_component(component) for component in pipelines])
17 |
18 | if collect is not None:
19 | self.collect = build_brick(collect)
20 |
21 | @property
22 | def with_collect(self):
23 | return hasattr(self, 'collect') and self.collect is not None
24 |
25 | def forward(self, x):
26 | feats = {self.input_to_layer: x}
27 |
28 | for component in self.components:
29 | component_from = component.from_layer
30 | component_to = component.to_layer
31 |
32 | if isinstance(component_from, list):
33 | inp = {key: feats[key] for key in component_from}
34 | out = component(**inp)
35 | else:
36 | inp = feats[component_from]
37 | out = component(inp)
38 | feats[component_to] = out
39 |
40 | if self.with_collect:
41 | return self.collect(feats)
42 | else:
43 | return feats
44 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import BODIES, COMPONENT
3 |
4 |
5 | def build_component(cfg, default_args=None):
6 | component = build_from_cfg(cfg, COMPONENT, default_args)
7 |
8 | return component
9 |
10 |
11 | def build_body(cfg, default_args=None):
12 | body = build_from_cfg(cfg, BODIES, default_args)
13 |
14 | return body
15 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/component.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .feature_extractors import build_brick, build_feature_extractor
4 | from .rectificators import build_rectificator
5 | from .registry import COMPONENT
6 | from .sequences import build_sequence_encoder
7 | from ..utils import build_module
8 |
9 |
10 | class BaseComponent(nn.Module):
11 |
12 | def __init__(self, from_layer, to_layer, component):
13 | super(BaseComponent, self).__init__()
14 |
15 | self.from_layer = from_layer
16 | self.to_layer = to_layer
17 | self.component = component
18 |
19 | def forward(self, x):
20 | return self.component(x)
21 |
22 |
23 | @COMPONENT.register_module
24 | class FeatureExtractorComponent(BaseComponent):
25 |
26 | def __init__(self, from_layer, to_layer, arch):
27 | super(FeatureExtractorComponent,
28 | self).__init__(from_layer, to_layer,
29 | build_feature_extractor(arch))
30 |
31 |
32 | @COMPONENT.register_module
33 | class RectificatorComponent(BaseComponent):
34 |
35 | def __init__(self, from_layer, to_layer, arch):
36 | super(RectificatorComponent, self).__init__(from_layer, to_layer,
37 | build_rectificator(arch))
38 |
39 |
40 | @COMPONENT.register_module
41 | class SequenceEncoderComponent(BaseComponent):
42 |
43 | def __init__(self, from_layer, to_layer, arch):
44 | super(SequenceEncoderComponent,
45 | self).__init__(from_layer, to_layer,
46 | build_sequence_encoder(arch))
47 |
48 |
49 | @COMPONENT.register_module
50 | class BrickComponent(BaseComponent):
51 |
52 | def __init__(self, from_layer, to_layer, arch):
53 | super(BrickComponent, self).__init__(from_layer, to_layer,
54 | build_brick(arch))
55 |
56 |
57 | @COMPONENT.register_module
58 | class PlugComponent(BaseComponent):
59 |
60 | def __init__(self, from_layer, to_layer, arch):
61 | super(PlugComponent, self).__init__(from_layer, to_layer,
62 | build_module(arch))
63 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_feature_extractor # noqa 401
2 | from .decoders import build_brick, build_bricks, build_decoder # noqa 401
3 | from .encoders import build_backbone, build_encoder, build_enhance_module # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/builder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .decoders import build_brick, build_decoder
4 | from .encoders import build_encoder
5 |
6 |
7 | def build_feature_extractor(cfg):
8 | encoder = build_encoder(cfg.get('encoder'))
9 |
10 | if cfg.get('decoder'):
11 | middle = build_decoder(cfg.get('decoder'))
12 | if 'collect' in cfg:
13 | final = build_brick(cfg.get('collect'))
14 | feature_extractor = nn.Sequential(encoder, middle, final)
15 | else:
16 | feature_extractor = nn.Sequential(encoder, middle)
17 | # assert 'collect' not in cfg
18 | else:
19 | assert 'collect' in cfg
20 | middle = build_brick(cfg.get('collect'))
21 | feature_extractor = nn.Sequential(encoder, middle)
22 |
23 | return feature_extractor
24 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/__init__.py:
--------------------------------------------------------------------------------
1 | from .bricks import build_brick, build_bricks # noqa 401
2 | from .builder import build_decoder # noqa 401
3 | from .gfpn import GFPN # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/bricks/__init__.py:
--------------------------------------------------------------------------------
1 | from .bricks import (CellAttentionBlock, CollectBlock, FusionBlock, # noqa 401
2 | JunctionBlock) # noqa 401
3 | from .builder import build_brick, build_bricks # noqa 401
4 | from .pva import PVABlock # noqa 401
5 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/bricks/builder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from vedastr.utils import build_from_cfg
4 | from .registry import BRICKS
5 |
6 |
7 | def build_brick(cfg, default_args=None):
8 | brick = build_from_cfg(cfg, BRICKS, default_args)
9 | return brick
10 |
11 |
12 | def build_bricks(cfgs):
13 | bricks = nn.ModuleList()
14 | for brick_cfg in cfgs:
15 | bricks.append(build_brick(brick_cfg))
16 | return bricks
17 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/bricks/pva.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from vedastr.models.weight_init import init_weights
5 | from .registry import BRICKS
6 |
7 |
8 | @BRICKS.register_module
9 | class PVABlock(nn.Module):
10 |
11 | def __init__(self,
12 | num_steps,
13 | in_channels,
14 | embedding_channels=512,
15 | inner_channels=512):
16 | super(PVABlock, self).__init__()
17 |
18 | self.num_steps = num_steps
19 | self.in_channels = in_channels
20 | self.inner_channels = inner_channels
21 | self.embedding_channels = embedding_channels
22 |
23 | self.order_embeddings = nn.Parameter(
24 | torch.randn(self.num_steps, self.embedding_channels),
25 | requires_grad=True)
26 |
27 | self.v_linear = nn.Linear(
28 | self.in_channels, self.inner_channels, bias=False)
29 | self.o_linear = nn.Linear(
30 | self.embedding_channels, self.inner_channels, bias=False)
31 | self.e_linear = nn.Linear(self.inner_channels, 1, bias=False)
32 |
33 | init_weights(self.modules())
34 |
35 | def forward(self, x):
36 | b, c, h, w = x.size()
37 |
38 | x = x.reshape(b, c, h * w).permute(0, 2, 1)
39 |
40 | o_out = self.o_linear(self.order_embeddings).view(
41 | 1, self.num_steps, 1, self.inner_channels)
42 | v_out = self.v_linear(x).unsqueeze(1)
43 | att = self.e_linear(torch.tanh(o_out + v_out)).squeeze(3)
44 | att = torch.softmax(att, dim=2)
45 |
46 | out = torch.bmm(att, x)
47 |
48 | return out
49 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/bricks/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | BRICKS = Registry('brick')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import DECODERS
3 |
4 |
5 | def build_decoder(cfg, default_args=None):
6 | decoder = build_from_cfg(cfg, DECODERS, default_args)
7 | return decoder
8 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/gfpn.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | from vedastr.models.weight_init import init_weights
6 | from .bricks import build_brick, build_bricks
7 | from .registry import DECODERS
8 |
9 | logger = logging.getLogger()
10 |
11 |
12 | @DECODERS.register_module
13 | class GFPN(nn.Module):
14 |
15 | def __init__(self, neck: list, fusion: dict = None):
16 | super().__init__()
17 | self.neck = build_bricks(neck)
18 | if fusion:
19 | self.fusion = build_brick(fusion)
20 | else:
21 | self.fusion = None
22 | logger.info('GFPN init weights')
23 | init_weights(self.modules())
24 |
25 | def forward(self, bottom_up):
26 |
27 | x = None
28 | feats = {}
29 | for ii, layer in enumerate(self.neck):
30 | top_down_from_layer = layer.from_layer.get('top_down')
31 | lateral_from_layer = layer.from_layer.get('lateral')
32 |
33 | if lateral_from_layer:
34 | ll = bottom_up[lateral_from_layer]
35 | else:
36 | ll = None
37 | if top_down_from_layer is None:
38 | td = None
39 | elif 'c' in top_down_from_layer:
40 | td = bottom_up[top_down_from_layer]
41 | elif 'p' in top_down_from_layer:
42 | td = feats[top_down_from_layer]
43 | else:
44 | raise ValueError('Key error')
45 |
46 | x = layer(td, ll)
47 | feats[layer.to_layer] = x
48 | bottom_up[layer.to_layer] = x
49 |
50 | if self.fusion:
51 | x = self.fusion(feats)
52 | bottom_up['fusion'] = x
53 | return bottom_up
54 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | DECODERS = Registry('decoder')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from .backbones import build_backbone # noqa 401
2 | from .builder import build_encoder # noqa 401
3 | from .enhance_modules import build_enhance_module # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_backbone # noqa 401
2 | from .general_backbone import GBackbone # noqa 401
3 | from .resnet import GResNet, ResNet # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/backbones/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import BACKBONES
3 |
4 |
5 | def build_backbone(cfg, default_args=None):
6 | backbone = build_from_cfg(cfg, BACKBONES, default_args)
7 |
8 | return backbone
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/backbones/general_backbone.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | from vedastr.models.utils import build_module, build_torch_nn
6 | from vedastr.models.weight_init import init_weights
7 | from .registry import BACKBONES
8 |
9 | logger = logging.getLogger()
10 |
11 |
12 | @BACKBONES.register_module
13 | class GBackbone(nn.Module):
14 |
15 | def __init__(
16 | self,
17 | layers: list,
18 | ):
19 | super(GBackbone, self).__init__()
20 |
21 | self.layers = nn.ModuleList()
22 | stage_layers = []
23 | for layer_cfg in layers:
24 | type_name = layer_cfg['type']
25 | if hasattr(nn, type_name):
26 | layer = build_torch_nn(layer_cfg)
27 | else:
28 | layer = build_module(layer_cfg)
29 | stride = layer_cfg.get('stride', 1)
30 | max_stride = stride if isinstance(stride, int) else max(stride)
31 | if max_stride > 1:
32 | self.layers.append(nn.Sequential(*stage_layers))
33 | stage_layers = []
34 | stage_layers.append(layer)
35 | self.layers.append(nn.Sequential(*stage_layers))
36 | logger.info('GBackbone init weights')
37 | init_weights(self.modules())
38 |
39 | def forward(self, x):
40 | feats = {}
41 |
42 | for i, layer in enumerate(self.layers):
43 | x = layer(x)
44 | feats['c{}'.format(i)] = x
45 |
46 | return feats
47 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/backbones/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | BACKBONES = Registry('backbone')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/builder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .backbones import build_backbone
4 | from .enhance_modules import build_enhance_module
5 |
6 |
7 | def build_encoder(cfg, default_args=None):
8 | backbone = build_backbone(cfg['backbone'])
9 |
10 | enhance_cfg = cfg.get('enhance')
11 | if enhance_cfg:
12 | enhance_module = build_enhance_module(enhance_cfg)
13 | encoder = nn.Sequential(backbone, enhance_module)
14 | else:
15 | encoder = backbone
16 |
17 | return encoder
18 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .aspp import ASPP # noqa 401
2 | from .builder import build_enhance_module # noqa 401
3 | from .ppm import PPM # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/aspp.py:
--------------------------------------------------------------------------------
1 | # modify from https://github.com/pytorch/vision/tree/master/torchvision/models/segmentation/deeplabv3.py # noqa 501
2 |
3 | import logging
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from vedastr.models.weight_init import init_weights
10 | from .registry import ENHANCE_MODULES
11 |
12 | logger = logging.getLogger()
13 |
14 |
15 | class ASPPConv(nn.Sequential):
16 |
17 | def __init__(self, in_channels, out_channels, dilation):
18 | modules = [
19 | nn.Conv2d(
20 | in_channels,
21 | out_channels,
22 | 3,
23 | padding=dilation,
24 | dilation=dilation,
25 | bias=False),
26 | nn.BatchNorm2d(out_channels),
27 | nn.ReLU(inplace=True)
28 | ]
29 | super(ASPPConv, self).__init__(*modules)
30 |
31 |
32 | class ASPPPooling(nn.Sequential):
33 |
34 | def __init__(self, in_channels, out_channels):
35 | super(ASPPPooling, self).__init__(
36 | nn.AdaptiveAvgPool2d(1),
37 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
38 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))
39 |
40 | def forward(self, x):
41 | size = x.shape[-2:]
42 | x = super(ASPPPooling, self).forward(x)
43 | return F.interpolate(x, size=size, mode='nearest')
44 |
45 |
46 | @ENHANCE_MODULES.register_module
47 | class ASPP(nn.Module):
48 |
49 | def __init__(self,
50 | in_channels: int,
51 | out_channels: int,
52 | atrous_rates: tuple,
53 | from_layer: str,
54 | to_layer: str,
55 | dropout=None):
56 | super(ASPP, self).__init__()
57 | self.from_layer = from_layer
58 | self.to_layer = to_layer
59 |
60 | modules = []
61 | modules.append(
62 | nn.Sequential(
63 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
64 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)))
65 |
66 | rate1, rate2, rate3 = tuple(atrous_rates)
67 | modules.append(ASPPConv(in_channels, out_channels, rate1))
68 | modules.append(ASPPConv(in_channels, out_channels, rate2))
69 | modules.append(ASPPConv(in_channels, out_channels, rate3))
70 | modules.append(ASPPPooling(in_channels, out_channels))
71 |
72 | self.convs = nn.ModuleList(modules)
73 |
74 | self.project = nn.Sequential(
75 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
76 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))
77 | self.with_dropout = dropout is not None
78 | if self.with_dropout:
79 | self.dropout = nn.Dropout(p=dropout)
80 |
81 | logger.info('ASPP init weights')
82 | init_weights(self.modules())
83 |
84 | def forward(self, feats):
85 | feats_ = feats.copy()
86 | x = feats_[self.from_layer]
87 | res = []
88 | for conv in self.convs:
89 | res.append(conv(x))
90 | res = torch.cat(res, dim=1)
91 | res = self.project(res)
92 | if self.with_dropout:
93 | res = self.dropout(res)
94 | feats_[self.to_layer] = res
95 | return feats_
96 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import ENHANCE_MODULES
3 |
4 |
5 | def build_enhance_module(cfg, default_args=None):
6 | enhance_module = build_from_cfg(cfg, ENHANCE_MODULES, default_args)
7 |
8 | return enhance_module
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/ppm.py:
--------------------------------------------------------------------------------
1 | # modify from https://github.com/hszhao/semseg/blob/master/model/pspnet.py
2 |
3 | import logging
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from vedastr.models.weight_init import init_weights
10 | from .registry import ENHANCE_MODULES
11 |
12 | logger = logging.getLogger()
13 |
14 |
15 | @ENHANCE_MODULES.register_module
16 | class PPM(nn.Module):
17 |
18 | def __init__(self, in_channels, out_channels, bins, from_layer, to_layer):
19 | super(PPM, self).__init__()
20 | self.from_layer = from_layer
21 | self.to_layer = to_layer
22 |
23 | self.blocks = nn.ModuleList()
24 | for bin_ in bins:
25 | self.blocks.append(
26 | nn.Sequential(
27 | nn.AdaptiveAvgPool2d(bin_),
28 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
29 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)))
30 | logger.info('PPM init weights')
31 | init_weights(self.modules())
32 |
33 | def forward(self, feats):
34 | feats_ = feats.copy()
35 | x = feats_[self.from_layer]
36 | h, w = x.shape[2:]
37 | out = [x]
38 | for block in self.blocks:
39 | feat = F.interpolate(
40 | block(x), (h, w), mode='bilinear', align_corners=True)
41 | out.append(feat)
42 | out = torch.cat(out, 1)
43 | feats_[self.to_layer] = out
44 |
45 | return feats_
46 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | ENHANCE_MODULES = Registry('enhance_module')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/rectificators/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_rectificator # noqa 401
2 | from .spin import SPIN # noqa 401
3 | from .tps_stn import TPS_STN # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/rectificators/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import RECTIFICATORS
3 |
4 |
5 | def build_rectificator(cfg, default_args=None):
6 | rectificator = build_from_cfg(cfg, RECTIFICATORS, default_args)
7 |
8 | return rectificator
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/rectificators/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | RECTIFICATORS = Registry('Rectificator')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/rectificators/spin.py:
--------------------------------------------------------------------------------
1 | # [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117) # noqa 501
2 | # Not fully implemented yet. SPN has tested successfully.
3 | import copy
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | from vedastr.models.bodies.feature_extractors import build_feature_extractor
10 | from vedastr.models.utils import build_module, build_torch_nn
11 | from vedastr.models.weight_init import init_weights
12 | from .registry import RECTIFICATORS
13 |
14 |
15 | class SPN(nn.Module):
16 |
17 | def __init__(self, cfg):
18 | super(SPN, self).__init__()
19 | self.body = build_feature_extractor(cfg['feature_extractor'])
20 | self.pool = build_torch_nn(cfg['pool'])
21 | heads = []
22 | for head in cfg['head']:
23 | heads.append(build_module(head))
24 | self.head = nn.Sequential(*heads)
25 |
26 | def forward(self, x):
27 | batch_size = x.size(0)
28 | x = self.body(x)
29 | x = self.pool(x).view(batch_size, -1)
30 | x = self.head(x)
31 | return x
32 |
33 |
34 | class AIN(nn.Module):
35 |
36 | def __init__(self, cfg):
37 | super(AIN, self).__init__()
38 | self.body = build_feature_extractor(cfg['feature_extractor'])
39 |
40 | def forward(self, x):
41 | x = self.body(x)
42 |
43 | return x
44 |
45 |
46 | @RECTIFICATORS.register_module
47 | class SPIN(nn.Module):
48 |
49 | def __init__(self, spin: dict, k: int):
50 | super(SPIN, self).__init__()
51 | self.body = build_feature_extractor(spin['feature_extractor'])
52 | self.spn = SPN(spin['spn'])
53 | self.betas = generate_beta(k)
54 | init_weights(self.modules())
55 |
56 | def forward(self, x):
57 | b, c, h, w = x.size()
58 | init_img = copy.copy(x)
59 | # shared parameters
60 | x = self.body(x)
61 |
62 | spn_out = self.spn(x) # 2k+2
63 | omega = spn_out[:, :-1]
64 | g_out = init_img.requires_grad_(True)
65 |
66 | gamma_out = [g_out**beta for beta in self.betas]
67 | gamma_out = torch.stack(gamma_out, axis=1).requires_grad_(True)
68 |
69 | fusion_img = omega[:, :, None, None, None] * gamma_out
70 | fusion_img = torch.sigmoid(fusion_img.sum(dim=1))
71 | return fusion_img
72 |
73 |
74 | def generate_beta(k):
75 | betas = []
76 | for i in range(1, k + 2):
77 | p = i / (2 * (k + 1))
78 | beta = round(np.log(1 - p) / np.log(p), 2)
79 | betas.append(beta)
80 | for i in range(k + 2, 2 * k + 2):
81 | beta = round(1 / betas[(i - (k + 1))], 2)
82 | betas.append(beta)
83 |
84 | return betas
85 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/rectificators/sspin.py:
--------------------------------------------------------------------------------
1 | # We implement a new module which has same property like spin to some extent.
2 | # We think this manner can replace the GA-SPIN by enlarging output features
3 | # of se layer, but we didn't do further experiments.
4 |
5 | import torch.nn as nn
6 |
7 | from vedastr.models.bodies.feature_extractors import build_feature_extractor
8 | from vedastr.models.utils import SE
9 | from vedastr.models.weight_init import init_weights
10 | from .registry import RECTIFICATORS
11 |
12 |
13 | @RECTIFICATORS.register_module
14 | class SSPIN(nn.Module):
15 |
16 | def __init__(self, feature_cfg, se_cfgs):
17 | super(SSPIN, self).__init__()
18 | self.body = build_feature_extractor(feature_cfg)
19 | self.se = SE(**se_cfgs)
20 | init_weights(self.modules())
21 |
22 | def forward(self, x):
23 | x = self.body(x)
24 | x = self.se(x)
25 |
26 | return x
27 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | COMPONENT = Registry('component')
4 | BODIES = Registry('body')
5 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_sequence_decoder, build_sequence_encoder # noqa 401
2 | from .rnn import RNN, GRUCell # noqa 401
3 | from .transformer import TransformerEncoder # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import SEQUENCE_DECODERS, SEQUENCE_ENCODERS
3 |
4 |
5 | def build_sequence_encoder(cfg, default_args=None):
6 | sequence_encoder = build_from_cfg(cfg, SEQUENCE_ENCODERS, default_args)
7 |
8 | return sequence_encoder
9 |
10 |
11 | def build_sequence_decoder(cfg, default_args=None):
12 | sequence_encoder = build_from_cfg(cfg, SEQUENCE_DECODERS, default_args)
13 |
14 | return sequence_encoder
15 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | SEQUENCE_ENCODERS = Registry('sequence_encoder')
4 | SEQUENCE_DECODERS = Registry('sequence_decoder')
5 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/rnn/__init__.py:
--------------------------------------------------------------------------------
1 | from .decoder import GRUCell, LSTMCell # noqa 401
2 | from .encoder import RNN # noqa 401
3 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/rnn/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from vedastr.models.weight_init import init_weights
5 | from ..registry import SEQUENCE_DECODERS
6 |
7 |
8 | class BaseCell(nn.Module):
9 |
10 | def __init__(self,
11 | basic_cell,
12 | input_size,
13 | hidden_size,
14 | bias=True,
15 | num_layers=1):
16 | super(BaseCell, self).__init__()
17 |
18 | self.input_size = input_size
19 | self.hidden_size = hidden_size
20 | self.bias = bias
21 | self.num_layers = num_layers
22 |
23 | self.cells = nn.ModuleList()
24 | for i in range(num_layers):
25 | if i == 0:
26 | self.cells.append(
27 | basic_cell(
28 | input_size=input_size,
29 | hidden_size=hidden_size,
30 | bias=bias))
31 | else:
32 | self.cells.append(
33 | basic_cell(
34 | input_size=hidden_size,
35 | hidden_size=hidden_size,
36 | bias=bias))
37 | init_weights(self.modules())
38 |
39 | def init_hidden(self, batch_size, device=None, value=0):
40 | raise NotImplementedError()
41 |
42 | def get_output(self, hiddens):
43 | raise NotImplementedError()
44 |
45 | def get_hidden_state(self, hidden):
46 | raise NotImplementedError()
47 |
48 | def forward(self, x, pre_hiddens):
49 | next_hiddens = []
50 |
51 | hidden = None
52 | for i, cell in enumerate(self.cells):
53 | if i == 0:
54 | hidden = cell(x, pre_hiddens[i])
55 | else:
56 | hidden = cell(self.get_hidden_state(hidden), pre_hiddens[i])
57 |
58 | next_hiddens.append(hidden)
59 |
60 | return next_hiddens
61 |
62 |
63 | @SEQUENCE_DECODERS.register_module
64 | class LSTMCell(BaseCell):
65 |
66 | def __init__(self, input_size, hidden_size, bias=True, num_layers=1):
67 | super(LSTMCell, self).__init__(nn.LSTMCell, input_size, hidden_size,
68 | bias, num_layers)
69 |
70 | def init_hidden(self, batch_size, device=None, value=0):
71 | hiddens = []
72 | for _ in range(self.num_layers):
73 | hidden = (
74 | torch.FloatTensor(batch_size,
75 | self.hidden_size).fill_(value).to(device),
76 | torch.FloatTensor(batch_size,
77 | self.hidden_size).fill_(value).to(device),
78 | )
79 | hiddens.append(hidden)
80 |
81 | return hiddens
82 |
83 | def get_output(self, hiddens):
84 | return hiddens[-1][0]
85 |
86 | def get_hidden_state(self, hidden):
87 | return hidden[0]
88 |
89 |
90 | @SEQUENCE_DECODERS.register_module
91 | class GRUCell(BaseCell):
92 |
93 | def __init__(self, input_size, hidden_size, bias=True, num_layers=1):
94 | super(GRUCell, self).__init__(nn.GRUCell, input_size, hidden_size,
95 | bias, num_layers)
96 |
97 | def init_hidden(self, batch_size, device=None, value=0):
98 | hiddens = []
99 | for i in range(self.num_layers):
100 | hidden = torch.FloatTensor(
101 | batch_size, self.hidden_size).fill_(value).to(device)
102 | hiddens.append(hidden)
103 |
104 | return hiddens
105 |
106 | def get_output(self, hiddens):
107 | return hiddens[-1]
108 |
109 | def get_hidden_state(self, hidden):
110 | return hidden
111 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/rnn/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from vedastr.models.utils import build_torch_nn
4 | from vedastr.models.weight_init import init_weights
5 | from ..registry import SEQUENCE_ENCODERS
6 |
7 |
8 | @SEQUENCE_ENCODERS.register_module
9 | class RNN(nn.Module):
10 |
11 | def __init__(self, input_pool, layers, keep_order=False):
12 | super(RNN, self).__init__()
13 | self.keep_order = keep_order
14 |
15 | if input_pool:
16 | self.input_pool = build_torch_nn(input_pool)
17 |
18 | self.layers = nn.ModuleList()
19 | for i, (layer_name, layer_cfg) in enumerate(layers):
20 | if layer_name in ['rnn', 'fc']:
21 | self.layers.add_module('{}_{}'.format(i, layer_name),
22 | build_torch_nn(layer_cfg))
23 | else:
24 | raise ValueError('Unknown layer name {}'.format(layer_name))
25 |
26 | init_weights(self.modules())
27 |
28 | @property
29 | def with_input_pool(self):
30 | return hasattr(self, 'input_pool') and self.input_pool
31 |
32 | def forward(self, x):
33 | if self.with_input_pool:
34 | out = self.input_pool(x).squeeze(2)
35 | else:
36 | out = x
37 | # input order (B, C, T) -> (B, T, C)
38 | out = out.permute(0, 2, 1)
39 | for layer_name, layer in self.layers.named_children():
40 |
41 | if 'rnn' in layer_name:
42 | layer.flatten_parameters()
43 | out, _ = layer(out)
44 | else:
45 | out = layer(out)
46 | if not self.keep_order:
47 | out = out.permute(0, 2, 1).unsqueeze(2)
48 |
49 | return out.contiguous()
50 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .decoder import TransformerDecoder # noqa 401
2 | from .encoder import TransformerEncoder # noqa 401
3 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/decoder.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | from vedastr.models.weight_init import init_weights
6 | from .position_encoder import build_position_encoder
7 | from .unit import build_decoder_layer
8 | from ..registry import SEQUENCE_DECODERS
9 |
10 | logger = logging.getLogger()
11 |
12 |
13 | @SEQUENCE_DECODERS.register_module
14 | class TransformerDecoder(nn.Module):
15 |
16 | def __init__(self,
17 | decoder_layer: dict,
18 | num_layers: int,
19 | position_encoder: dict = None):
20 | super(TransformerDecoder, self).__init__()
21 |
22 | if position_encoder is not None:
23 | self.pos_encoder = build_position_encoder(position_encoder)
24 |
25 | self.layers = nn.ModuleList(
26 | [build_decoder_layer(decoder_layer) for _ in range(num_layers)])
27 |
28 | logger.info('TransformerDecoder init weights')
29 | init_weights(self.modules())
30 |
31 | @property
32 | def with_position_encoder(self):
33 | return hasattr(self, 'pos_encoder') and self.pos_encoder is not None
34 |
35 | def forward(self, tgt, src, tgt_mask=None, src_mask=None):
36 | if self.with_position_encoder:
37 | tgt = self.pos_encoder(tgt)
38 |
39 | for layer in self.layers:
40 | tgt = layer(tgt, src, tgt_mask, src_mask)
41 |
42 | return tgt
43 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/encoder.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | from vedastr.models.weight_init import init_weights
6 | from .position_encoder import build_position_encoder
7 | from .unit import build_encoder_layer
8 | from ..registry import SEQUENCE_ENCODERS
9 |
10 | logger = logging.getLogger()
11 |
12 |
13 | @SEQUENCE_ENCODERS.register_module
14 | class TransformerEncoder(nn.Module):
15 |
16 | def __init__(self,
17 | encoder_layer: dict,
18 | num_layers: int,
19 | position_encoder: dict = None):
20 | super(TransformerEncoder, self).__init__()
21 |
22 | if position_encoder is not None:
23 | self.pos_encoder = build_position_encoder(position_encoder)
24 |
25 | self.layers = nn.ModuleList(
26 | [build_encoder_layer(encoder_layer) for _ in range(num_layers)])
27 |
28 | logger.info('TransformerEncoder init weights')
29 | init_weights(self.modules())
30 |
31 | @property
32 | def with_position_encoder(self):
33 | return hasattr(self, 'pos_encoder') and self.pos_encoder is not None
34 |
35 | def forward(self, src, src_mask=None):
36 | if self.with_position_encoder:
37 | src = self.pos_encoder(src)
38 |
39 | for layer in self.layers:
40 | src = layer(src, src_mask)
41 |
42 | return src
43 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptive_2d_encoder import Adaptive2DPositionEncoder # noqa 401
2 | from .builder import build_position_encoder # noqa 401
3 | from .encoder import PositionEncoder1D # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/adaptive_2d_encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .registry import POSITION_ENCODERS
4 | from .utils import generate_encoder
5 |
6 |
7 | @POSITION_ENCODERS.register_module
8 | class Adaptive2DPositionEncoder(nn.Module):
9 |
10 | def __init__(self, in_channels, max_h=200, max_w=200, dropout=0.1):
11 | super(Adaptive2DPositionEncoder, self).__init__()
12 |
13 | h_position_encoder = generate_encoder(in_channels, max_h)
14 | h_position_encoder = h_position_encoder.transpose(0, 1).view(
15 | 1, in_channels, max_h, 1)
16 |
17 | w_position_encoder = generate_encoder(in_channels, max_w)
18 | w_position_encoder = w_position_encoder.transpose(0, 1).view(
19 | 1, in_channels, 1, max_w)
20 |
21 | self.register_buffer('h_position_encoder', h_position_encoder)
22 | self.register_buffer('w_position_encoder', w_position_encoder)
23 |
24 | self.h_scale = self.scale_factor_generate(in_channels)
25 | self.w_scale = self.scale_factor_generate(in_channels)
26 | self.pool = nn.AdaptiveAvgPool2d(1)
27 | self.dropout = nn.Dropout(p=dropout)
28 |
29 | def scale_factor_generate(self, in_channels):
30 | scale_factor = nn.Sequential(
31 | nn.Conv2d(in_channels, in_channels, kernel_size=1),
32 | nn.ReLU(inplace=True),
33 | nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.Sigmoid())
34 |
35 | return scale_factor
36 |
37 | def forward(self, x):
38 | b, c, h, w = x.size()
39 |
40 | avg_pool = self.pool(x)
41 |
42 | h_pos_encoding = self.h_scale(
43 | avg_pool) * self.h_position_encoder[:, :, :h, :]
44 | w_pos_encoding = self.w_scale(
45 | avg_pool) * self.w_position_encoder[:, :, :, :w]
46 |
47 | out = x + h_pos_encoding + w_pos_encoding
48 |
49 | out = self.dropout(out)
50 |
51 | return out
52 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import POSITION_ENCODERS
3 |
4 |
5 | def build_position_encoder(cfg, default_args=None):
6 | position_encoder = build_from_cfg(cfg, POSITION_ENCODERS, default_args)
7 |
8 | return position_encoder
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .registry import POSITION_ENCODERS
4 | from .utils import generate_encoder
5 |
6 |
7 | @POSITION_ENCODERS.register_module
8 | class PositionEncoder1D(nn.Module):
9 |
10 | def __init__(self, in_channels, max_len=2000, dropout=0.1):
11 | super(PositionEncoder1D, self).__init__()
12 |
13 | position_encoder = generate_encoder(in_channels, max_len)
14 | position_encoder = position_encoder.unsqueeze(0)
15 | self.register_buffer('position_encoder', position_encoder)
16 | self.dropout = nn.Dropout(p=dropout)
17 |
18 | def forward(self, x):
19 | out = x + self.position_encoder[:, :x.size(1), :]
20 | out = self.dropout(out)
21 |
22 | return out
23 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | POSITION_ENCODERS = Registry('position_encoder')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def generate_encoder(in_channels, max_len):
5 | pos = torch.arange(max_len).float().unsqueeze(1)
6 |
7 | i = torch.arange(in_channels).float().unsqueeze(0)
8 | angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / in_channels)
9 |
10 | position_encoder = pos * angle_rates
11 | position_encoder[:, 0::2] = torch.sin(position_encoder[:, 0::2])
12 | position_encoder[:, 1::2] = torch.cos(position_encoder[:, 1::2])
13 |
14 | return position_encoder
15 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_decoder_layer, build_encoder_layer # noqa 401
2 | from .decoder import TransformerDecoderLayer1D # noqa 401
3 | from .encoder import TransformerEncoderLayer1D, TransformerEncoderLayer2D # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_attention # noqa 401
2 | from .multihead_attention import MultiHeadAttention # noqa 401
3 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/attention/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import TRANSFORMER_ATTENTIONS
3 |
4 |
5 | def build_attention(cfg, default_args=None):
6 | attention = build_from_cfg(cfg, TRANSFORMER_ATTENTIONS, default_args)
7 |
8 | return attention
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/attention/multihead_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .registry import TRANSFORMER_ATTENTIONS
5 |
6 |
7 | class ScaledDotProductAttention(nn.Module):
8 |
9 | def __init__(self, temperature, dropout=0.1):
10 | super(ScaledDotProductAttention, self).__init__()
11 |
12 | self.temperature = temperature
13 | self.dropout = nn.Dropout(p=dropout)
14 |
15 | def forward(self, q, k, v, mask=None):
16 | attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
17 |
18 | if mask is not None:
19 | attn = attn.masked_fill(mask=mask, value=float('-inf'))
20 |
21 | attn = torch.softmax(attn, dim=-1)
22 | attn = self.dropout(attn)
23 |
24 | out = torch.matmul(attn, v)
25 |
26 | return out, attn
27 |
28 |
29 | @TRANSFORMER_ATTENTIONS.register_module
30 | class MultiHeadAttention(nn.Module):
31 |
32 | def __init__(self,
33 | in_channels: int,
34 | k_channels: int,
35 | v_channels: int,
36 | n_head: int = 8,
37 | dropout: float = 0.1):
38 | super(MultiHeadAttention, self).__init__()
39 |
40 | self.in_channels = in_channels
41 | self.k_channels = k_channels
42 | self.v_channels = v_channels
43 | self.n_head = n_head
44 |
45 | self.q_linear = nn.Linear(in_channels, n_head * k_channels)
46 | self.k_linear = nn.Linear(in_channels, n_head * k_channels)
47 | self.v_linear = nn.Linear(in_channels, n_head * v_channels)
48 | self.attention = ScaledDotProductAttention(
49 | temperature=k_channels**0.5, dropout=dropout)
50 | self.out_linear = nn.Linear(n_head * v_channels, in_channels)
51 |
52 | self.dropout = nn.Dropout(p=dropout)
53 |
54 | def forward(self, q, k, v, mask=None):
55 | b, q_len, k_len, v_len = q.size(0), q.size(1), k.size(1), v.size(1)
56 |
57 | q = self.q_linear(q).view(b, q_len, self.n_head,
58 | self.k_channels).transpose(1, 2)
59 | k = self.k_linear(k).view(b, k_len, self.n_head,
60 | self.k_channels).transpose(1, 2)
61 | v = self.v_linear(v).view(b, v_len, self.n_head,
62 | self.v_channels).transpose(1, 2)
63 |
64 | if mask is not None:
65 | mask = mask.unsqueeze(1)
66 |
67 | out, attn = self.attention(q, k, v, mask=mask)
68 |
69 | out = out.transpose(1,
70 | 2).contiguous().view(b, q_len,
71 | self.n_head * self.v_channels)
72 | out = self.out_linear(out)
73 | out = self.dropout(out)
74 |
75 | return out, attn
76 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/attention/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | TRANSFORMER_ATTENTIONS = Registry('transformer_attention')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import TRANSFORMER_DECODER_LAYERS, TRANSFORMER_ENCODER_LAYERS
3 |
4 |
5 | def build_encoder_layer(cfg, default_args=None):
6 | encoder_layer = build_from_cfg(cfg, TRANSFORMER_ENCODER_LAYERS,
7 | default_args)
8 |
9 | return encoder_layer
10 |
11 |
12 | def build_decoder_layer(cfg, default_args=None):
13 | decoder_layer = build_from_cfg(cfg, TRANSFORMER_DECODER_LAYERS,
14 | default_args)
15 |
16 | return decoder_layer
17 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/decoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from vedastr.models.utils import build_torch_nn
4 | from .attention import build_attention
5 | from .feedforward import build_feedforward
6 | from .registry import TRANSFORMER_DECODER_LAYERS
7 |
8 |
9 | @TRANSFORMER_DECODER_LAYERS.register_module
10 | class TransformerDecoderLayer1D(nn.Module):
11 |
12 | def __init__(self, self_attention, self_attention_norm, attention,
13 | attention_norm, feedforward, feedforward_norm):
14 | super(TransformerDecoderLayer1D, self).__init__()
15 |
16 | self.self_attention = build_attention(self_attention)
17 | self.self_attention_norm = build_torch_nn(self_attention_norm)
18 |
19 | self.attention = build_attention(attention)
20 | self.attention_norm = build_torch_nn(attention_norm)
21 |
22 | self.feedforward = build_feedforward(feedforward)
23 | self.feedforward_norm = build_torch_nn(feedforward_norm)
24 |
25 | def forward(self, tgt, src, tgt_mask=None, src_mask=None):
26 | attn1, _ = self.self_attention(tgt, tgt, tgt, tgt_mask)
27 | out1 = self.self_attention_norm(tgt + attn1)
28 |
29 | size = src.size()
30 | if len(size) == 4:
31 | b, c, h, w = size
32 | src = src.view(b, c, h * w).transpose(1, 2)
33 | if src_mask is not None:
34 | src_mask = src_mask.view(b, 1, h * w)
35 |
36 | attn2, _ = self.attention(out1, src, src, src_mask)
37 | out2 = self.attention_norm(out1 + attn2)
38 |
39 | ffn_out = self.feedforward(out2)
40 | out3 = self.feedforward_norm(out2 + ffn_out)
41 |
42 | return out3
43 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from vedastr.models.utils import build_torch_nn
4 | from .attention import build_attention
5 | from .feedforward import build_feedforward
6 | from .registry import TRANSFORMER_ENCODER_LAYERS
7 |
8 |
9 | class _TransformerEncoderLayer(nn.Module):
10 |
11 | def __init__(self, attention, attention_norm, feedforward,
12 | feedforward_norm):
13 | super(_TransformerEncoderLayer, self).__init__()
14 | self.attention = build_attention(attention)
15 | self.attention_norm = build_torch_nn(attention_norm)
16 |
17 | self.feedforward = build_feedforward(feedforward)
18 | self.feedforward_norm = build_torch_nn(feedforward_norm)
19 |
20 |
21 | @TRANSFORMER_ENCODER_LAYERS.register_module
22 | class TransformerEncoderLayer1D(_TransformerEncoderLayer):
23 |
24 | def __init__(self, attention, attention_norm, feedforward,
25 | feedforward_norm):
26 | super(TransformerEncoderLayer1D,
27 | self).__init__(attention, attention_norm, feedforward,
28 | feedforward_norm)
29 |
30 | def forward(self, src, src_mask=None):
31 | attn_out, _ = self.attention(src, src, src, src_mask)
32 | out1 = self.attention_norm(src + attn_out)
33 |
34 | ffn_out = self.feedforward(out1)
35 | out2 = self.feedforward_norm(out1 + ffn_out)
36 |
37 | return out2
38 |
39 |
40 | @TRANSFORMER_ENCODER_LAYERS.register_module
41 | class TransformerEncoderLayer2D(_TransformerEncoderLayer):
42 |
43 | def __init__(self, attention, attention_norm, feedforward,
44 | feedforward_norm):
45 | super(TransformerEncoderLayer2D,
46 | self).__init__(attention, attention_norm, feedforward,
47 | feedforward_norm)
48 |
49 | def norm(self, norm_layer, x):
50 | b, c, h, w = x.size()
51 |
52 | if isinstance(norm_layer, nn.LayerNorm):
53 | out = x.view(b, c, h * w).transpose(1, 2)
54 | out = norm_layer(out)
55 | out = out.transpose(1, 2).contiguous().view(b, c, h, w)
56 | else:
57 | out = norm_layer(x)
58 |
59 | return out
60 |
61 | def forward(self, src, src_mask=None):
62 | b, c, h, w = src.size()
63 |
64 | src = src.view(b, c, h * w).transpose(1, 2)
65 | if src_mask is not None:
66 | src_mask = src_mask.view(b, 1, h * w)
67 |
68 | attn_out, _ = self.attention(src, src, src, src_mask)
69 | out1 = src + attn_out
70 | out1 = out1.transpose(1, 2).contiguous().view(b, c, h, w)
71 | out1 = self.norm(self.attention_norm, out1)
72 |
73 | ffn_out = self.feedforward(out1)
74 | out2 = self.norm(self.feedforward_norm, out1 + ffn_out)
75 |
76 | return out2
77 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/feedforward/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_feedforward # noqa 401
2 | from .feedforward import Feedforward # noqa 401
3 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/feedforward/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import TRANSFORMER_FEEDFORWARDS
3 |
4 |
5 | def build_feedforward(cfg, default_args=None):
6 | feedforward = build_from_cfg(cfg, TRANSFORMER_FEEDFORWARDS, default_args)
7 |
8 | return feedforward
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/feedforward/feedforward.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from vedastr.models.utils import build_module
4 | from .registry import TRANSFORMER_FEEDFORWARDS
5 |
6 |
7 | @TRANSFORMER_FEEDFORWARDS.register_module
8 | class Feedforward(nn.Module):
9 |
10 | def __init__(self, layers: dict):
11 | super(Feedforward, self).__init__()
12 |
13 | self.layers = [build_module(layer) for layer in layers]
14 | self.layers = nn.Sequential(*self.layers)
15 |
16 | def forward(self, x):
17 | out = self.layers(x)
18 |
19 | return out
20 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/feedforward/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | TRANSFORMER_FEEDFORWARDS = Registry('transformer_feedforward')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | TRANSFORMER_ENCODER_LAYERS = Registry('transformer_encoder_layer')
4 | TRANSFORMER_DECODER_LAYERS = Registry('transformer_decoder_layer')
5 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/builder.py:
--------------------------------------------------------------------------------
1 | from ..utils import build_from_cfg
2 | from .registry import MODELS
3 |
4 |
5 | def build_model(cfg, default_args=None):
6 | model = build_from_cfg(cfg, MODELS, default_args)
7 |
8 | return model
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .att_head import AttHead # noqa 401
2 | from .builder import build_head # noqa 401
3 | from .conv_head import ConvHead # noqa 401
4 | from .ctc_head import CTCHead # noqa 401
5 | from .fc_head import FCHead # noqa 401
6 | from .head import Head # noqa 401
7 | from .multi_head import MultiHead # noqa 401
8 | from .transformer_head import TransformerHead # noqa 401
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/att_head.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from vedastr.models.bodies import build_brick, build_sequence_decoder
7 | from vedastr.models.utils import build_torch_nn
8 | from vedastr.models.weight_init import init_weights
9 | from .registry import HEADS
10 |
11 | logger = logging.getLogger()
12 |
13 |
14 | @HEADS.register_module
15 | class AttHead(nn.Module):
16 |
17 | def __init__(self,
18 | cell,
19 | generator,
20 | num_steps,
21 | num_class,
22 | input_attention_block=None,
23 | output_attention_block=None,
24 | text_transform=None,
25 | holistic_input_from=None):
26 | super(AttHead, self).__init__()
27 |
28 | if input_attention_block is not None:
29 | self.input_attention_block = build_brick(input_attention_block)
30 |
31 | self.cell = build_sequence_decoder(cell)
32 | self.generator = build_torch_nn(generator)
33 | self.num_steps = num_steps
34 | self.num_class = num_class
35 |
36 | if output_attention_block is not None:
37 | self.output_attention_block = build_brick(output_attention_block)
38 |
39 | if text_transform is not None:
40 | self.text_transform = build_torch_nn(text_transform)
41 |
42 | if holistic_input_from is not None:
43 | self.holistic_input_from = holistic_input_from
44 |
45 | self.register_buffer('embeddings',
46 | torch.diag(torch.ones(self.num_class)))
47 | logger.info('AttHead init weights')
48 | init_weights(self.modules())
49 |
50 | @property
51 | def with_holistic_input(self):
52 | return hasattr(self,
53 | 'holistic_input_from') and self.holistic_input_from
54 |
55 | @property
56 | def with_input_attention(self):
57 | return hasattr(
58 | self,
59 | 'input_attention_block') and self.input_attention_block is not None
60 |
61 | @property
62 | def with_output_attention(self):
63 | return hasattr(self, 'output_attention_block'
64 | ) and self.output_attention_block is not None
65 |
66 | @property
67 | def with_text_transform(self):
68 | return hasattr(self, 'text_transform') and self.text_transform
69 |
70 | def forward(self, feats, texts):
71 | batch_size = texts.size(0)
72 |
73 | hidden = self.cell.init_hidden(batch_size, device=texts.device)
74 | if self.with_holistic_input:
75 | holistic_input = feats[self.holistic_input_from][:, :, 0, -1]
76 | hidden = self.cell(holistic_input, hidden)
77 |
78 | out = []
79 |
80 | if self.training:
81 | use_gt = True
82 | assert self.num_steps == texts.size(1)
83 | else:
84 | use_gt = False
85 | assert texts.size(1) == 1
86 |
87 | for i in range(self.num_steps):
88 | if i == 0:
89 | indexes = texts[:, i]
90 | else:
91 | if use_gt:
92 | indexes = texts[:, i]
93 | else:
94 | _, indexes = out[-1].max(1)
95 | text_feat = self.embeddings.index_select(0, indexes)
96 |
97 | if self.with_text_transform:
98 | text_feat = self.text_transform(text_feat)
99 |
100 | if self.with_input_attention:
101 | attention_feat = self.input_attention_block(
102 | feats,
103 | self.cell.get_output(hidden).unsqueeze(-1).unsqueeze(-1))
104 | cell_input = torch.cat([attention_feat, text_feat], dim=1)
105 | else:
106 | cell_input = text_feat
107 | hidden = self.cell(cell_input, hidden)
108 | out_feat = self.cell.get_output(hidden)
109 |
110 | if self.with_output_attention:
111 | attention_feat = self.output_attention_block(
112 | feats,
113 | self.cell.get_output(hidden).unsqueeze(-1).unsqueeze(-1))
114 | out_feat = torch.cat(
115 | [self.cell.get_output(hidden), attention_feat], dim=1)
116 |
117 | out.append(self.generator(out_feat))
118 |
119 | out = torch.stack(out, dim=1)
120 |
121 | return out
122 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/builder.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import build_from_cfg
2 | from .registry import HEADS
3 |
4 |
5 | def build_head(cfg, default_args=None):
6 | head = build_from_cfg(cfg, HEADS, default_args)
7 | return head
8 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/conv_head.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | from vedastr.models.utils import ConvModules
6 | from vedastr.models.weight_init import init_weights
7 | from .registry import HEADS
8 |
9 | logger = logging.getLogger()
10 |
11 |
12 | @HEADS.register_module
13 | class ConvHead(nn.Module):
14 | """FCHead
15 |
16 | Args:
17 | """
18 |
19 | def __init__(self,
20 | in_channels,
21 | num_class,
22 | from_layer,
23 | num_convs=0,
24 | inner_channels=None,
25 | kernel_size=1,
26 | padding=0,
27 | **kwargs):
28 | super(ConvHead, self).__init__()
29 |
30 | self.from_layer = from_layer
31 |
32 | self.conv = []
33 | if num_convs > 0:
34 | out_channels = inner_channels
35 | self.conv.append(
36 | ConvModules(
37 | in_channels,
38 | out_channels,
39 | num_convs=num_convs,
40 | kernel_size=kernel_size,
41 | padding=padding,
42 | **kwargs))
43 | else:
44 | out_channels = in_channels
45 | self.conv.append(
46 | nn.Conv2d(
47 | out_channels,
48 | num_class,
49 | kernel_size=kernel_size,
50 | padding=padding))
51 | self.conv = nn.Sequential(*self.conv)
52 |
53 | logger.info('ConvHead init weights')
54 | init_weights(self.modules())
55 |
56 | def forward(self, x_input):
57 | x = x_input[self.from_layer]
58 | assert x.size(2) == 1
59 |
60 | out = self.conv(x).mean(2).permute(0, 2, 1)
61 |
62 | return out
63 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/ctc_head.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | from vedastr.models.utils import build_torch_nn
6 | from vedastr.models.weight_init import init_weights
7 | from .registry import HEADS
8 |
9 | logger = logging.getLogger()
10 |
11 |
12 | @HEADS.register_module
13 | class CTCHead(nn.Module):
14 | """CTCHead
15 |
16 | """
17 |
18 | def __init__(
19 | self,
20 | in_channels,
21 | num_class,
22 | from_layer,
23 | pool=None,
24 | export=False,
25 | ):
26 | super(CTCHead, self).__init__()
27 |
28 | self.num_class = num_class
29 | self.from_layer = from_layer
30 | fc = nn.Linear(in_channels, num_class)
31 | if pool is not None:
32 | self.pool = build_torch_nn(pool)
33 | self.fc = fc
34 | self.export = export
35 |
36 | logger.info('CTCHead init weights')
37 | init_weights(self.modules())
38 |
39 | @property
40 | def with_pool(self):
41 | return hasattr(self, 'pool') and self.pool is not None
42 |
43 | def forward(self, x_input):
44 | x = x_input[self.from_layer]
45 | if self.export:
46 | x = x.mean(2).permute(0, 2, 1)
47 | elif self.with_pool:
48 | x = self.pool(x).permute(0, 3, 1, 2).squeeze(3)
49 | out = self.fc(x)
50 |
51 | return out
52 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/fc_head.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | from vedastr.models.utils import FCModules, build_torch_nn
6 | from vedastr.models.weight_init import init_weights
7 | from .registry import HEADS
8 |
9 | logger = logging.getLogger()
10 |
11 |
12 | @HEADS.register_module
13 | class FCHead(nn.Module):
14 | """FCHead
15 |
16 | Args:
17 | """
18 |
19 | def __init__(self,
20 | in_channels,
21 | out_channels,
22 | num_class,
23 | batch_max_length,
24 | from_layer,
25 | inner_channels=None,
26 | bias=True,
27 | activation='relu',
28 | inplace=True,
29 | dropouts=None,
30 | num_fcs=0,
31 | pool=None):
32 | super(FCHead, self).__init__()
33 |
34 | self.num_class = num_class
35 | self.batch_max_length = batch_max_length
36 | self.from_layer = from_layer
37 |
38 | if num_fcs > 0:
39 | inter_fc = FCModules(in_channels, inner_channels, bias, activation,
40 | inplace, dropouts, num_fcs)
41 | fc = nn.Linear(inner_channels, out_channels)
42 | else:
43 | inter_fc = nn.Sequential()
44 | fc = nn.Linear(in_channels, out_channels)
45 |
46 | if pool is not None:
47 | self.pool = build_torch_nn(pool)
48 |
49 | self.inter_fc = inter_fc
50 | self.fc = fc
51 |
52 | logger.info('FCHead init weights')
53 | init_weights(self.modules())
54 |
55 | @property
56 | def with_pool(self):
57 | return hasattr(self, 'pool') and self.pool is not None
58 |
59 | def forward(self, x_input):
60 | x = x_input[self.from_layer]
61 | batch_size = x.size(0)
62 |
63 | if self.with_pool:
64 | x = self.pool(x)
65 |
66 | x = x.contiguous().view(batch_size, -1)
67 |
68 | out = self.inter_fc(x)
69 | out = self.fc(out)
70 |
71 | return out.reshape(-1, self.batch_max_length + 1, self.num_class)
72 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/head.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | from vedastr.models.utils import build_module
6 | from vedastr.models.weight_init import init_weights
7 | from .registry import HEADS
8 |
9 | logger = logging.getLogger()
10 |
11 |
12 | @HEADS.register_module
13 | class Head(nn.Module):
14 | """Head
15 |
16 | Args:
17 | """
18 |
19 | def __init__(
20 | self,
21 | from_layer,
22 | generator,
23 | ):
24 | super(Head, self).__init__()
25 |
26 | self.from_layer = from_layer
27 | self.generator = build_module(generator)
28 |
29 | logger.info('Head init weights')
30 | init_weights(self.modules())
31 |
32 | def forward(self, feats):
33 | x = feats[self.from_layer]
34 | out = self.generator(x)
35 |
36 | return out
37 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/multi_head.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from vedastr.models.utils import FCModules, build_module, build_torch_nn
7 | from vedastr.models.weight_init import init_weights
8 | from .registry import HEADS
9 |
10 | logger = logging.getLogger()
11 |
12 |
13 | @HEADS.register_module
14 | class MultiHead(nn.Module):
15 | """MultiHead
16 |
17 | Args:
18 | """
19 |
20 | def __init__(self,
21 | in_channels,
22 | num_class,
23 | batch_max_length,
24 | from_layer,
25 | inners=None,
26 | skip_connections=None,
27 | inner_channels=None,
28 | bias=True,
29 | activation='relu',
30 | inplace=True,
31 | dropouts=None,
32 | embedding=False,
33 | num_fcs=0,
34 | pool=None):
35 | super(MultiHead, self).__init__()
36 |
37 | self.num_class = num_class
38 | self.batch_max_length = batch_max_length
39 | self.embedding = embedding
40 | self.from_layer = from_layer
41 |
42 | if inners is not None:
43 | self.inners = []
44 | for inner_cfg in inners:
45 | self.inners.append(build_module(inner_cfg))
46 | self.inners = nn.Sequential(*self.inners)
47 |
48 | if skip_connections is not None:
49 | self.skip_layer = []
50 | for skip_cfg in skip_connections:
51 | self.skip_layer.append(build_module(skip_cfg))
52 | self.skip_layer = nn.Sequential(*self.skip_layer)
53 |
54 | self.fcs = nn.ModuleList()
55 | for i in range(batch_max_length + 1):
56 | if num_fcs > 0:
57 | inter_fc = FCModules(in_channels, inner_channels, bias,
58 | activation, inplace, dropouts, num_fcs)
59 | else:
60 | inter_fc = nn.Sequential()
61 | fc = nn.Linear(in_channels, num_class)
62 | self.fcs.append(nn.Sequential(inter_fc, fc))
63 |
64 | if pool is not None:
65 | self.pool = build_torch_nn(pool)
66 |
67 | logger.info('MultiHead init weights')
68 | init_weights(self.modules())
69 |
70 | @property
71 | def with_pool(self):
72 | return hasattr(self, 'pool') and self.pool is not None
73 |
74 | @property
75 | def with_inners(self):
76 | return hasattr(self, 'inners') and self.inners is not None
77 |
78 | @property
79 | def with_skip_layer(self):
80 | return hasattr(self, 'skip_layer') and self.skip_layer is not None
81 |
82 | def forward(self, x_input):
83 | x = x_input[self.from_layer]
84 | batch_size = x.size(0)
85 |
86 | if self.with_pool:
87 | x = self.pool(x)
88 | if self.with_inners:
89 | inner_x = self.inners(x)
90 | if self.with_skip_layer:
91 | short_x = self.skip_layer(x)
92 | x = inner_x + short_x
93 | else:
94 | x = inner_x + x
95 | x = x.contiguous().view(batch_size, -1)
96 | else:
97 | x = x.squeeze()
98 | x = torch.split(x.squeeze(), (1, ) * x.size(2), dim=2)
99 |
100 | outs = []
101 | for idx, layer in enumerate(self.fcs):
102 | out = layer(x) if not isinstance(x, tuple) else layer(
103 | x[idx].squeeze())
104 | out = out.unsqueeze(1)
105 | outs.append(out)
106 | outs = torch.cat(outs, dim=1)
107 |
108 | return outs
109 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils.registry import Registry
2 |
3 | HEADS = Registry('head')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/heads/transformer_head.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from vedastr.models.bodies import build_sequence_decoder
8 | from vedastr.models.utils import build_torch_nn
9 | from vedastr.models.weight_init import init_weights
10 | from .registry import HEADS
11 |
12 | logger = logging.getLogger()
13 |
14 |
15 | @HEADS.register_module
16 | class TransformerHead(nn.Module):
17 |
18 | def __init__(
19 | self,
20 | decoder,
21 | generator,
22 | embedding,
23 | num_steps,
24 | pad_id,
25 | src_from,
26 | src_mask_from=None,
27 | ):
28 | super(TransformerHead, self).__init__()
29 |
30 | self.decoder = build_sequence_decoder(decoder)
31 | self.generator = build_torch_nn(generator)
32 | self.embedding = build_torch_nn(embedding)
33 | self.num_steps = num_steps
34 | self.pad_id = pad_id
35 | self.src_from = src_from
36 | self.src_mask_from = src_mask_from
37 |
38 | logger.info('TransformerHead init weights')
39 | init_weights(self.modules())
40 |
41 | def pad_mask(self, text):
42 | pad_mask = (text == self.pad_id)
43 | pad_mask[:, 0] = False
44 | pad_mask = pad_mask.unsqueeze(1)
45 |
46 | return pad_mask
47 |
48 | def order_mask(self, text):
49 | t = text.size(1)
50 | order_mask = torch.triu(torch.ones(t, t), diagonal=1).bool()
51 | order_mask = order_mask.unsqueeze(0).to(text.device)
52 |
53 | return order_mask
54 |
55 | def text_embedding(self, texts):
56 | tgt = self.embedding(texts)
57 | tgt *= math.sqrt(tgt.size(2))
58 |
59 | return tgt
60 |
61 | def forward(self, feats, texts):
62 | src = feats[self.src_from]
63 | if self.src_mask_from:
64 | src_mask = feats[self.src_mask_from]
65 | else:
66 | src_mask = None
67 |
68 | if self.training:
69 | tgt = self.text_embedding(texts)
70 | tgt_mask = (self.pad_mask(texts) | self.order_mask(texts))
71 |
72 | out = self.decoder(tgt, src, tgt_mask, src_mask)
73 | out = self.generator(out)
74 | else:
75 | out = None
76 | for _ in range(self.num_steps):
77 | tgt = self.text_embedding(texts)
78 | tgt_mask = self.order_mask(texts)
79 | out = self.decoder(tgt, src, tgt_mask, src_mask)
80 | out = self.generator(out)
81 | next_text = torch.argmax(out[:, -1:, :], dim=-1)
82 |
83 | texts = torch.cat([texts, next_text], dim=-1)
84 |
85 | return out
86 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .bodies import build_body
4 | from .heads import build_head
5 | from .registry import MODELS
6 |
7 |
8 | @MODELS.register_module
9 | class GModel(nn.Module):
10 |
11 | def __init__(self, body, head, need_text=True):
12 | super(GModel, self).__init__()
13 |
14 | self.body = build_body(body)
15 | self.head = build_head(head)
16 | self.need_text = need_text
17 |
18 | def forward(self, inputs):
19 | if not isinstance(inputs, (tuple, list)):
20 | inputs = [inputs]
21 | x = self.body(inputs[0])
22 | if self.need_text:
23 | out = self.head(x, inputs[1])
24 | else:
25 | out = self.head(x)
26 |
27 | return out
28 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/registry.py:
--------------------------------------------------------------------------------
1 | from ..utils import Registry
2 |
3 | MODELS = Registry('model')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_module, build_torch_nn # noqa 401
2 | from .cbam import CBAM # noqa 401
3 | from .conv_module import ConvModule, ConvModules # noqa 401
4 | from .fc_module import FCModule, FCModules # noqa 401
5 | from .non_local import NonLocal2d # noqa 401
6 | from .norm import build_norm_layer # noqa 401
7 | from .residual_module import BasicBlock, Bottleneck # noqa 401
8 | from .upsample import Upsample # noqa 401
9 | from .squeeze_excitation_module import SE # noqa 401
10 |
11 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/utils/builder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from vedastr.utils import build_from_cfg
4 | from .registry import UTILS
5 |
6 |
7 | def build_module(cfg, default_args=None):
8 | util = build_from_cfg(cfg, UTILS, default_args)
9 | return util
10 |
11 |
12 | def build_torch_nn(cfg, default_args=None):
13 | module = build_from_cfg(cfg, nn, default_args, 'module')
14 | return module
15 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/utils/cbam.py:
--------------------------------------------------------------------------------
1 | # From https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .conv_module import ConvModule
8 | from .registry import UTILS
9 |
10 |
11 | class Flatten(nn.Module):
12 |
13 | def forward(self, x):
14 | return x.view(x.size(0), -1)
15 |
16 |
17 | class ChannelGate(nn.Module):
18 |
19 | def __init__(self,
20 | gate_channels,
21 | reduction_ratio=16,
22 | pool_types=['avg', 'max']):
23 | super(ChannelGate, self).__init__()
24 | self.gate_channels = gate_channels
25 | self.mlp = nn.Sequential(
26 | Flatten(),
27 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
28 | nn.ReLU(),
29 | nn.Linear(gate_channels // reduction_ratio, gate_channels))
30 | self.pool_types = pool_types
31 |
32 | def forward(self, x):
33 | channel_att_sum = None
34 | for pool_type in self.pool_types:
35 | if pool_type == 'avg':
36 | avg_pool = F.avg_pool2d(
37 | x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
38 | channel_att_raw = self.mlp(avg_pool)
39 | elif pool_type == 'max':
40 | max_pool = F.max_pool2d(
41 | x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42 | channel_att_raw = self.mlp(max_pool)
43 | elif pool_type == 'lp':
44 | lp_pool = F.lp_pool2d(
45 | x,
46 | 2, (x.size(2), x.size(3)),
47 | stride=(x.size(2), x.size(3)))
48 | channel_att_raw = self.mlp(lp_pool)
49 | elif pool_type == 'lse':
50 | # LSE pool only
51 | lse_pool = logsumexp_2d(x)
52 | channel_att_raw = self.mlp(lse_pool)
53 |
54 | if channel_att_sum is None:
55 | channel_att_sum = channel_att_raw
56 | else:
57 | channel_att_sum = channel_att_sum + channel_att_raw
58 |
59 | scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(
60 | x)
61 | return x * scale
62 |
63 |
64 | def logsumexp_2d(tensor):
65 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
66 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
67 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
68 | return outputs
69 |
70 |
71 | class ChannelPool(nn.Module):
72 |
73 | def forward(self, x):
74 | return torch.cat(
75 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)),
76 | dim=1)
77 |
78 |
79 | class SpatialGate(nn.Module):
80 |
81 | def __init__(self, norm_cfg=None):
82 | super(SpatialGate, self).__init__()
83 | kernel_size = 7
84 | self.compress = ChannelPool()
85 | self.spatial = ConvModule(
86 | 2,
87 | 1,
88 | kernel_size,
89 | stride=1,
90 | padding=(kernel_size - 1) // 2,
91 | activation=None,
92 | norm_cfg=norm_cfg)
93 |
94 | def forward(self, x):
95 | x_compress = self.compress(x)
96 | x_out = self.spatial(x_compress)
97 | scale = F.sigmoid(x_out) # broadcasting
98 | return x * scale
99 |
100 |
101 | @UTILS.register_module
102 | class CBAM(nn.Module):
103 |
104 | def __init__(
105 | self,
106 | gate_channels,
107 | reduction_ratio=16,
108 | pool_types=['avg', 'max'],
109 | no_spatial=False,
110 | norm_cfg=None,
111 | ):
112 | # TODO, default CBAM BN args
113 | # out_planes, eps=1e-5, momentum=0.01, affine=True
114 | super(CBAM, self).__init__()
115 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio,
116 | pool_types)
117 | self.no_spatial = no_spatial
118 | if not no_spatial:
119 | self.SpatialGate = SpatialGate(norm_cfg)
120 |
121 | def forward(self, x):
122 | x_out = self.ChannelGate(x)
123 | if not self.no_spatial:
124 | x_out = self.SpatialGate(x_out)
125 | return x_out
126 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/utils/fc_module.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .registry import UTILS
4 |
5 |
6 | @UTILS.register_module
7 | class FCModule(nn.Module):
8 | """FCModule
9 |
10 | Args:
11 | """
12 |
13 | def __init__(self,
14 | in_channels,
15 | out_channels,
16 | bias=True,
17 | activation='relu',
18 | inplace=True,
19 | dropout=None,
20 | order=('fc', 'act')):
21 | super(FCModule, self).__init__()
22 | self.order = order
23 | self.activation = activation
24 | self.inplace = inplace
25 |
26 | self.with_activatation = activation is not None
27 | self.with_dropout = dropout is not None
28 |
29 | self.fc = nn.Linear(in_channels, out_channels, bias)
30 |
31 | # build activation layer
32 | if self.with_activatation:
33 | # TODO: introduce `activation` and supports more activation layers
34 | if self.activation not in ['relu', 'tanh', 'sigmoid']:
35 | raise ValueError('{} is currently not supported.'.format(
36 | self.activation))
37 | if self.activation == 'relu':
38 | self.activate = nn.ReLU(inplace=inplace)
39 | elif self.activation == 'tanh':
40 | self.activate = nn.Tanh()
41 | elif self.activation == 'sigmoid':
42 | self.activate = nn.Sigmoid()
43 |
44 | if self.with_dropout:
45 | self.dropout = nn.Dropout(p=dropout)
46 |
47 | def forward(self, x):
48 | if self.order == ('fc', 'act'):
49 | x = self.fc(x)
50 |
51 | if self.with_activatation:
52 | x = self.activate(x)
53 | elif self.order == ('act', 'fc'):
54 | if self.with_activatation:
55 | x = self.activate(x)
56 | x = self.fc(x)
57 |
58 | if self.with_dropout:
59 | x = self.dropout(x)
60 |
61 | return x
62 |
63 |
64 | @UTILS.register_module
65 | class FCModules(nn.Module):
66 | """FCModules
67 |
68 | Args:
69 | """
70 |
71 | def __init__(self,
72 | in_channels,
73 | out_channels,
74 | bias=True,
75 | activation='relu',
76 | inplace=True,
77 | dropouts=None,
78 | num_fcs=1):
79 | super().__init__()
80 |
81 | if dropouts is not None:
82 | assert num_fcs == len(dropouts)
83 | dropout = dropouts[0]
84 | else:
85 | dropout = None
86 |
87 | layers = [
88 | FCModule(in_channels, out_channels, bias, activation, inplace,
89 | dropout)
90 | ]
91 | for ii in range(1, num_fcs):
92 | if dropouts is not None:
93 | dropout = dropouts[ii]
94 | else:
95 | dropout = None
96 | layers.append(
97 | FCModule(out_channels, out_channels, bias, activation, inplace,
98 | dropout))
99 |
100 | self.block = nn.Sequential(*layers)
101 |
102 | def forward(self, x):
103 | feat = self.block(x)
104 |
105 | return feat
106 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/utils/norm.py:
--------------------------------------------------------------------------------
1 | # modify from mmcv and mmdetection
2 |
3 | import torch.nn as nn
4 |
5 | norm_cfg = {
6 | # format: layer_type: (abbreviation, module)
7 | 'BN': ('bn', nn.BatchNorm2d),
8 | 'SyncBN': ('bn', nn.SyncBatchNorm),
9 | 'GN': ('gn', nn.GroupNorm),
10 | # and potentially 'SN'
11 | }
12 |
13 |
14 | def build_norm_layer(cfg, num_features, postfix='', layer_only=False):
15 | """ Build normalization layer
16 |
17 | Args:
18 | cfg (dict): cfg should contain:
19 | type (str): identify norm layer type.
20 | layer args: args needed to instantiate a norm layer.
21 | requires_grad (bool): [optional] whether stop gradient updates
22 | num_features (int): number of channels from input.
23 | postfix (int, str): appended into norm abbreviation to
24 | create named layer.
25 |
26 | Returns:
27 | name (str): abbreviation + postfix
28 | layer (nn.Module): created norm layer
29 | """
30 | assert isinstance(cfg, dict) and 'type' in cfg
31 | cfg_ = cfg.copy()
32 |
33 | layer_type = cfg_.pop('type')
34 | if layer_type not in norm_cfg:
35 | raise KeyError('Unrecognized norm type {}'.format(layer_type))
36 | else:
37 | abbr, norm_layer = norm_cfg[layer_type]
38 | if norm_layer is None:
39 | raise NotImplementedError
40 |
41 | assert isinstance(postfix, (int, str))
42 | name = abbr + str(postfix)
43 |
44 | requires_grad = cfg_.pop('requires_grad', True)
45 | if layer_type != 'GN':
46 | layer = norm_layer(num_features, **cfg_)
47 | if layer_type == 'SyncBN':
48 | layer._specify_ddp_gpu_num(1)
49 | else:
50 | assert 'num_groups' in cfg_
51 | layer = norm_layer(num_channels=num_features, **cfg_)
52 |
53 | for param in layer.parameters():
54 | param.requires_grad = requires_grad
55 |
56 | if layer_only:
57 | return layer
58 |
59 | return name, layer
60 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/utils/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | UTILS = Registry('utils')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/utils/squeeze_excitation_module.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .fc_module import FCModule
4 | from .registry import UTILS
5 |
6 |
7 | @UTILS.register_module
8 | class SE(nn.Module):
9 |
10 | def __init__(self, channel, reduction):
11 | # TODO, input channel should has same name with other modules
12 |
13 | super(SE, self).__init__()
14 | assert channel % reduction == 0, \
15 | "Input_channel can't be evenly divided by reduction."
16 |
17 | self.pool = nn.AdaptiveAvgPool2d(1)
18 | self.layer = nn.Sequential(
19 | FCModule(channel, channel // reduction, bias=False),
20 | FCModule(
21 | channel // reduction,
22 | channel,
23 | bias=False,
24 | activation='sigmoid'),
25 | )
26 |
27 | def forward(self, x):
28 | b, c, _, _ = x.size()
29 | y = self.pool(x).view(b, c)
30 | y = self.layer(y).view(b, c, 1, 1)
31 |
32 | return x * y.expand_as(x)
33 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/utils/upsample.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from .registry import UTILS
5 |
6 |
7 | @UTILS.register_module
8 | class Upsample(nn.Module):
9 | __constants__ = [
10 | 'size', 'scale_factor', 'scale_bias', 'mode', 'align_corners', 'name'
11 | ]
12 |
13 | def __init__(self,
14 | size=None,
15 | scale_factor=None,
16 | scale_bias=0,
17 | mode='nearest',
18 | align_corners=None):
19 | super(Upsample, self).__init__()
20 | self.size = size
21 | self.scale_factor = scale_factor
22 | self.scale_bias = scale_bias
23 | self.mode = mode
24 | self.align_corners = align_corners
25 |
26 | assert (self.size is None) ^ (self.scale_factor is None)
27 |
28 | def forward(self, x):
29 | if self.size:
30 | size = self.size
31 | else:
32 | n, c, h, w = x.size()
33 | new_h = int(h * self.scale_factor + self.scale_bias)
34 | new_w = int(w * self.scale_factor + self.scale_bias)
35 |
36 | size = (new_h, new_w)
37 |
38 | return F.interpolate(
39 | x, size=size, mode=self.mode, align_corners=self.align_corners)
40 |
41 | def extra_repr(self):
42 | if self.size is not None:
43 | info = 'size=' + str(self.size)
44 | else:
45 | info = 'scale_factor=' + str(self.scale_factor)
46 | info += ', scale_bias=' + str(self.scale_bias)
47 | info += ', mode=' + self.mode
48 | return info
49 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/models/weight_init.py:
--------------------------------------------------------------------------------
1 | # modify from mmcv and mmdetection
2 |
3 | import torch.nn as nn
4 |
5 |
6 | def constant_init(module, val, bias=0):
7 | nn.init.constant_(module.weight, val)
8 | if hasattr(module, 'bias') and module.bias is not None:
9 | nn.init.constant_(module.bias, bias)
10 |
11 |
12 | def xavier_init(module, gain=1, bias=0, distribution='normal'):
13 | assert distribution in ['uniform', 'normal']
14 | if distribution == 'uniform':
15 | nn.init.xavier_uniform_(module.weight, gain=gain)
16 | else:
17 | nn.init.xavier_normal_(module.weight, gain=gain)
18 | if hasattr(module, 'bias') and module.bias is not None:
19 | nn.init.constant_(module.bias, bias)
20 |
21 |
22 | def normal_init(module, mean=0, std=1, bias=0):
23 | nn.init.normal_(module.weight, mean, std)
24 | if hasattr(module, 'bias') and module.bias is not None:
25 | nn.init.constant_(module.bias, bias)
26 |
27 |
28 | def uniform_init(module, a=0, b=1, bias=0):
29 | nn.init.uniform_(module.weight, a, b)
30 | if hasattr(module, 'bias') and module.bias is not None:
31 | nn.init.constant_(module.bias, bias)
32 |
33 |
34 | def kaiming_init(module,
35 | a=0,
36 | is_rnn=False,
37 | mode='fan_in',
38 | nonlinearity='leaky_relu',
39 | bias=0,
40 | distribution='normal'):
41 | assert distribution in ['uniform', 'normal']
42 | if distribution == 'uniform':
43 | if is_rnn:
44 | for name, param in module.named_parameters():
45 | if 'bias' in name:
46 | nn.init.constant_(param, bias)
47 | elif 'weight' in name:
48 | nn.init.kaiming_uniform_(
49 | param, a=a, mode=mode, nonlinearity=nonlinearity)
50 | else:
51 | nn.init.kaiming_uniform_(
52 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
53 |
54 | else:
55 | if is_rnn:
56 | for name, param in module.named_parameters():
57 | if 'bias' in name:
58 | nn.init.constant_(param, bias)
59 | elif 'weight' in name:
60 | nn.init.kaiming_normal_(
61 | param, a=a, mode=mode, nonlinearity=nonlinearity)
62 | else:
63 | nn.init.kaiming_normal_(
64 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
65 |
66 | if not is_rnn and hasattr(module, 'bias') and module.bias is not None:
67 | nn.init.constant_(module.bias, bias)
68 |
69 |
70 | def caffe2_xavier_init(module, bias=0):
71 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
72 | # Acknowledgment to FAIR's internal code
73 | kaiming_init(
74 | module,
75 | a=1,
76 | mode='fan_in',
77 | nonlinearity='leaky_relu',
78 | distribution='uniform')
79 |
80 |
81 | def init_weights(modules):
82 | for m in modules:
83 | if isinstance(m, nn.Conv2d):
84 | kaiming_init(m)
85 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
86 | constant_init(m, 1)
87 | elif isinstance(m, nn.Linear):
88 | xavier_init(m)
89 | elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
90 | kaiming_init(m, is_rnn=True)
91 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_optimizer # noqa 401
2 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/optimizers/builder.py:
--------------------------------------------------------------------------------
1 | import torch.optim as torch_optim
2 |
3 | from vedastr.utils import build_from_cfg
4 |
5 |
6 | def build_optimizer(cfg, default_args=None):
7 | optim = build_from_cfg(cfg, torch_optim, default_args, 'module')
8 |
9 | return optim
10 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/runners/__init__.py:
--------------------------------------------------------------------------------
1 | from .inference_runner import InferenceRunner # noqa 401
2 | from .test_runner import TestRunner # noqa 401
3 | from .train_runner import TrainRunner # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/runners/base.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | from torch.backends import cudnn
6 |
7 | from ..dataloaders import build_dataloader
8 | from ..dataloaders.samplers import build_sampler
9 | from ..datasets import build_datasets
10 | from ..logger import build_logger
11 | from ..metrics import build_metric
12 | from ..transforms import build_transform
13 | from ..utils import get_dist_info, init_dist_pytorch
14 |
15 |
16 | class Common(object):
17 |
18 | def __init__(self, cfg):
19 | super(Common, self).__init__()
20 |
21 | # build logger
22 | logger_cfg = cfg.get('logger')
23 | if logger_cfg is None:
24 | logger_cfg = dict(
25 | handlers=(dict(type='StreamHandler', level='INFO'),
26 | ),
27 | )
28 | self.workdir = cfg.get('workdir')
29 | self.distribute = cfg.get('distribute', False)
30 |
31 | # set gpu devices
32 | self.use_gpu = self._set_device()
33 |
34 | # set distribute setting
35 | if self.distribute and self.use_gpu:
36 | init_dist_pytorch(**cfg.dist_params)
37 |
38 | self.rank, self.world_size = get_dist_info()
39 | self.logger = self._build_logger(logger_cfg)
40 |
41 | # set cudnn configuration
42 | self._set_cudnn(
43 | cfg.get('cudnn_deterministic', False),
44 | cfg.get('cudnn_benchmark', False))
45 |
46 | # set seed
47 | self._set_seed(cfg.get('seed', None))
48 | self.seed = cfg.get('seed', None)
49 |
50 | # build metric
51 | if 'metric' in cfg:
52 | self.metric = self._build_metric(cfg['metric'])
53 | self.backup_metric = self._build_metric(cfg['metric'])
54 | else:
55 | raise KeyError('Please set metric in config file.')
56 |
57 | # set need_text
58 | self.need_text = False
59 |
60 | def _build_logger(self, cfg):
61 | return build_logger(cfg, dict(workdir=self.workdir))
62 |
63 | def _set_device(self):
64 | self.gpu_num = torch.cuda.device_count()
65 | if torch.cuda.is_available():
66 | use_gpu = True
67 | else:
68 | use_gpu = False
69 |
70 | return use_gpu
71 |
72 | def _set_seed(self, seed):
73 | if seed is not None:
74 | self.logger.info('Set seed {}'.format(seed))
75 | random.seed(seed)
76 | np.random.seed(seed)
77 | torch.manual_seed(seed)
78 | torch.cuda.manual_seed_all(seed)
79 |
80 | def _set_cudnn(self, deterministic, benchmark):
81 | self.logger.info('Set cudnn deterministic {}'.format(deterministic))
82 | cudnn.deterministic = deterministic
83 |
84 | self.logger.info('Set cudnn benchmark {}'.format(benchmark))
85 | cudnn.benchmark = benchmark
86 |
87 | def _build_metric(self, cfg):
88 | return build_metric(cfg)
89 |
90 | def _build_transform(self, cfg):
91 | return build_transform(cfg)
92 |
93 | def _build_dataloader(self, cfg):
94 | transform = build_transform(cfg['transform'])
95 | dataset = build_datasets(cfg['dataset'], dict(transform=transform))
96 |
97 | # TODO, distributed sampler or not
98 | if not cfg.get('sampler'):
99 | sampler = None
100 | else:
101 | if isinstance(dataset, list):
102 | sampler = [
103 | build_sampler(self.distribute, cfg['sampler'],
104 | dict(dataset=d)) for d in dataset
105 | ]
106 | else:
107 | sampler = build_sampler(self.distribute,
108 | cfg['sampler'],
109 | dict(dataset=dataset))
110 | dataloader = build_dataloader(
111 | self.distribute,
112 | self.gpu_num,
113 | cfg['dataloader'],
114 | dict(dataset=dataset, sampler=sampler),
115 | seed=self.seed,
116 | )
117 | return dataloader
118 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/runners/inference_runner.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from .base import Common
8 | from ..converter import build_converter
9 | from ..models import build_model
10 | from ..utils import load_checkpoint
11 |
12 |
13 | class InferenceRunner(Common):
14 |
15 | def __init__(self, inference_cfg, common_cfg=None):
16 | inference_cfg = inference_cfg.copy()
17 | common_cfg = {} if common_cfg is None else common_cfg.copy()
18 | super(InferenceRunner, self).__init__(common_cfg)
19 |
20 | # build test transform
21 | self.transform = self._build_transform(inference_cfg['transform'])
22 | # build converter
23 | self.converter = self._build_converter(inference_cfg['converter'])
24 | # build model
25 | self.model = self._build_model(inference_cfg['model'])
26 | self.logger.info(self.model)
27 | self.postprocess_cfg = inference_cfg.get('postprocess', None)
28 | self.model.eval()
29 |
30 | def _build_model(self, cfg):
31 | self.logger.info('Build model')
32 |
33 | model = build_model(cfg)
34 | params_num = []
35 | for p in filter(lambda p: p.requires_grad, model.parameters()):
36 | params_num.append(np.prod(p.size()))
37 | self.logger.info('Trainable params num : %s' % (sum(params_num)))
38 | self.need_text = model.need_text
39 |
40 | if self.use_gpu:
41 | if self.distribute:
42 | model = torch.nn.parallel.DistributedDataParallel(
43 | model.cuda(),
44 | device_ids=[torch.cuda.current_device()],
45 | broadcast_buffers=True,
46 | )
47 | self.logger.info('Using distributed training')
48 | else:
49 | if torch.cuda.device_count() > 1:
50 | model = torch.nn.DataParallel(model)
51 | model.cuda()
52 | return model
53 |
54 | def _build_converter(self, cfg):
55 | return build_converter(cfg)
56 |
57 | def load_checkpoint(self, filename, map_location='default', strict=True):
58 | self.logger.info('Load checkpoint from {}'.format(filename))
59 |
60 | if map_location == 'default':
61 | if self.use_gpu:
62 | device_id = torch.cuda.current_device()
63 | map_location = lambda storage, loc: storage.cuda(device_id)
64 | else:
65 | map_location = 'cpu'
66 |
67 | return load_checkpoint(self.model, filename, map_location, strict)
68 |
69 | def postprocess(self, preds, cfg=None):
70 | if cfg is not None:
71 | sensitive = cfg.get('sensitive', True)
72 | character = cfg.get('character', '')
73 | else:
74 | sensitive = True
75 | character = ''
76 |
77 | probs = F.softmax(preds, dim=2)
78 | max_probs, indexes = probs.max(dim=2)
79 | preds_str = []
80 | preds_prob = []
81 | for i, pstr in enumerate(self.converter.decode(indexes)):
82 | str_len = len(pstr)
83 | if str_len == 0:
84 | prob = 0
85 | else:
86 | prob = max_probs[i, :str_len].cumprod(dim=0)[-1]
87 | preds_prob.append(prob)
88 | if not sensitive:
89 | pstr = pstr.lower()
90 |
91 | if character:
92 | pstr = re.sub('[^{}]'.format(character), '', pstr)
93 |
94 | preds_str.append(pstr)
95 | return preds_str, preds_prob
96 |
97 | def __call__(self, image):
98 | with torch.no_grad():
99 | dummy_text = ''
100 | aug = self.transform(image=image, label=dummy_text)
101 | image, text = aug['image'], aug['label']
102 | image = image.unsqueeze(0)
103 | label_input, label_length, label_target = self.converter.test_encode([text]) # noqa 501
104 | if self.use_gpu:
105 | image = image.cuda()
106 | label_input = label_input.cuda()
107 |
108 | if self.need_text:
109 | pred = self.model((image, label_input))
110 | else:
111 | pred = self.model((image,))
112 |
113 | pred, prob = self.postprocess(pred, self.postprocess_cfg)
114 |
115 | return pred, prob
116 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/runners/test_runner.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | from .inference_runner import InferenceRunner
6 |
7 |
8 | class TestRunner(InferenceRunner):
9 |
10 | def __init__(self, test_cfg, inference_cfg, common_cfg=None):
11 | super(TestRunner, self).__init__(inference_cfg, common_cfg)
12 |
13 | self.test_dataloader = self._build_dataloader(test_cfg['data'])
14 | if not isinstance(self.test_dataloader, dict):
15 | self.test_dataloader = dict(all=self.test_dataloader)
16 | self.test_exclude_num = dict()
17 | for k, v in self.test_dataloader.items():
18 | extra_data = len(v.dataset) % self.world_size
19 | self.test_exclude_num[
20 | k] = self.world_size - extra_data if extra_data != 0 else 0
21 | self.postprocess_cfg = test_cfg.get('postprocess_cfg', None)
22 |
23 | def test_batch(self, img, label, save_path=None, exclude_num=0):
24 | self.model.eval()
25 | with torch.no_grad():
26 | label_input, label_length, label_target = self.converter.test_encode(label) # noqa 501
27 | if self.use_gpu:
28 | img = img.cuda()
29 | label_input = label_input.cuda()
30 |
31 | if self.need_text:
32 | pred = self.model((img, label_input))
33 | else:
34 | pred = self.model((img,))
35 |
36 | pred, prob = self.postprocess(pred, self.postprocess_cfg)
37 |
38 | if save_path is not None:
39 | for idx, (p, l) in enumerate(zip(pred, label)):
40 | if p == l:
41 | print(p, '\t', l)
42 | cimg = img[idx][0, :, :].cpu().numpy()
43 | cimg = (cimg * 0.5) + 0.5
44 | cv2.imwrite(save_path + f'/%s_{p}_{l}.png' % idx,
45 | (cimg * 255).astype(np.uint8))
46 | self.metric.measure(pred, prob, label, exclude_num)
47 | self.backup_metric.measure(pred, prob, label, exclude_num)
48 |
49 | def __call__(self):
50 | self.logger.info('Start testing')
51 | self.logger.info('test info: %s' % self.postprocess_cfg)
52 | self.metric.reset()
53 | accs = []
54 | for name, dataloader in self.test_dataloader.items():
55 | test_exclude_num = self.test_exclude_num[name]
56 | save_path = None
57 | self.backup_metric.reset()
58 | for tidx, (img, label) in enumerate(dataloader):
59 | exclude_num = test_exclude_num if (tidx +
60 | 1) == len(dataloader) else 0
61 |
62 | self.test_batch(img, label, save_path, exclude_num)
63 | accs.append(self.backup_metric.avg['acc']['true'])
64 | self.logger.info(
65 | 'Test, current dataset root %s, acc %.4f, edit distance %.4f' %
66 | (name, self.backup_metric.avg['acc']['true'],
67 | self.backup_metric.avg['edit']))
68 | self.logger.info(
69 | 'Test, average acc %.4f, edit distance %s' %
70 | (self.metric.avg['acc']['true'], self.metric.avg['edit']))
71 | acc_str = ' '.join(list(map(lambda x: str(x)[:6], accs)))
72 | self.logger.info('For copy and record, %s' % acc_str)
73 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_transform # noqa 401
2 | from .transforms import (FactorScale, LongestMaxSize, PadIfNeeded, RandomScale, # noqa 401
3 | Sensitive, ToTensor) # noqa 401
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/transforms/builder.py:
--------------------------------------------------------------------------------
1 | import albumentations as albu
2 |
3 | from vedastr.utils import build_from_cfg
4 | from .registry import TRANSFORMS
5 |
6 |
7 | def build_transform(cfgs):
8 | tfs = []
9 | for cfg in cfgs:
10 | if TRANSFORMS.get(cfg['type']):
11 | tf = build_from_cfg(cfg, TRANSFORMS)
12 | else:
13 | tf = build_from_cfg(cfg, albu, src='module')
14 | tfs.append(tf)
15 | aug = albu.Compose(tfs)
16 |
17 | return aug
18 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/transforms/registry.py:
--------------------------------------------------------------------------------
1 | from vedastr.utils import Registry
2 |
3 | TRANSFORMS = Registry('transforms')
4 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .checkpoint import (load_checkpoint, load_state_dict, load_url_dist, # noqa 401
2 | save_checkpoint) # noqa 401
3 | from .common import (WorkerInit, build_from_cfg, get_root_logger, # noqa 401
4 | set_random_seed) # noqa 401
5 | from .config import Config, ConfigDict # noqa 401
6 | from .dist_utils import (gather_tensor, get_dist_info, init_dist_pytorch, # noqa 401
7 | master_only, reduce_tensor) # noqa 401
8 | from .registry import Registry # noqa 401
9 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/utils/common.py:
--------------------------------------------------------------------------------
1 | # modify from mmcv and mmdetection
2 |
3 | import inspect
4 | import logging
5 | import random
6 | import sys
7 |
8 | import numpy as np
9 | import torch
10 |
11 |
12 | def build_from_cfg(cfg, parent, default_args=None, src='registry'):
13 | if src == 'registry':
14 | return obj_from_dict_registry(cfg, parent, default_args)
15 | elif src == 'module':
16 | return obj_from_dict_module(cfg, parent, default_args)
17 | else:
18 | raise ValueError('Method %s is not supported' % src)
19 |
20 |
21 | def obj_from_dict_module(info, parent=None, default_args=None):
22 | """Initialize an object from dict.
23 | The dict must contain the key "type", which indicates the object type, it
24 | can be either a string or type, such as "list" or ``list``. Remaining
25 | fields are treated as the arguments for constructing the object.
26 | Args:
27 | info (dict): Object types and arguments.
28 | parent (:class:`module`): Module which may containing expected object
29 | classes.
30 | default_args (dict, optional): Default arguments for initializing the
31 | object.
32 | Returns:
33 | any type: Object built from the dict.
34 | """
35 | assert isinstance(info, dict) and 'type' in info
36 | assert isinstance(default_args, dict) or default_args is None
37 | args = info.copy()
38 | obj_type = args.pop('type')
39 | if isinstance(obj_type, str):
40 | if parent is not None:
41 | obj_type = getattr(parent, obj_type)
42 | else:
43 | obj_type = sys.modules[obj_type]
44 | elif not isinstance(obj_type, type):
45 | raise TypeError('type must be a str or valid type, but got {}'.format(
46 | type(obj_type)))
47 | if default_args is not None:
48 | for name, value in default_args.items():
49 | args.setdefault(name, value)
50 | return obj_type(**args)
51 |
52 |
53 | def obj_from_dict_registry(cfg, registry, default_args=None):
54 | """Build a module from config dict.
55 | Args:
56 | cfg (dict): Config dict. It should at least contain the key "type".
57 | registry (:obj:`Registry`): The registry to search the type from.
58 | default_args (dict, optional): Default initialization arguments.
59 | Returns:
60 | obj: The constructed object.
61 | """
62 | assert isinstance(cfg, dict) and 'type' in cfg
63 | assert isinstance(default_args, dict) or default_args is None
64 | args = cfg.copy()
65 | obj_type = args.pop('type')
66 | if isinstance(obj_type, str):
67 | obj_cls = registry.get(obj_type)
68 | if obj_cls is None:
69 | raise KeyError('{} is not in the {} registry'.format(
70 | obj_type, registry.name))
71 | elif inspect.isclass(obj_type):
72 | obj_cls = obj_type
73 | else:
74 | raise TypeError('type must be a str or valid type, but got {}'.format(
75 | type(obj_type)))
76 | if default_args is not None:
77 | for name, value in default_args.items():
78 | args.setdefault(name, value)
79 | return obj_cls(**args)
80 |
81 |
82 | def set_random_seed(seed):
83 | random.seed(seed)
84 | np.random.seed(seed)
85 | torch.manual_seed(seed)
86 | torch.cuda.manual_seed_all(seed)
87 |
88 |
89 | def get_root_logger(log_level=logging.INFO):
90 | logger = logging.getLogger()
91 | np.set_printoptions(precision=4)
92 | if not logger.hasHandlers():
93 | logging.basicConfig(
94 | format='%(asctime)s - %(levelname)s - %(message)s',
95 | level=log_level)
96 | return logger
97 |
98 |
99 | class WorkerInit:
100 |
101 | def __init__(self, num_workers, rank, seed, epoch):
102 | self.num_workers = num_workers
103 | self.rank = rank
104 | self.seed = seed if seed is not None else 0
105 | self.epoch = epoch
106 |
107 | def __call__(self, worker_id):
108 | worker_seed = self.num_workers * self.rank + \
109 | worker_id + self.seed + self.epoch
110 | np.random.seed(worker_seed)
111 | random.seed(worker_seed)
112 |
113 | def set_epoch(self, n):
114 | assert isinstance(n, int)
115 | self.epoch = n
116 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | # adapted from mmcv and mmdetection
2 |
3 | import functools
4 | import os
5 |
6 | import torch
7 | import torch.distributed as dist
8 |
9 |
10 | def init_dist_pytorch(backend='nccl', **kwargs):
11 | rank = int(os.environ['RANK'])
12 | num_gpus = torch.cuda.device_count()
13 | torch.cuda.set_device(rank % num_gpus)
14 | dist.init_process_group(backend=backend, **kwargs)
15 |
16 |
17 | def get_dist_info():
18 | if dist.is_available():
19 | initialized = dist.is_initialized()
20 | else:
21 | initialized = False
22 |
23 | if initialized:
24 | rank = dist.get_rank()
25 | world_size = dist.get_world_size()
26 | else:
27 | rank = 0
28 | world_size = 1
29 |
30 | return rank, world_size
31 |
32 |
33 | def reduce_tensor(data, average=True):
34 | rank, world_size = get_dist_info()
35 | if world_size < 2:
36 | return data
37 |
38 | with torch.no_grad():
39 | if not isinstance(data, torch.Tensor):
40 | data = torch.tensor(data).cuda()
41 | dist.reduce(data, dst=0)
42 | if rank == 0 and average:
43 | data /= world_size
44 | return data
45 |
46 |
47 | def gather_tensor(data):
48 | _, world_size = get_dist_info()
49 | if world_size < 2:
50 | return data
51 |
52 | with torch.no_grad():
53 | if not isinstance(data, torch.Tensor):
54 | data = torch.tensor(data).cuda()
55 | if not data.size():
56 | data = data.unsqueeze(0)
57 |
58 | gather_list = [torch.ones_like(data) for _ in range(world_size)]
59 | dist.all_gather(gather_list, data)
60 | gather_data = torch.cat(gather_list, 0)
61 |
62 | return gather_data
63 |
64 |
65 | def synchronize():
66 | if not dist.is_available():
67 | return
68 | if not dist.is_initialized():
69 | return
70 | world_size = dist.get_world_size()
71 | if world_size == 1:
72 | return
73 | dist.barrier()
74 |
75 |
76 | def master_only(func):
77 | """Don't use master_only to decorate function which have random state.
78 |
79 | """
80 |
81 | @functools.wraps(func)
82 | def wrapper(*args, **kwargs):
83 | rank, _ = get_dist_info()
84 | if rank == 0:
85 | return func(*args, **kwargs)
86 |
87 | return wrapper
88 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/utils/path.py:
--------------------------------------------------------------------------------
1 | # modify from mmcv and mmdetection
2 |
3 | import os
4 | import os.path as osp
5 | import sys
6 | from pathlib import Path
7 |
8 | import six
9 |
10 | from .misc import is_str
11 |
12 | if sys.version_info <= (3, 3):
13 | FileNotFoundError = IOError
14 | else:
15 | FileNotFoundError = FileNotFoundError
16 |
17 |
18 | def is_filepath(x):
19 | if is_str(x) or isinstance(x, Path):
20 | return True
21 | else:
22 | return False
23 |
24 |
25 | def fopen(filepath, *args, **kwargs):
26 | if is_str(filepath):
27 | return open(filepath, *args, **kwargs)
28 | elif isinstance(filepath, Path):
29 | return filepath.open(*args, **kwargs)
30 |
31 |
32 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
33 | if not osp.isfile(filename):
34 | raise FileNotFoundError(msg_tmpl.format(filename))
35 |
36 |
37 | def mkdir_or_exist(dir_name, mode=0o777):
38 | if dir_name == '':
39 | return
40 | dir_name = osp.expanduser(dir_name)
41 | if six.PY3:
42 | os.makedirs(dir_name, mode=mode, exist_ok=True)
43 | else:
44 | if not osp.isdir(dir_name):
45 | os.makedirs(dir_name, mode=mode)
46 |
47 |
48 | def symlink(src, dst, overwrite=True, **kwargs):
49 | if os.path.lexists(dst) and overwrite:
50 | os.remove(dst)
51 | os.symlink(src, dst, **kwargs)
52 |
53 |
54 | def _scandir_py35(dir_path, suffix=None):
55 | for entry in os.scandir(dir_path):
56 | if not entry.is_file():
57 | continue
58 | filename = entry.name
59 | if suffix is None:
60 | yield filename
61 | elif filename.endswith(suffix):
62 | yield filename
63 |
64 |
65 | def _scandir_py(dir_path, suffix=None):
66 | for filename in os.listdir(dir_path):
67 | if not osp.isfile(osp.join(dir_path, filename)):
68 | continue
69 | if suffix is None:
70 | yield filename
71 | elif filename.endswith(suffix):
72 | yield filename
73 |
74 |
75 | def scandir(dir_path, suffix=None):
76 | if suffix is not None and not isinstance(suffix, (str, tuple)):
77 | raise TypeError('"suffix" must be a string or tuple of strings')
78 | if sys.version_info >= (3, 5):
79 | return _scandir_py35(dir_path, suffix)
80 | else:
81 | return _scandir_py(dir_path, suffix)
82 |
83 |
84 | def find_vcs_root(path, markers=('.git', )):
85 | """Finds the root directory (including itself) of specified markers.
86 |
87 | Args:
88 | path (str): Path of directory or file.
89 | markers (list[str], optional): List of file or directory names.
90 |
91 | Returns:
92 | The directory contained one of the markers or None if not found.
93 | """
94 | if osp.isfile(path):
95 | path = osp.dirname(path)
96 |
97 | prev, cur = None, osp.abspath(osp.expanduser(path))
98 | while cur != prev:
99 | if any(osp.exists(osp.join(cur, marker)) for marker in markers):
100 | return cur
101 | prev, cur = cur, osp.split(cur)[0]
102 | return None
103 |
--------------------------------------------------------------------------------
/vedastr_cstr/vedastr/utils/registry.py:
--------------------------------------------------------------------------------
1 | # modify from mmcv and mmdetection
2 | import inspect
3 |
4 |
5 | class Registry(object):
6 |
7 | def __init__(self, name):
8 | self._name = name
9 | self._module_dict = dict()
10 |
11 | def __repr__(self):
12 | format_str = self.__class__.__name__ + '(name={}, items={})'.format(
13 | self._name, list(self._module_dict.keys()))
14 | return format_str
15 |
16 | @property
17 | def name(self):
18 | return self._name
19 |
20 | @property
21 | def module_dict(self):
22 | return self._module_dict
23 |
24 | def get(self, key):
25 | return self._module_dict.get(key, None)
26 |
27 | def _register_module(self, module_class):
28 | """Register a module.
29 | Args:
30 | module (:obj:`nn.Module`): Module to be registered.
31 | """
32 | if not inspect.isclass(module_class):
33 | raise TypeError('module must be a class, but got {}'.format(
34 | type(module_class)))
35 | module_name = module_class.__name__
36 | if module_name in self._module_dict:
37 | raise KeyError('{} is already registered in {}'.format(
38 | module_name, self.name))
39 | self._module_dict[module_name] = module_class
40 |
41 | def register_module(self, cls):
42 | self._register_module(cls)
43 | return cls
44 |
--------------------------------------------------------------------------------