├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── feature-request.md │ └── write-document.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── Attention.png ├── SATRN.png ├── SATRN_feedforward.png ├── competition-overview.png ├── demo.png └── logo.png ├── checkpoint.py ├── configs ├── Attention.yaml └── SATRN.yaml ├── custom_augment.py ├── data_tools ├── download.sh ├── extract_tokens.py ├── make_dataset.py ├── parse_upstage.py └── train_test_split.py ├── dataset.py ├── flags.py ├── inference.py ├── metrics.py ├── networks ├── Attention.py ├── README.md ├── SATRN.py ├── loss.py └── spatial_transformation.py ├── pre_processing.py ├── requirements.txt ├── scheduler.py ├── train.py ├── transform.py ├── utils.py └── vedastr_cstr ├── LICENSE ├── README.md ├── configs └── cstr.py ├── requirements.txt ├── tools ├── deploy │ ├── benchmark.py │ ├── export.py │ └── utils │ │ ├── __init__.py │ │ └── common.py ├── dist_test.sh ├── dist_train.sh ├── inference.py ├── test.py └── train.py └── vedastr ├── __init__.py ├── converter ├── __init__.py ├── attn_converter.py ├── base_convert.py ├── builder.py ├── ctc_converter.py ├── custom_converter.py ├── fc_converter.py └── registry.py ├── criteria ├── __init__.py ├── builder.py ├── cross_entropy_loss.py ├── ctc_loss.py ├── label_smooth_cross_entropy_loss.py └── registry.py ├── dataloaders ├── __init__.py ├── builder.py ├── registry.py └── samplers │ ├── __init__.py │ ├── balance_sampler.py │ ├── builder.py │ ├── default_sampler.py │ ├── dist_balance_sampler.py │ ├── dist_default_sampler.py │ └── registry.py ├── datasets ├── __init__.py ├── base.py ├── builder.py ├── concat_dataset.py ├── fold_dataset.py ├── lmdb_dataset.py ├── paste_dataset.py ├── registry.py └── txt_datasets.py ├── logger ├── __init__.py └── builder.py ├── lr_schedulers ├── __init__.py ├── base.py ├── builder.py ├── constant_lr.py ├── cosine_lr.py ├── exponential_lr.py ├── poly_lr.py ├── registry.py └── step_lr.py ├── metrics ├── __init__.py ├── accuracy.py ├── builder.py └── registry.py ├── models ├── __init__.py ├── bodies │ ├── __init__.py │ ├── body.py │ ├── builder.py │ ├── component.py │ ├── feature_extractors │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── decoders │ │ │ ├── __init__.py │ │ │ ├── bricks │ │ │ │ ├── __init__.py │ │ │ │ ├── bricks.py │ │ │ │ ├── builder.py │ │ │ │ ├── pva.py │ │ │ │ └── registry.py │ │ │ ├── builder.py │ │ │ ├── gfpn.py │ │ │ └── registry.py │ │ └── encoders │ │ │ ├── __init__.py │ │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── general_backbone.py │ │ │ ├── registry.py │ │ │ └── resnet.py │ │ │ ├── builder.py │ │ │ └── enhance_modules │ │ │ ├── __init__.py │ │ │ ├── aspp.py │ │ │ ├── builder.py │ │ │ ├── ppm.py │ │ │ └── registry.py │ ├── rectificators │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── registry.py │ │ ├── spin.py │ │ ├── sspin.py │ │ └── tps_stn.py │ ├── registry.py │ └── sequences │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── registry.py │ │ ├── rnn │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── encoder.py │ │ └── transformer │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── position_encoder │ │ ├── __init__.py │ │ ├── adaptive_2d_encoder.py │ │ ├── builder.py │ │ ├── encoder.py │ │ ├── registry.py │ │ └── utils.py │ │ └── unit │ │ ├── __init__.py │ │ ├── attention │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── multihead_attention.py │ │ └── registry.py │ │ ├── builder.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── feedforward │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── feedforward.py │ │ └── registry.py │ │ └── registry.py ├── builder.py ├── heads │ ├── __init__.py │ ├── att_head.py │ ├── builder.py │ ├── conv_head.py │ ├── ctc_head.py │ ├── fc_head.py │ ├── head.py │ ├── multi_head.py │ ├── registry.py │ └── transformer_head.py ├── model.py ├── registry.py ├── utils │ ├── __init__.py │ ├── builder.py │ ├── cbam.py │ ├── conv_module.py │ ├── fc_module.py │ ├── non_local.py │ ├── norm.py │ ├── registry.py │ ├── residual_module.py │ ├── squeeze_excitation_module.py │ └── upsample.py └── weight_init.py ├── optimizers ├── __init__.py └── builder.py ├── runners ├── __init__.py ├── base.py ├── inference_runner.py ├── test_runner.py └── train_runner.py ├── transforms ├── __init__.py ├── builder.py ├── registry.py └── transforms.py └── utils ├── __init__.py ├── checkpoint.py ├── common.py ├── config.py ├── dist_utils.py ├── misc.py ├── path.py └── registry.py /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: 버그 보고 4 | title: "[BUG] " 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### 🐞 Describe the bug 11 | 12 | ### 📷 Screenshots 13 | 14 | ### 📄 Additional context 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: 새로운 기능 제안 4 | title: "[FEAT] " 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### 😥 Explain the Problem 11 | 12 | ### ✨ Describe A New Feature 13 | 14 | ### 📄 Additional Context 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/write-document.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Write Document 3 | about: 문서화 작업 4 | title: "[DOC] " 5 | labels: documentation 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### ✅ Check the Type 11 | 12 | - [ ] Comments to explain Code 13 | - [ ] Documentation to explain Project 14 | 15 | ### 📝 What to Document 16 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## ✨ Types of Changes 2 | 3 | - [ ] Bugfix 4 | - [ ] New Feature 5 | - [ ] Documentation 6 | 7 | ## ✅ Checklist 8 | 9 | - [ ] Enough documentation in Pull Request or README.md 10 | - [ ] Add comments to explain code 11 | - [ ] Observe python PEP8 rule 12 | - [ ] Check if code works well 13 | 14 | ## 📝 Proposed Changes 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/linux,python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,python 4 | 5 | ### Linux ### 6 | *~ 7 | 8 | # temporary files which can be created if a process still has a handle open of a deleted file 9 | .fuse_hidden* 10 | 11 | # KDE directory preferences 12 | .directory 13 | 14 | # Linux trash folder which might appear on any partition or disk 15 | .Trash-* 16 | 17 | # .nfs files are created when an open file is removed but is still being accessed 18 | .nfs* 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | pip-wheel-metadata/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | pytestdebug.log 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | doc/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # poetry 115 | #poetry.lock 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | # .env 129 | .env/ 130 | .venv/ 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | pythonenv* 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # operating system-related files 160 | # file properties cache/storage on macOS 161 | *.DS_Store 162 | # thumbnail cache on Windows 163 | Thumbs.db 164 | 165 | # profiling data 166 | .prof 167 | 168 | 169 | # End of https://www.toptal.com/developers/gitignore/api/linux,python 170 | 171 | #log 172 | log/ 173 | wandb/ 174 | .vscode/ 175 | .idea/ 176 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Team DKT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/Attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/Attention.png -------------------------------------------------------------------------------- /assets/SATRN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/SATRN.png -------------------------------------------------------------------------------- /assets/SATRN_feedforward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/SATRN_feedforward.png -------------------------------------------------------------------------------- /assets/competition-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/competition-overview.png -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/demo.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/assets/logo.png -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tensorboardX import SummaryWriter 4 | 5 | use_cuda = torch.cuda.is_available() 6 | 7 | default_checkpoint = { 8 | "epoch": 0, 9 | "train_losses": [], 10 | "train_symbol_accuracy": [], 11 | "train_sentence_accuracy": [], 12 | "train_wer": [], 13 | "train_score": [], 14 | "validation_losses": [], 15 | "validation_symbol_accuracy": [], 16 | "validation_sentence_accuracy": [], 17 | "validation_wer": [], 18 | "validation_score": [], 19 | "lr": [], 20 | "grad_norm": [], 21 | "model": {}, 22 | "configs":{}, 23 | "token_to_id":{}, 24 | "id_to_token":{}, 25 | } 26 | 27 | 28 | def save_checkpoint(checkpoint, dir="./checkpoints", prefix=""): 29 | """ Saving check point 30 | 31 | Args: 32 | checkpoint(dict) : Checkpoint to save 33 | dir(str) : Path to save the checkpoint 34 | prefix(str) : Path of location of dir 35 | """ 36 | # Padded to 4 digits because of lexical sorting of numbers. 37 | # e.g. 0009.pth 38 | filename = "{num:0>4}.pth".format(num=checkpoint["epoch"]) 39 | if not os.path.exists(os.path.join(prefix, dir)): 40 | os.makedirs(os.path.join(prefix, dir)) 41 | torch.save(checkpoint, os.path.join(prefix, dir, filename)) 42 | 43 | 44 | def load_checkpoint(path, cuda=use_cuda): 45 | """ Load check point 46 | 47 | Args: 48 | path(str) : Path checkpoint located 49 | cuda : Whether use cuda or not [Default: use_cuda] 50 | Returns 51 | Loaded checkpoints 52 | """ 53 | if cuda: 54 | return torch.load(path) 55 | else: 56 | # Load GPU model on CPU 57 | return torch.load(path, map_location=lambda storage, loc: storage) 58 | 59 | 60 | def init_tensorboard(name="", base_dir="./tensorboard"): 61 | """Init tensorboard 62 | Args: 63 | name(str) : name of tensorboard 64 | base_dir(str): path of tesnorboard 65 | """ 66 | return SummaryWriter(os.path.join(name, base_dir)) 67 | 68 | 69 | def write_tensorboard( 70 | writer, 71 | epoch, 72 | grad_norm, 73 | train_loss, 74 | train_symbol_accuracy, 75 | train_sentence_accuracy, 76 | train_wer, 77 | train_score, 78 | validation_loss, 79 | validation_symbol_accuracy, 80 | validation_sentence_accuracy, 81 | validation_wer, 82 | validation_score, 83 | model, 84 | ): 85 | writer.add_scalar("train_loss", train_loss, epoch) 86 | writer.add_scalar("train_symbol_accuracy", train_symbol_accuracy, epoch) 87 | writer.add_scalar("train_sentence_accuracy",train_sentence_accuracy,epoch) 88 | writer.add_scalar("train_wer", train_wer, epoch) 89 | writer.add_scalar("train_score", train_score, epoch) 90 | writer.add_scalar("validation_loss", validation_loss, epoch) 91 | writer.add_scalar("validation_symbol_accuracy", validation_symbol_accuracy, epoch) 92 | writer.add_scalar("validation_sentence_accuracy",validation_sentence_accuracy,epoch) 93 | writer.add_scalar("validation_wer",validation_wer,epoch) 94 | writer.add_scalar("validation_score", validation_score, epoch) 95 | writer.add_scalar("grad_norm", grad_norm, epoch) 96 | 97 | for name, param in model.encoder.named_parameters(): 98 | writer.add_histogram( 99 | "encoder/{}".format(name), param.detach().cpu().numpy(), epoch 100 | ) 101 | if param.grad is not None: 102 | writer.add_histogram( 103 | "encoder/{}/grad".format(name), param.grad.detach().cpu().numpy(), epoch 104 | ) 105 | 106 | for name, param in model.decoder.named_parameters(): 107 | writer.add_histogram( 108 | "decoder/{}".format(name), param.detach().cpu().numpy(), epoch 109 | ) 110 | if param.grad is not None: 111 | writer.add_histogram( 112 | "decoder/{}/grad".format(name), param.grad.detach().cpu().numpy(), epoch 113 | ) 114 | -------------------------------------------------------------------------------- /configs/Attention.yaml: -------------------------------------------------------------------------------- 1 | network: "Attention" 2 | input_size: 3 | height: 128 4 | width: 128 5 | SATRN: 6 | encoder: 7 | hidden_dim: 300 8 | filter_dim: 600 9 | layer_num: 6 10 | head_num: 8 11 | decoder: 12 | src_dim: 300 13 | hidden_dim: 128 14 | filter_dim: 512 15 | layer_num: 3 16 | head_num: 8 17 | Attention: 18 | src_dim: 512 19 | hidden_dim: 128 20 | embedding_dim: 128 21 | layer_num: 1 22 | cell_type: "LSTM" 23 | checkpoint: "" 24 | prefix: "./log/attention_50" 25 | 26 | data: 27 | train: 28 | - "/opt/ml/input/data/train_dataset/gt.txt" 29 | test: 30 | - "" 31 | token_paths: 32 | - "/opt/ml/input/data/train_dataset/tokens.txt" # 241 tokens 33 | dataset_proportions: # proportion of data to take from train (not test) 34 | - 1.0 35 | random_split: True # if True, random split from train files 36 | test_proportions: 0.2 # only if random_split is True 37 | crop: True 38 | rgb: 1 # 3 for color, 1 for greyscale 39 | 40 | batch_size: 96 41 | num_workers: 8 42 | num_epochs: 50 43 | print_epochs: 1 44 | dropout_rate: 0.1 45 | teacher_forcing_ratio: 0.5 46 | max_grad_norm: 2.0 47 | seed: 1234 48 | optimizer: 49 | optimizer: 'Adam' # Adam, Adadelta 50 | lr: 5e-4 # 1e-4 51 | weight_decay: 1e-4 52 | is_cycle: True 53 | 54 | patience: -1 # -1 for off 55 | save_best_only: False 56 | 57 | wandb: 58 | wandb: True 59 | run_name: "" -------------------------------------------------------------------------------- /configs/SATRN.yaml: -------------------------------------------------------------------------------- 1 | network: SATRN 2 | input_size: 3 | height: 48 4 | width: 192 5 | SATRN: 6 | encoder: 7 | hidden_dim: 300 8 | filter_dim: 1200 9 | layer_num: 6 10 | head_num: 8 11 | 12 | shallower_cnn: True # shallow CNN 13 | adaptive_gate: True # A2DPE 14 | conv_ff: True # locality-aware feedforward 15 | separable_ff: True # only if conv_ff is True 16 | decoder: 17 | src_dim: 300 18 | hidden_dim: 300 19 | filter_dim: 1200 20 | layer_num: 3 21 | head_num: 8 22 | 23 | checkpoint: "" 24 | prefix: "./log/satrn" 25 | 26 | data: 27 | train: 28 | - "/opt/ml/input/data/train_dataset/gt.txt" 29 | # - /opt/ml/input/data/train_dataset/custom_train.txt # for experiments 30 | test: 31 | - 32 | # - /opt/ml/input/data/train_dataset/custom_test.txt # for experiments 33 | token_paths: 34 | - "/opt/ml/input/data/train_dataset/tokens.txt" # 241 tokens 35 | dataset_proportions: # proportion of data to take from train (not test) 36 | - 1.0 37 | random_split: True # if True, random split from train files 38 | test_proportions: 0.2 # only if random_split is True 39 | crop: True 40 | rgb: 1 # 3 for color, 1 for greyscale 41 | 42 | batch_size: 16 43 | num_workers: 8 44 | num_epochs: 200 45 | print_epochs: 1 46 | dropout_rate: 0.1 47 | teacher_forcing_ratio: 0.5 48 | teacher_forcing_damp: 5e-3 # 0 to turn off 49 | max_grad_norm: 2.0 50 | seed: 1234 51 | optimizer: 52 | optimizer: AdamP # Adam, Adadelta 53 | lr: 5e-4 # 1e-4 54 | weight_decay: 1e-4 55 | selective_weight_decay: True 56 | is_cycle: True 57 | label_smoothing: 0.2 # 0 to off 58 | 59 | patience: 30 # -1 for off 60 | save_best_only: True 61 | 62 | fp16: True 63 | 64 | wandb: 65 | wandb: True 66 | run_name: ^____________^ -------------------------------------------------------------------------------- /data_tools/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | wget -P /opt/ml/input/data/ https://prod-aistages-public.s3.ap-northeast-2.amazonaws.com/app/Competitions/000043/data/train_dataset.zip 3 | cd /opt/ml/input/data 4 | unzip train_dataset.zip 5 | cd ~ 6 | -------------------------------------------------------------------------------- /data_tools/extract_tokens.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import sys 4 | 5 | 6 | def parse_symbols(truth): 7 | """Returns the unique tokens of the groundtrurh. 8 | 9 | Args: 10 | truth(string) : groundtruth 11 | 12 | Returns: 13 | unique_symbols(set): unique_symbols 14 | """ 15 | unique_symbols = set(truth.split()) 16 | return unique_symbols 17 | 18 | 19 | def create_tokens(groundtruth, output="tokens.txt"): 20 | """Save a unique tokens file for the ground trurh. 21 | 22 | Args: 23 | groundtruth (text file) : groundtruth file 24 | output (str, optional): output filename. Defaults to "tokens.txt". 25 | """ 26 | with open(groundtruth, "r") as fd: 27 | data = fd.read() 28 | 29 | unique_symbols = set() 30 | data = data.split("\n") 31 | data = [x.split("\t") for x in data] 32 | for _, truth in data: 33 | truth_symbols = parse_symbols(truth) 34 | unique_symbols = unique_symbols.union(truth_symbols) 35 | 36 | symbols = list(unique_symbols) 37 | symbols.sort() 38 | with open(output, "w") as output_fd: 39 | writer = csv.writer(output_fd, delimiter="\n") 40 | writer.writerow(symbols) 41 | 42 | 43 | if __name__ == "__main__": 44 | """ 45 | extract_tokens path/to/groundtruth.tsv [-o OUTPUT] 46 | """ 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument( 49 | "-o", 50 | "--output", 51 | dest="output", 52 | default="tokens.txt", 53 | help="Output path of the tokens text file", 54 | ) 55 | parser.add_argument("groundtruth", nargs=1, help="Ground truth TXT file") 56 | args = parser.parse_args() 57 | create_tokens(args.groundtruth[0], args.output) 58 | -------------------------------------------------------------------------------- /data_tools/train_test_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import random 5 | 6 | test_percent = 0.2 7 | output_dir = "gt-split" 8 | 9 | 10 | # Split the ground truth into train, test sets 11 | def split_gt(groundtruth, test_percent=0.2, data_num=None): 12 | """Split the ground truth into train, test sets 13 | 14 | Args: 15 | groundtruth (text file) : ground truth file 16 | test_percent (float) : represent the proportion of the dataset to include in the test split. Defaults to 0.2. 17 | data_num (int) : represents the absolute number of test samples. Defaults to None. 18 | 19 | Returns: 20 | train dataset 21 | test dataset 22 | """ 23 | with open(groundtruth, "r") as fd: 24 | data = fd.read() 25 | data = data.split('\n') 26 | data = [x.split('\t') for x in data] 27 | random.shuffle(data) 28 | if data_num: 29 | assert sum(data_num) < len(data) 30 | return data[:data_num[0]], data[data_num[0]:data_num[0] + data_num[1]] 31 | test_len = round(len(data) * test_percent) 32 | return data[test_len:], data[:test_len] # train, test 33 | 34 | 35 | def write_tsv(data, path): 36 | with open(path, "w") as fd: 37 | writer = csv.writer(fd, delimiter="\t") 38 | writer.writerows(data) 39 | 40 | 41 | def parse_args(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument( 44 | "-p", 45 | "--test-percent", 46 | dest="test_percent", 47 | default=test_percent, 48 | type=float, 49 | help="Percent of data to use for test [Default: {}]".format(test_percent) 50 | ) 51 | parser.add_argument( 52 | "-n", 53 | "--data_num", 54 | nargs=2, 55 | type=int, 56 | help="Number of train data and test data", 57 | ) 58 | parser.add_argument( 59 | "-i", 60 | "--input", 61 | dest="input", 62 | required=True, 63 | type=str, 64 | help="Path to input ground truth file", 65 | ) 66 | parser.add_argument( 67 | "-o", 68 | "--output-dir", 69 | dest="output_dir", 70 | default=output_dir, 71 | type=str, 72 | help="Directory to save the split ground truth files", 73 | ) 74 | return parser.parse_args() 75 | 76 | 77 | if __name__ == "__main__": 78 | options = parse_args() 79 | train_gt, test_gt = split_gt(options.input, options.test_percent, options.data_num) 80 | if not os.path.exists(options.output_dir): 81 | os.makedirs(options.output_dir) 82 | write_tsv(train_gt, os.path.join(options.output_dir, "train.txt")) 83 | write_tsv(test_gt, os.path.join(options.output_dir, "test.txt")) 84 | -------------------------------------------------------------------------------- /flags.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code from clovaai/SATRN 3 | """ 4 | import os 5 | import yaml 6 | import collections 7 | 8 | 9 | def dict_to_namedtuple(d): 10 | """ Convert dictionary to named tuple. 11 | """ 12 | FLAGSTuple = collections.namedtuple('FLAGS', sorted(d.keys())) 13 | 14 | for k, v in d.items(): 15 | 16 | if k == 'prefix': 17 | v = os.path.join('./', v) 18 | 19 | if type(v) is dict: 20 | d[k] = dict_to_namedtuple(v) 21 | 22 | elif type(v) is str: 23 | try: 24 | d[k] = eval(v) 25 | except: 26 | d[k] = v 27 | 28 | nt = FLAGSTuple(**d) 29 | 30 | return nt 31 | 32 | 33 | class Flags: 34 | """ Flags object. 35 | """ 36 | 37 | def __init__(self, config_file): 38 | try: 39 | with open(config_file, 'r') as f: 40 | d = yaml.safe_load(f) 41 | except: 42 | d = config_file 43 | 44 | 45 | self.flags = dict_to_namedtuple(d) 46 | 47 | def get(self): 48 | return self.flags -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from train import id_to_string 4 | from checkpoint import load_checkpoint 5 | from torchvision import transforms 6 | from dataset import LoadEvalDataset, collate_eval_batch 7 | from flags import Flags 8 | from utils import get_network 9 | import csv 10 | from torch.utils.data import DataLoader 11 | import argparse 12 | import random 13 | from tqdm import tqdm 14 | 15 | 16 | def main(parser): 17 | """Inference code 18 | """ 19 | is_cuda = torch.cuda.is_available() 20 | # load pretrained model checkpoint 21 | checkpoint = load_checkpoint(parser.checkpoint, cuda=is_cuda) 22 | options = Flags(checkpoint["configs"]).get() 23 | torch.manual_seed(options.seed) 24 | random.seed(options.seed) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | hardware = "cuda" if is_cuda else "cpu" 29 | device = torch.device(hardware) 30 | print("--------------------------------") 31 | print("Running {} on device {}\n".format(options.network, device)) 32 | 33 | model_checkpoint = checkpoint["model"] 34 | if model_checkpoint: 35 | print( 36 | "[+] Checkpoint\n", 37 | "Resuming from epoch : {}\n".format(checkpoint["epoch"]), 38 | ) 39 | print(options.input_size.height) 40 | 41 | # transform to be applied on a sample. 42 | transformed = transforms.Compose( 43 | [ 44 | transforms.Resize((options.input_size.height, options.input_size.width)), 45 | transforms.ToTensor(), 46 | ] 47 | ) 48 | 49 | dummy_gt = "\sin " * parser.max_sequence # set maximum inference sequence 50 | # make dataset from test folder 51 | root = os.path.join(os.path.dirname(parser.file_path), "images") 52 | with open(parser.file_path, "r") as fd: 53 | reader = csv.reader(fd, delimiter="\t") 54 | data = list(reader) 55 | test_data = [[os.path.join(root, x[0]), x[0], dummy_gt] for x in data] 56 | test_dataset = LoadEvalDataset( 57 | test_data, checkpoint["token_to_id"], checkpoint["id_to_token"], crop=False, transform=transformed, 58 | rgb=options.data.rgb 59 | ) 60 | test_data_loader = DataLoader( 61 | test_dataset, 62 | batch_size=parser.batch_size, 63 | shuffle=False, 64 | num_workers=options.num_workers, 65 | collate_fn=collate_eval_batch, 66 | ) 67 | 68 | print( 69 | "[+] Data\n", 70 | "The number of test samples : {}\n".format(len(test_dataset)), 71 | ) 72 | 73 | model = get_network( 74 | options.network, 75 | options, 76 | model_checkpoint, 77 | device, 78 | test_dataset, 79 | ) 80 | model.eval() 81 | results = [] 82 | for d in tqdm(test_data_loader): 83 | input = d["image"].to(device) 84 | expected = d["truth"]["encoded"].to(device) 85 | 86 | output = model(input, expected, False, 0.0) 87 | decoded_values = output.transpose(1, 2) 88 | _, sequence = torch.topk(decoded_values, 1, dim=1) 89 | sequence = sequence.squeeze(1) 90 | sequence_str = id_to_string(sequence, test_data_loader, do_eval=1) 91 | for path, predicted in zip(d["file_path"], sequence_str): 92 | results.append((path, predicted)) 93 | # save inference results as csv file 94 | os.makedirs(parser.output_dir, exist_ok=True) 95 | with open(os.path.join(parser.output_dir, "output.csv"), "w") as w: 96 | for path, predicted in results: 97 | w.write(path + "\t" + predicted + "\n") 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument( 103 | "--checkpoint", 104 | dest="checkpoint", 105 | default="./log/satrn/checkpoints/0015.pth", 106 | type=str, 107 | help="Path of checkpoint file", 108 | ) 109 | parser.add_argument( 110 | "--max_sequence", 111 | dest="max_sequence", 112 | default=230, 113 | type=int, 114 | help="maximun sequence when doing inference", 115 | ) 116 | parser.add_argument( 117 | "--batch_size", 118 | dest="batch_size", 119 | default=8, 120 | type=int, 121 | help="batch size when doing inference", 122 | ) 123 | 124 | eval_dir = os.environ.get('SM_CHANNEL_EVAL', '/opt/ml/input/data/') 125 | file_path = os.path.join(eval_dir, 'eval_dataset/input.txt') 126 | parser.add_argument( 127 | "--file_path", 128 | dest="file_path", 129 | default=file_path, 130 | type=str, 131 | help="file path when doing inference", 132 | ) 133 | 134 | output_dir = os.environ.get('SM_OUTPUT_DATA_DIR', 'submit') 135 | parser.add_argument( 136 | "--output_dir", 137 | dest="output_dir", 138 | default=output_dir, 139 | type=str, 140 | help="output directory", 141 | ) 142 | 143 | parser = parser.parse_args() 144 | main(parser) 145 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import editdistance 2 | import numpy as np 3 | 4 | def word_error_rate(predicted_outputs, ground_truths): 5 | """ Estimate Word_error_rate. 6 | Args: 7 | predicted_outputs(list) : result of model prediction 8 | ground_truths(list) : ground truth 9 | 10 | Returns: 11 | World Error rate(float) : Word error rate Estimated by Edit distance. 12 | """ 13 | sum_wer=0.0 14 | for output,ground_truth in zip(predicted_outputs,ground_truths): 15 | output=output.split(" ") 16 | ground_truth=ground_truth.split(" ") 17 | distance = editdistance.eval(output, ground_truth) 18 | length = max(len(output),len(ground_truth)) 19 | sum_wer+=(distance/length) 20 | return sum_wer/len(predicted_outputs) 21 | 22 | 23 | def sentence_acc(predicted_outputs, ground_truths): 24 | """ Estimate sentence_acc. 25 | Args: 26 | predicted_outputs(list) : result of model prediction 27 | ground_truths(list) : ground truth 28 | 29 | Returns: 30 | sentence_acc(float) : Acurracy between preicted_output and ground_truths 31 | """ 32 | correct_sentences=0 33 | for output,ground_truth in zip(predicted_outputs,ground_truths): 34 | if np.array_equal(output,ground_truth): 35 | correct_sentences+=1 36 | return correct_sentences/len(predicted_outputs) 37 | 38 | 39 | def get_worst_wer_img_path(img_path_list, predicted_outputs, ground_truths): 40 | """ Return Information of max word error rate Image 41 | Args: 42 | img_path_list(list) : list of image path 43 | predicted_outputs(list) : result of model prediction 44 | ground_truths(list) : ground truth 45 | 46 | Returns: 47 | image path(str) : Image path of worst error rate 48 | word error rate(float) : max word error rate 49 | ground truth(str) : Ground truth of max word error rate image 50 | predicted_output(str) : Prediction of model 51 | """ 52 | max_wer_ind = 0 53 | max_wer = 0 54 | 55 | i = 0 56 | for output, ground_truth in zip(predicted_outputs,ground_truths): 57 | output=output.split(" ") 58 | ground_truth=ground_truth.split(" ") 59 | 60 | distance = editdistance.eval(output, ground_truth) 61 | length = max(len(output), len(ground_truth)) 62 | cur_wer = (distance / length) 63 | if max_wer < cur_wer: 64 | max_wer = cur_wer 65 | max_wer_ind = i 66 | i+=1 67 | 68 | return img_path_list[max_wer_ind], max_wer, ground_truths[max_wer_ind], predicted_outputs[max_wer_ind] 69 | -------------------------------------------------------------------------------- /networks/README.md: -------------------------------------------------------------------------------- 1 | # Supported Models 2 | 3 | ## SATRN 4 | [Transformer](https://arxiv.org/abs/1706.03762)의 encoder-decoder 구조를 STR 테스크에 적합하게 변경한 모델입니다. [On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention](https://arxiv.org/abs/1910.04396)에서 제안되었으며, 주요 특징은 다음과 같습니다. 5 | 6 |
7 | 8 |
9 | 10 | ### Shallow CNN 11 | 2D feature map을 생성하는 CNN의 깊이를 크게 줄임으로써, self-attention encoder block이 spatial dependency를 더 잘 포착하도록 할 수 있습니다. `ShallowConvLayer`를 두 겹 쌓는 형태로 구현되었으며, config file에서 `SATRN.encoder.shallow_cnn` 을 `True`로 하여 설정(`False`일 때는 `DeepCNN300`)할 수 있습니다. 12 | 13 | ### Adaptive 2D positional encoding 14 | Transformer의 positional encoding을 2D로 확장하기 위한 방안입니다. 두 방향의 sinusoidal positional encoding을 weighted sum하는 형태로 구현하였으며, 이는 위 논문에서 제시한 바와 같습니다. 이 때, weighted sum의 'weight'는 이미지에 따라 두 방향 정보의 중요도를 조절하는 adaptive gate로서, `Linear-ReLU-Linear-sigmoid`의 구조를 가집니다. 코드에서 adaptive gate는 `AdaptiveGate`으로, positional encoding은 `AdaptivePositionalEncoding2D`으로 구현되어 있습니다. Config file에서 `SATRN.encoder.adaptive_gate`을 `True`로 하여 설정(`False`일 때는 `PositionalEncoding2D` (non-adaptive concat))할 수 있습니다. 15 | 16 | ### Locality-aware feedforward layer 17 | Transformer의 point-wise feedforward layer를 3x3 convolutional layer를 활용한 구조로 변경하여, short-term dependency를 더욱 효과적으로 포착할 수 있도록 했습니다. Config file에서 `SATRN.encoder.conv_ff`를 `True`로 하면 self-attention block의 feedforward layer를 `LocalityAwareFeedforward`로 설정(`False`일 때 `Feedforward`)할 수 있습니다. `SATRN.encoder.seprable_ff`가 `True`일 때 separable, `False`일 때 convolution으로 설정할 수 있습니다(아래 그림 참고). 18 |
19 | 20 |
21 | 22 | ## Attention 23 | [ASTER: An Attentional Scene Text Recognizer with Flexible Rectification](https://ieeexplore.ieee.org/document/8395027)에서 제안된 ASTER에서 Bi-LSTM을 제거한 구조입니다. CNN encoder와 RNN+attention decoder로 이루어져 있으며, decoder의 RNN은 LSTM 혹은 GRU로 설정할 수 있습니다(config file의 `Attention.encoder.cell_type`). 24 | 25 |
26 | 27 |
28 | -------------------------------------------------------------------------------- /networks/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class LabelSmoothingCrossEntropy(nn.Module): 6 | """ 7 | Label Smoothing Cross Entropy 8 | 9 | A Cross Entropy with Label Smooothing 10 | 11 | """ 12 | 13 | def __init__(self, eps: float = 0.1, reduction='mean', ignore_index=-100): 14 | """ 15 | Args: 16 | eps(float) :Rate of Label Smoothing 17 | reduction(str) : The way of reduction [mean, sum] 18 | ignore_index(int) : Index wants to ignore 19 | """ 20 | 21 | super(LabelSmoothingCrossEntropy, self).__init__() 22 | self.eps, self.reduction = eps, reduction 23 | self.ignore_index = ignore_index 24 | 25 | def forward(self, output, target, *args): 26 | 27 | output = output.transpose(1,2) 28 | 29 | pred = output.contiguous().view(-1, output.shape[-1]) 30 | target = target.to(pred.device).contiguous().view(-1) 31 | c = pred.size()[-1] 32 | 33 | log_preds = F.log_softmax(pred, dim=-1) 34 | ignore_target = target != self.ignore_index 35 | log_preds = log_preds * ignore_target[:, None] 36 | 37 | if self.reduction == 'sum': 38 | loss = -log_preds.sum() 39 | else: 40 | loss = -log_preds.sum(dim=-1) 41 | if self.reduction == 'mean': 42 | loss = loss.mean() 43 | 44 | return ( 45 | loss * self.eps / c + (1 - self.eps) * 46 | F.nll_loss( 47 | log_preds, 48 | target, 49 | reduction=self.reduction, 50 | ignore_index=self.ignore_index, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit_image==0.14.1 2 | opencv_python==3.4.4.19 3 | tqdm==4.28.1 4 | --find-links=https://download.pytorch.org/whl/torch_stable.html 5 | torch==1.7.1+cu101 6 | torchvision==0.8.2+cu101 7 | scipy==1.2.0 8 | numpy==1.15.4 9 | pillow==8.2.0 10 | tensorboardX==1.5 11 | editdistance==0.5.3 12 | python-dotenv==0.17.1 13 | wandb==0.10.30 14 | adamp==0.3.0 15 | python-dotenv==0.17.1 -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class RotateByDistribution(nn.Module): 7 | """ Image Rotation 8 | Rotate Image angle between (-34 ~ 34) or input distribution 9 | """ 10 | def __init__(self, distribution=None): 11 | super(RotateByDistribution, self).__init__() 12 | if distribution is None: 13 | self.distribution = torch.distributions.normal.Normal(0, 34) 14 | else: 15 | self.distribution = distribution 16 | 17 | def forward(self, img): 18 | degree = self.distribution.sample().item() 19 | return F.rotate(img, degree) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from copy import deepcopy 3 | import torch.optim as optim 4 | from adamp import AdamP 5 | 6 | from networks.Attention import Attention 7 | from networks.SATRN import SATRN 8 | 9 | 10 | def get_network( 11 | model_type, 12 | FLAGS, 13 | model_checkpoint, 14 | device, 15 | train_dataset, 16 | ): 17 | """Get network 18 | 19 | Args: 20 | model_type (str): Model name that wants to use. 21 | FLAGS (Flag): Configs of model. 22 | model_checkpoint (dict): model checkpoint. 23 | device (torch.device): Device type to use. 24 | train_dataset (list): train_dataset 25 | 26 | Returns: 27 | model : model 28 | """ 29 | model = None 30 | 31 | if model_type == "SATRN": 32 | model = SATRN(FLAGS, train_dataset, model_checkpoint).to(device) 33 | elif model_type == "CRNN": 34 | model = CRNN() 35 | elif model_type == "Attention": 36 | model = Attention(FLAGS, train_dataset, model_checkpoint).to(device) 37 | else: 38 | raise NotImplementedError 39 | 40 | return model 41 | 42 | 43 | def get_optimizer(optimizer, params, lr, weight_decay=None): 44 | """Get Optimizer 45 | 46 | Args: 47 | optimizer (optimizer): optimizer. 48 | params (optimizer.params): optimizer.params 49 | lr (optimizer.lr): optimizer LR 50 | weight_decay (float, optional): weight decay (L2 penalty). Defaults to None. 51 | 52 | Returns: 53 | optimizer: optimizer 54 | """ 55 | if optimizer == "AdamP": 56 | optimizer = AdamP(params, lr=lr) 57 | elif optimizer == "Adam": 58 | optimizer = optim.Adam(params, lr=lr) 59 | elif optimizer == "Adadelta": 60 | optim.Adadelta(params, lr=lr, weight_decay=weight_decay) 61 | else: 62 | raise NotImplementedError 63 | return optimizer 64 | 65 | 66 | def get_wandb_config(config_file): 67 | """Get Wandb config from config_file 68 | 69 | Args: 70 | config_file (str): config_file path 71 | Returns: 72 | config (dict): original config 73 | """ 74 | # load config file 75 | with open(config_file, 'r') as f: 76 | option = yaml.safe_load(f) 77 | config = deepcopy(option) 78 | 79 | # remove all except network 80 | keys = ["checkpoint", "input_size", "data", "optimizer", "wandb", "prefix"] 81 | for key in keys: 82 | del config[key] 83 | 84 | # modify some config key-value 85 | new_config = { 86 | "log_path": option['prefix'], 87 | "dataset_proportions": option['data']['dataset_proportions'], 88 | "test_proportions": option['data']['test_proportions'], 89 | "crop": option['data']['crop'], 90 | "rgb": "grayscale" if option['data']['rgb']==1 else "color", 91 | "input_size": (option['input_size']['height'], option['input_size']['width']), 92 | "optimizer": option['optimizer']['optimizer'], 93 | "learning_rate": option['optimizer']['lr'], 94 | "weight_decay": option['optimizer']['weight_decay'], 95 | "is_cycle": option['optimizer']['is_cycle'], 96 | } 97 | 98 | # merge 99 | config.update(new_config) 100 | 101 | # print log 102 | print("wandb save configs below:\n", list(config.keys())) 103 | 104 | return config -------------------------------------------------------------------------------- /vedastr_cstr/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | addict 4 | six 5 | torch>=1.6.* 6 | torchvision>=0.7.* 7 | Pillow>=8.2.0 8 | lmdb 9 | nltk 10 | terminaltables 11 | albumentations 12 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/deploy/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) 6 | 7 | import cv2 # noqa 402 8 | from volksdep.benchmark import benchmark # noqa 402 9 | 10 | from tools.deploy.utils import CALIBRATORS, CalibDataset, Metric, MetricDataset # noqa 402 11 | from vedastr.runners import TestRunner # noqa 402 12 | from vedastr.utils import Config # noqa 402 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Inference') 17 | parser.add_argument('config', type=str, help='config file path') 18 | parser.add_argument('checkpoint', type=str, help='checkpoint file path') 19 | parser.add_argument('image', type=str, help='sample image path') 20 | parser.add_argument( 21 | '--dtypes', 22 | default=('fp32', 'fp16', 'int8'), 23 | nargs='+', 24 | type=str, 25 | choices=['fp32', 'fp16', 'int8'], 26 | help='dtypes for benchmark') 27 | parser.add_argument( 28 | '--iters', default=100, type=int, help='iters for benchmark') 29 | parser.add_argument( 30 | '--calibration_images', 31 | default=None, 32 | type=str, 33 | help='images dir used when int8 in dtypes') 34 | parser.add_argument( 35 | '--calibration_modes', 36 | nargs='+', 37 | default=['entropy', 'entropy_2', 'minmax'], 38 | type=str, 39 | choices=['entropy_2', 'entropy', 'minmax'], 40 | help='calibration modes for benchmark') 41 | args = parser.parse_args() 42 | 43 | return args 44 | 45 | 46 | def main(): 47 | args = parse_args() 48 | 49 | cfg_path = args.config 50 | cfg = Config.fromfile(cfg_path) 51 | 52 | test_cfg = cfg['test'] 53 | infer_cfg = cfg['inference'] 54 | common_cfg = cfg['common'] 55 | 56 | runner = TestRunner(test_cfg, infer_cfg, common_cfg) 57 | assert runner.use_gpu, 'Please use gpu for benchmark.' 58 | runner.load_checkpoint(args.checkpoint) 59 | 60 | image = cv2.imread(args.image) 61 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 62 | aug = runner.transform(image=image, label='') 63 | image, dummy_label = aug['image'], aug['label'] # noqa 841 64 | image = image.unsqueeze(0) 65 | input_len = runner.converter.test_encode(1)[0] 66 | model = runner.model 67 | need_text = runner.need_text 68 | if need_text: 69 | shape = tuple(image.shape), tuple(input_len.shape) 70 | else: 71 | shape = tuple(image.shape) 72 | 73 | dtypes = args.dtypes 74 | iters = args.iters 75 | int8_calibrator = None 76 | if args.calibration_images: 77 | calib_dataset = CalibDataset(args.calibration_images, runner.converter, 78 | runner.transform, need_text) 79 | int8_calibrator = [ 80 | CALIBRATORS[mode](dataset=calib_dataset) 81 | for mode in args.calibration_modes 82 | ] 83 | dataset = runner.test_dataloader.dataset 84 | dataset = MetricDataset(dataset, runner.converter, need_text) 85 | metric = Metric(runner.metric, runner.converter) 86 | benchmark( 87 | model, 88 | shape, 89 | dtypes=dtypes, 90 | iters=iters, 91 | int8_calibrator=int8_calibrator, 92 | dataset=dataset, 93 | metric=metric) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/deploy/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) 6 | 7 | import cv2 # noqa 402 8 | from volksdep.converters import save, torch2onnx, torch2trt # noqa 402 9 | 10 | from tools.deploy.utils import CALIBRATORS, CalibDataset # noqa 402 11 | from vedastr.runners import InferenceRunner # noqa 402 12 | from vedastr.utils import Config # noqa 402 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Inference') 17 | parser.add_argument('config', type=str, help='config file path') 18 | parser.add_argument('checkpoint', type=str, help='checkpoint file path') 19 | parser.add_argument('image', type=str, help='sample image path') 20 | parser.add_argument('out', type=str, help='output model file name') 21 | parser.add_argument( 22 | '--onnx', default=False, action='store_true', help='convert to onnx') 23 | parser.add_argument( 24 | '--max_batch_size', 25 | default=1, 26 | type=int, 27 | help='max batch size for trt engine execution') 28 | parser.add_argument( 29 | '--max_workspace_size', 30 | default=1, 31 | type=int, 32 | help='max workspace size for building trt engine') 33 | parser.add_argument( 34 | '--fp16', 35 | default=False, 36 | action='store_true', 37 | help='convert to trt engine with fp16 mode') 38 | parser.add_argument( 39 | '--int8', 40 | default=False, 41 | action='store_true', 42 | help='convert to trt engine with int8 mode') 43 | parser.add_argument( 44 | '--calibration_mode', 45 | default='entropy_2', 46 | type=str, 47 | choices=['entropy_2', 'entropy', 'minmax']) 48 | parser.add_argument( 49 | '--calibration_images', 50 | default=None, 51 | type=str, 52 | help='images dir used when int8 mode is True') 53 | args = parser.parse_args() 54 | 55 | return args 56 | 57 | 58 | def main(): 59 | args = parse_args() 60 | out_name = args.out 61 | 62 | cfg_path = args.config 63 | cfg = Config.fromfile(cfg_path) 64 | 65 | infer_cfg = cfg['inference'] 66 | common_cfg = cfg.get('common') 67 | 68 | runner = InferenceRunner(infer_cfg, common_cfg) 69 | assert runner.use_gpu, 'Please use valid gpu to export model.' 70 | runner.load_checkpoint(args.checkpoint) 71 | 72 | image = cv2.imread(args.image) 73 | 74 | aug = runner.transform(image=image, label='') 75 | image, label = aug['image'], aug['label'] # noqa 841 76 | image = image.unsqueeze(0).cuda() 77 | dummy_input = (image, runner.converter.test_encode([''])[0]) 78 | model = runner.model.cuda().eval() 79 | need_text = runner.need_text 80 | if not need_text: 81 | dummy_input = dummy_input[0] 82 | 83 | if args.onnx: 84 | runner.logger.info('Convert to onnx model') 85 | torch2onnx(model, dummy_input, out_name) 86 | else: 87 | max_batch_size = args.max_batch_size 88 | max_workspace_size = args.max_workspace_size 89 | fp16_mode = args.fp16 90 | int8_mode = args.int8 91 | int8_calibrator = None 92 | if int8_mode: 93 | runner.logger.info('Convert to trt engine with int8') 94 | if args.calibration_images: 95 | runner.logger.info( 96 | 'Use calibration with mode {} and data {}'.format( 97 | args.calibration_mode, args.calibration_images)) 98 | dataset = CalibDataset(args.calibration_images, 99 | runner.converter, runner.transform, 100 | need_text) 101 | int8_calibrator = CALIBRATORS[args.calibration_mode]( 102 | dataset=dataset) 103 | else: 104 | runner.logger.info('Use default calibration mode and data') 105 | elif fp16_mode: 106 | runner.logger.info('Convert to trt engine with fp16') 107 | else: 108 | runner.logger.info('Convert to trt engine with fp32') 109 | trt_model = torch2trt( 110 | model, 111 | dummy_input, 112 | max_batch_size=max_batch_size, 113 | max_workspace_size=max_workspace_size, 114 | fp16_mode=fp16_mode, 115 | int8_mode=int8_mode, 116 | int8_calibrator=int8_calibrator) 117 | save(trt_model, out_name) 118 | runner.logger.info( 119 | 'Convert successfully, save model to {}'.format(out_name)) 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/deploy/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import CALIBRATORS, CalibDataset, Metric, MetricDataset # noqa 401 2 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/deploy/utils/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | from volksdep.calibrators import (EntropyCalibrator, EntropyCalibrator2, 5 | MinMaxCalibrator) 6 | from volksdep.datasets import Dataset 7 | from volksdep.metrics import Metric as BaseMetric 8 | 9 | CALIBRATORS = { 10 | 'entropy': EntropyCalibrator, 11 | 'entropy_2': EntropyCalibrator2, 12 | 'minmax': MinMaxCalibrator, 13 | } 14 | 15 | 16 | class CalibDataset(Dataset): 17 | 18 | def __init__(self, images_dir, converter, transform=None, need_text=False): 19 | super(CalibDataset, self).__init__() 20 | 21 | self.root = images_dir 22 | self.samples = os.listdir(images_dir) 23 | self.converter = converter 24 | self.transform = transform 25 | self.need_text = need_text 26 | 27 | def __getitem__(self, idx): 28 | image_file = os.path.join(self.root, self.samples[idx]) 29 | image = Image.open(image_file) 30 | if self.transform: 31 | image, _ = self.transform(image=image, label='') 32 | label = self.converter.test_encode(1) 33 | if self.need_text: 34 | return image, label 35 | else: 36 | return image 37 | 38 | def __len__(self): 39 | return len(self.samples) 40 | 41 | 42 | class MetricDataset(Dataset): 43 | 44 | def __init__(self, dataset, converter, need_text): 45 | super(MetricDataset, self).__init__() 46 | self.dataset = dataset 47 | self.converter = converter 48 | self.need_text = need_text 49 | 50 | def __getitem__(self, idx): 51 | image, label = self.dataset[idx] 52 | label_input, _, _ = self.converter.test_encode(1) 53 | _, _, label_target = self.converter.train_encode([label]) 54 | 55 | if self.need_text: 56 | return (image, label_input[0]), label 57 | else: 58 | return image, label 59 | 60 | def __len__(self): 61 | return len(self.dataset) 62 | 63 | 64 | class Metric(BaseMetric): 65 | 66 | def __init__(self, metric, converter): 67 | self.metric = metric 68 | self.converter = converter 69 | 70 | def decode(self, preds): 71 | indexes = np.argmax(preds, 2) 72 | pred_str = self.converter.decode(indexes) 73 | 74 | return pred_str 75 | 76 | def __call__(self, preds, targets): 77 | self.metric.reset() 78 | pred_str = self.decode(preds) 79 | 80 | self.metric.measure(pred_str, None, targets) 81 | res = self.metric.result 82 | 83 | return ', '.join(['{}: {:.4f}'.format(k, v) for k, v in res.items()]) 84 | 85 | def __str__(self): 86 | return self.metric.__class__.__name__.lower() 87 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if (($# < 3)); then 4 | echo "Uasage: bash tools/dist_test.sh config_file checkpoint gpus_to_use" 5 | exit 1 6 | fi 7 | 8 | CONFIG="$1" 9 | CHECKPOINT="$2" 10 | GPUS="$3" 11 | 12 | IFS=', ' read -r -a gpus <<<"${GPUS}" 13 | NGPUS="${#gpus[@]}" 14 | PORT="$((29400 + RANDOM % 100))" 15 | 16 | export CUDA_VISIBLE_DEVICES=${GPUS} 17 | 18 | PYTHONPATH="$(dirname "$0")/..":${PYTHONPATH} \ 19 | python -m torch.distributed.launch \ 20 | --nproc_per_node="${NGPUS}" \ 21 | --master_port=${PORT} \ 22 | "$(dirname "$0")"/test.py \ 23 | "$CONFIG" \ 24 | "$CHECKPOINT" \ 25 | --distribute \ 26 | "${@:4}" 27 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if (($# < 2)); then 4 | echo "Uasage: bash tools/dist_train.sh config_file gpus_to_use" 5 | exit 1 6 | fi 7 | CONFIG="$1" 8 | GPUS="$2" 9 | 10 | IFS=', ' read -r -a gpus <<<"${GPUS}" 11 | NGPUS="${#gpus[@]}" 12 | PORT="$((29400 + RANDOM % 100))" 13 | 14 | export CUDA_VISIBLE_DEVICES=${GPUS} 15 | 16 | PYTHONPATH="$(dirname "$0")/..":${PYTHONPATH} \ 17 | python -m torch.distributed.launch \ 18 | --nproc_per_node="${NGPUS}" \ 19 | --master_port=${PORT} \ 20 | "$(dirname "$0")"/train.py \ 21 | "$CONFIG" \ 22 | --distribute \ 23 | "${@:3}" 24 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) 6 | 7 | import cv2 # noqa 402 8 | 9 | from vedastr.runners import InferenceRunner # noqa 402 10 | from vedastr.utils import Config # noqa 402 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='Inference') 15 | parser.add_argument('config', type=str, help='Config file path') 16 | parser.add_argument('checkpoint', type=str, help='Checkpoint file path') 17 | parser.add_argument('image', type=str, help='input image path') 18 | args = parser.parse_args() 19 | 20 | return args 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | 26 | cfg_path = args.config 27 | cfg = Config.fromfile(cfg_path) 28 | 29 | inference_cfg = cfg['inference'] 30 | common_cfg = cfg.get('common') 31 | 32 | runner = InferenceRunner(inference_cfg, common_cfg) 33 | runner.load_checkpoint(args.checkpoint) 34 | if os.path.isfile(args.image): 35 | images = [args.image] 36 | else: 37 | images = [ 38 | os.path.join(args.image, name) for name in os.listdir(args.image) 39 | ] 40 | for img in images: 41 | image = cv2.imread(img) 42 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 43 | pred_str, probs = runner(image) 44 | runner.logger.info('Text in {} is:\t {} '.format(pred_str, img)) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) 6 | 7 | from vedastr.runners import TestRunner # noqa 402 8 | from vedastr.utils import Config # noqa 402 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description='Test.') 13 | parser.add_argument('config', type=str, help='Config file path') 14 | parser.add_argument('checkpoint', type=str, help='Checkpoint file path') 15 | parser.add_argument('--distribute', default=False, action='store_true') 16 | parser.add_argument('--local_rank', type=int, default=0) 17 | args = parser.parse_args() 18 | if 'LOCAL_RANK' not in os.environ: 19 | os.environ['LOCAL_RANK'] = str(args.local_rank) 20 | return args 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | 26 | cfg_path = args.config 27 | cfg = Config.fromfile(cfg_path) 28 | 29 | _, fullname = os.path.split(cfg_path) 30 | fname, ext = os.path.splitext(fullname) 31 | 32 | root_workdir = cfg.pop('root_workdir') 33 | workdir = os.path.join(root_workdir, fname) 34 | os.makedirs(workdir, exist_ok=True) 35 | 36 | test_cfg = cfg['test'] 37 | inference_cfg = cfg['inference'] 38 | common_cfg = cfg['common'] 39 | common_cfg['workdir'] = workdir 40 | common_cfg['distribute'] = args.distribute 41 | 42 | runner = TestRunner(test_cfg, inference_cfg, common_cfg) 43 | runner.load_checkpoint(args.checkpoint) 44 | runner() 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /vedastr_cstr/tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import sys 5 | 6 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../")) 7 | 8 | from vedastr.runners import TrainRunner # noqa 402 9 | from vedastr.utils import Config # noqa 402 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description="Train.") 14 | parser.add_argument("config", type=str, help="config file path") 15 | parser.add_argument("--distribute", default=False, action="store_true") 16 | parser.add_argument("--local_rank", type=int, default=0) 17 | args = parser.parse_args() 18 | if "LOCAL_RANK" not in os.environ: 19 | os.environ["LOCAL_RANK"] = str(args.local_rank) 20 | return args 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | 26 | cfg_path = args.config 27 | cfg = Config.fromfile(cfg_path) 28 | 29 | _, fullname = os.path.split(cfg_path) 30 | fname, ext = os.path.splitext(fullname) 31 | 32 | root_workdir = cfg.pop("root_workdir") 33 | workdir = os.path.join(root_workdir, fname) 34 | os.makedirs(workdir, exist_ok=True) 35 | # copy corresponding cfg to workdir 36 | shutil.copy(cfg_path, os.path.join(workdir, os.path.basename(cfg_path))) 37 | 38 | train_cfg = cfg["train"] 39 | inference_cfg = cfg["inference"] 40 | common_cfg = cfg["common"] 41 | common_cfg["workdir"] = workdir 42 | common_cfg["distribute"] = args.distribute 43 | 44 | runner = TrainRunner(train_cfg, inference_cfg, common_cfg) 45 | runner() 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pstage-ocr-team6/ocr-teamcode/86d5070e8f907571a47967d64facaee246d92a35/vedastr_cstr/vedastr/__init__.py -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/converter/__init__.py: -------------------------------------------------------------------------------- 1 | from .attn_converter import AttnConverter # noqa 401 2 | from .builder import build_converter # noqa 401 3 | from .ctc_converter import CTCConverter # noqa 401 4 | from .fc_converter import FCConverter # noqa 401 5 | from .custom_converter import CustomConverter -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/converter/attn_converter.py: -------------------------------------------------------------------------------- 1 | # modify from clovaai 2 | 3 | import torch 4 | 5 | from .base_convert import BaseConverter 6 | from .registry import CONVERTERS 7 | 8 | 9 | @CONVERTERS.register_module 10 | class AttnConverter(BaseConverter): 11 | 12 | def __init__(self, character, batch_max_length, go_last=False): 13 | list_character = list(character) 14 | self.batch_max_length = batch_max_length + 1 15 | if go_last: 16 | list_token = ['[s]', '[GO]'] 17 | character = list_character + list_token 18 | else: 19 | list_token = ['[GO]', '[s]'] 20 | character = list_token + list_character 21 | super(AttnConverter, self).__init__(character=character) 22 | self.ignore_index = self.dict['[GO]'] 23 | 24 | def train_encode(self, text): 25 | length = [len(s) + 1 for s in text] 26 | batch_text = torch.LongTensor(len(text), self.batch_max_length + 1).fill_(self.ignore_index) # noqa 501 27 | for idx, t in enumerate(text): 28 | text = list(t) 29 | text.append('[s]') 30 | text = [self.dict[char] for char in text] 31 | batch_text[idx][1:1 + len(text)] = torch.LongTensor(text) 32 | batch_text_input = batch_text[:, :-1] 33 | batch_text_target = batch_text[:, 1:] 34 | 35 | return batch_text_input, torch.IntTensor(length), batch_text_target 36 | 37 | def decode(self, text_index): 38 | texts = [] 39 | batch_size = text_index.shape[0] 40 | for index in range(batch_size): 41 | text = ''.join([self.character[i] for i in text_index[index, :]]) 42 | text = text[:text.find('[s]')] 43 | texts.append(text) 44 | 45 | return texts 46 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/converter/base_convert.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | 5 | from .registry import CONVERTERS 6 | 7 | 8 | @CONVERTERS.register_module 9 | class BaseConverter(object): 10 | def __init__(self, character): 11 | self.character = list(character) 12 | self.dict = {} 13 | for i, char in enumerate(self.character): 14 | self.dict[char] = i 15 | self.ignore_index = None 16 | 17 | @abc.abstractmethod 18 | def train_encode(self, *args, **kwargs): 19 | '''encode text in train phase''' 20 | 21 | def test_encode(self, text): 22 | if isinstance(text, (list, tuple)): 23 | num = len(text) 24 | elif isinstance(text, int): 25 | num = text 26 | else: 27 | raise TypeError( 28 | f'Type of text should in (list, tuple, int) ' 29 | f'but got {type(text)}' 30 | ) 31 | ignore_index = self.ignore_index 32 | if ignore_index is None: 33 | ignore_index = 0 34 | batch_text = torch.LongTensor(num, 1).fill_(ignore_index) 35 | length = [1 for i in range(num)] 36 | 37 | return batch_text, torch.IntTensor(length), batch_text 38 | 39 | @abc.abstractmethod 40 | def decode(self, *args, **kwargs): 41 | '''decode label to text in train and test phase''' 42 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/converter/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import CONVERTERS 3 | 4 | 5 | def build_converter(cfg, default_args=None): 6 | converter = build_from_cfg(cfg, CONVERTERS, default_args=default_args) 7 | 8 | return converter 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/converter/ctc_converter.py: -------------------------------------------------------------------------------- 1 | # modify from clovaai 2 | 3 | import torch 4 | 5 | from .base_convert import BaseConverter 6 | from .registry import CONVERTERS 7 | 8 | 9 | @CONVERTERS.register_module 10 | class CTCConverter(BaseConverter): 11 | def __init__(self, character: str, batch_max_length: int): 12 | list_token = ['[blank]'] 13 | list_character = list(character) 14 | self.batch_max_length = batch_max_length 15 | super(CTCConverter, 16 | self).__init__(character=list_token + list_character) 17 | 18 | def train_encode(self, text: list): 19 | length = [len(s) for s in text] 20 | batch_text = torch.LongTensor(len(text), 21 | self.batch_max_length).fill_(0) 22 | for i, t in enumerate(text): 23 | text = list(t) 24 | text = [self.dict[char] for char in text] 25 | batch_text[i][:len(text)] = torch.LongTensor(text) 26 | 27 | return batch_text, torch.IntTensor(length), batch_text 28 | 29 | def decode(self, text_index): 30 | texts = [] 31 | batch_size = text_index.shape[0] 32 | length = text_index.shape[1] 33 | for i in range(batch_size): 34 | t = text_index[i] 35 | char_list = [] 36 | for idx in range(length): 37 | if t[idx] != 0 and (not (idx > 0 and t[idx - 1] == t[idx])): 38 | char_list.append(self.character[t[idx]]) 39 | text = ''.join(char_list) 40 | texts.append(text) 41 | return texts 42 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/converter/custom_converter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .registry import CONVERTERS 4 | from .base_convert import BaseConverter 5 | 6 | # special token 정의 7 | START = "" # 문장 시작 8 | END = "" # 문장 끝 9 | PAD = "" # 패딩 10 | SPECIAL_TOKENS = [START, END, PAD] 11 | 12 | def load_vocab(tokens_paths): 13 | ''' 14 | 토큰 txt 파일을 불러와 토큰-id 딕셔너리 생성 15 | - `token_paths`: 토큰 txt 파일 경로 16 | ''' 17 | tokens = [] # 토큰 저장 공간 18 | tokens.extend(SPECIAL_TOKENS) # special 토큰 추가 19 | 20 | # 파일을 불러와 토큰 추출 21 | for tokens_file in tokens_paths: 22 | with open(tokens_file, "r") as fd: 23 | reader = fd.read() 24 | for token in reader.split("\n"): 25 | if token not in tokens: 26 | tokens.append(token) 27 | tokens.remove('') # 빈 문자열 제거 28 | token_to_id = {tok: i for i, tok in enumerate(tokens)} # 토큰 to id 29 | id_to_token = {i: tok for i, tok in enumerate(tokens)} # id to 토큰 30 | return tokens, token_to_id, id_to_token 31 | 32 | @CONVERTERS.register_module 33 | class CustomConverter(BaseConverter): 34 | 35 | def __init__(self, character, batch_max_length=25): 36 | token_path = ["/opt/ml/input/data/train_dataset/tokens.txt"] # 토큰 파일 경로 37 | self.character, self.dict, self.id_to_token = load_vocab(token_path) # 토큰 파싱 38 | self.batch_max_length = batch_max_length # 최대 시퀀스 길이 설정 39 | self.ignore_index = self.dict[PAD] # 무시할 토큰 지정 40 | 41 | 42 | def train_encode(self, text_list): 43 | ''' 44 | 배치 크기의 시퀀스들을 인코딩 45 | - `text_list`: 배치 크기의 시퀀스 리스트 46 | ''' 47 | length = [len(s) + 1 for s in text_list] # batch 안에 있는 ground truth의 길이 저장 48 | batch_text = torch.LongTensor(len(text_list), self.batch_max_length+1).fill_(self.ignore_index) # 인코딩 행렬 생성 (b, max_len+1) 49 | for i, t in enumerate(text_list): 50 | text = t.split(' ') 51 | text.append(END) # EOS 추가 52 | text = [self.dict[char] for char in text] # 토큰 to id로 변환 53 | batch_text[i][:len(text)] = torch.LongTensor(text) # 행렬에 저장 54 | batch_text_input = batch_text 55 | batch_text_target = batch_text 56 | 57 | return batch_text_input, torch.IntTensor(length), batch_text_target 58 | 59 | 60 | def decode(self, text_index): 61 | ''' 62 | 배치 크기의 id 시퀀스를 토큰 시퀀스로 디코딩 63 | - `text_index`: 배치 크기의 id 시퀀스 리스트 64 | ''' 65 | texts = [] 66 | batch_size = text_index.shape[0] 67 | for index in range(batch_size): 68 | text = ' '.join([self.id_to_token[int(i)] for i in text_index[index, :]]) 69 | text = text[:text.find(END)] # eos까지만 가져오기 70 | texts.append(text) 71 | 72 | return texts 73 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/converter/fc_converter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_convert import BaseConverter 4 | from .registry import CONVERTERS 5 | 6 | 7 | @CONVERTERS.register_module 8 | class FCConverter(BaseConverter): 9 | 10 | def __init__(self, character, batch_max_length=25): 11 | 12 | list_token = ['[s]'] 13 | ignore_token = ['[ignore]'] 14 | list_character = list(character) 15 | self.batch_max_length = batch_max_length + 1 16 | super(FCConverter, self).__init__(character=list_token + list_character + ignore_token) # noqa 501 17 | self.ignore_index = self.dict[ignore_token[0]] 18 | 19 | def train_encode(self, text): 20 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 21 | batch_text = torch.LongTensor(len(text), self.batch_max_length).fill_(self.ignore_index) # noqa 501 22 | for i, t in enumerate(text): 23 | text = list(t) 24 | text.append('[s]') 25 | text = [self.dict[char] for char in text] 26 | batch_text[i][:len(text)] = torch.LongTensor(text) 27 | batch_text_input = batch_text 28 | batch_text_target = batch_text 29 | 30 | return batch_text_input, torch.IntTensor(length), batch_text_target 31 | 32 | def decode(self, text_index): 33 | texts = [] 34 | batch_size = text_index.shape[0] 35 | for index in range(batch_size): 36 | text = ''.join([self.character[i] for i in text_index[index, :]]) 37 | text = text[:text.find('[s]')] 38 | texts.append(text) 39 | 40 | return texts 41 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/converter/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | CONVERTERS = Registry('convert') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/criteria/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_criterion # noqa 401 2 | from .cross_entropy_loss import CrossEntropyLoss # noqa 401 3 | from .ctc_loss import CTCLoss # noqa 401 4 | from .label_smooth_cross_entropy_loss import LabelSmoothingCrossEntropy # noqa 401 5 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/criteria/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import CRITERIA 3 | 4 | 5 | def build_criterion(cfg): 6 | criterion = build_from_cfg(cfg, CRITERIA, src='registry') 7 | 8 | return criterion 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/criteria/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .registry import CRITERIA 4 | 5 | 6 | @CRITERIA.register_module 7 | class CrossEntropyLoss(nn.Module): 8 | 9 | def __init__(self, 10 | weight=None, 11 | size_average=None, 12 | ignore_index=-100, 13 | reduce=None, 14 | reduction='mean'): 15 | super(CrossEntropyLoss, self).__init__() 16 | self.criteron = nn.CrossEntropyLoss( 17 | weight=weight, 18 | size_average=size_average, 19 | ignore_index=ignore_index, 20 | reduce=reduce, 21 | reduction=reduction) 22 | 23 | def forward(self, pred, target, *args): 24 | return self.criteron(pred.contiguous().view(-1, pred.shape[-1]), 25 | target.to(pred.device).contiguous().view(-1)) 26 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/criteria/ctc_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .registry import CRITERIA 5 | 6 | 7 | @CRITERIA.register_module 8 | class CTCLoss(nn.Module): 9 | 10 | def __init__(self, zero_infinity=True, blank=0, reduction='mean'): 11 | super(CTCLoss, self).__init__() 12 | self.criterion = nn.CTCLoss( 13 | zero_infinity=zero_infinity, blank=blank, reduction=reduction) 14 | 15 | def forward(self, pred, target, target_length, batch_size): 16 | pred = pred.log_softmax(2) 17 | input_lengths = torch.full( 18 | size=(batch_size, ), fill_value=pred.size(1), dtype=torch.long) 19 | pred_ = pred.permute(1, 0, 2) 20 | cost = self.criterion( 21 | log_probs=pred_, 22 | targets=target.to(pred.device), 23 | input_lengths=input_lengths.to(pred.device), 24 | target_lengths=target_length.to(pred.device)) 25 | return cost 26 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/criteria/label_smooth_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # modify from fast.ai 2 | # https://github.com/fastai/fastai/blob/8013797e05f0ae0d771d60ecf7cf524da591503c/fastai/layers.py#L300 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import Counter 6 | from .registry import CRITERIA 7 | 8 | 9 | @CRITERIA.register_module 10 | class LabelSmoothingCrossEntropy(nn.Module): 11 | 12 | def __init__(self, eps: float = 0.1, reduction='mean', ignore_index=-100): 13 | super(LabelSmoothingCrossEntropy, self).__init__() 14 | self.eps, self.reduction = eps, reduction 15 | self.ignore_index = ignore_index 16 | 17 | def forward(self, output, target, *args): 18 | pred = output.contiguous().view(-1, output.shape[-1]) 19 | target = target.to(pred.device).contiguous().view(-1) 20 | 21 | c = pred.size()[-1] 22 | log_preds = F.log_softmax(pred, dim=-1) 23 | 24 | # ignore index for smooth label 25 | ignore_target = target != self.ignore_index 26 | log_preds = log_preds * ignore_target[:, None] 27 | 28 | if self.reduction == 'sum': 29 | loss = -log_preds.sum() 30 | else: 31 | loss = -log_preds.sum(dim=-1) 32 | if self.reduction == 'mean': 33 | loss = loss.mean() 34 | return loss * self.eps / c + (1 - self.eps) * \ 35 | F.nll_loss(log_preds, target, reduction=self.reduction, 36 | ignore_index=self.ignore_index) 37 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/criteria/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | CRITERIA = Registry('criterion') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_dataloader # noqa 401 2 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/dataloaders/builder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as tud 2 | 3 | from vedastr.utils import WorkerInit, build_from_cfg, get_dist_info 4 | from .registry import DATALOADERS 5 | 6 | 7 | def build_dataloader(distributed, 8 | num_gpus, 9 | cfg, 10 | default_args: dict = None, 11 | seed=None): 12 | cfg_ = cfg.copy() 13 | 14 | samples_per_gpu = cfg_.pop('samples_per_gpu') 15 | workers_per_gpu = cfg_.pop('workers_per_gpu') 16 | 17 | if distributed: 18 | batch_size = samples_per_gpu 19 | num_workers = workers_per_gpu 20 | else: 21 | batch_size = num_gpus * samples_per_gpu 22 | num_workers = num_gpus * workers_per_gpu 23 | 24 | cfg_.update({'batch_size': batch_size, 'num_workers': num_workers}) 25 | 26 | dataloaders = {} 27 | 28 | # TODO, other implementations 29 | if DATALOADERS.get(cfg['type']): 30 | packages = DATALOADERS 31 | src = 'registry' 32 | else: 33 | packages = tud 34 | src = 'module' 35 | 36 | # build different dataloaders for different datasets 37 | if isinstance(default_args.get('dataset'), list): 38 | for idx, ds in enumerate(default_args['dataset']): 39 | assert isinstance(ds, tud.Dataset) 40 | if default_args.get('sampler'): 41 | sp = default_args['sampler'][idx] 42 | else: 43 | sp = None 44 | dataloader = build_from_cfg( 45 | cfg_, packages, dict(dataset=ds, sampler=sp), src=src) 46 | if hasattr(ds, 'root'): 47 | name = getattr(ds, 'root') 48 | else: 49 | name = str(idx) 50 | dataloaders[name] = dataloader 51 | else: 52 | rank, _ = get_dist_info() 53 | worker_init_fn = WorkerInit( 54 | num_workers=num_workers, rank=rank, seed=seed, epoch=0) 55 | default_args['worker_init_fn'] = worker_init_fn 56 | dataloaders = build_from_cfg( 57 | cfg_, packages, default_args, src=src) # build a single dataloader 58 | 59 | return dataloaders 60 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/dataloaders/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | DATALOADERS = Registry('dataloader') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/dataloaders/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .balance_sampler import BalanceSampler # noqa 401 2 | from .builder import build_sampler # noqa 401 3 | from .default_sampler import DefaultSampler # noqa 401 4 | from .dist_balance_sampler import BalanceSampler # noqa 401, 811 5 | from .dist_default_sampler import DefaultSampler # noqa 401, 811 6 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/dataloaders/samplers/builder.py: -------------------------------------------------------------------------------- 1 | from .registry import DISTSAMPLER, SAMPLER 2 | from ...utils import build_from_cfg 3 | 4 | 5 | def build_sampler(distributed, cfg, default_args=None): 6 | if distributed: 7 | sampler = build_from_cfg(cfg, DISTSAMPLER, default_args) 8 | else: 9 | sampler = build_from_cfg(cfg, SAMPLER, default_args) 10 | 11 | return sampler 12 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/dataloaders/samplers/default_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Sampler 3 | 4 | from .registry import SAMPLER 5 | 6 | 7 | @SAMPLER.register_module 8 | class DefaultSampler(Sampler): 9 | """Default non-distributed sampler.""" 10 | 11 | def __init__(self, 12 | dataset, 13 | shuffle: bool = True, 14 | **kwargs 15 | ): 16 | self.dataset = dataset 17 | self.shuffle = shuffle 18 | 19 | def __iter__(self): 20 | if self.shuffle: 21 | return iter(torch.randperm(len(self.dataset)).tolist()) 22 | else: 23 | return iter(range(len(self.dataset))) 24 | 25 | def __len__(self): 26 | return len(self.dataset) 27 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/dataloaders/samplers/dist_default_sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DistributedSampler 2 | 3 | from .registry import DISTSAMPLER 4 | from ...utils import get_dist_info 5 | 6 | 7 | @DISTSAMPLER.register_module 8 | class DefaultSampler(DistributedSampler): 9 | """Default distributed sampler.""" 10 | 11 | def __init__(self, 12 | dataset, 13 | shuffle: bool = True, 14 | seed=0, 15 | drop_last=False, 16 | **kwargs 17 | ): 18 | if seed is None: 19 | seed = 0 20 | rank, num_replicas = get_dist_info() 21 | super().__init__(dataset=dataset, 22 | num_replicas=num_replicas, 23 | rank=rank, 24 | shuffle=shuffle, 25 | seed=seed, 26 | drop_last=drop_last, 27 | ) 28 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/dataloaders/samplers/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | SAMPLER = Registry('sampler') 4 | DISTSAMPLER = Registry('dist_sampler') 5 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_datasets # noqa 401 2 | from .concat_dataset import ConcatDatasets # noqa 401 3 | from .fold_dataset import FolderDataset # noqa 401 4 | from .lmdb_dataset import LmdbDataset # noqa 401 5 | from .paste_dataset import PasteDataset # noqa 401 6 | from .txt_datasets import TxtDataset # noqa 401 7 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/base.py: -------------------------------------------------------------------------------- 1 | # modify from clovaai 2 | 3 | import logging 4 | import os 5 | import re 6 | 7 | import cv2 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class BaseDataset(Dataset): 12 | """ 13 | Args: 14 | root (str): The dir path of image files. 15 | transform: Transformation for images, which will be passed 16 | automatically if you set transform cfg correctly in 17 | configure file. 18 | character (str): The character will be used. We will filter the 19 | sample based on the charatcer. 20 | batch_max_length (int): The max allowed length of the text 21 | after filter. 22 | data_filter (bool): If true, we will filter sample based on the 23 | character. Otherwise not filter. 24 | 25 | """ 26 | 27 | def __init__(self, 28 | root: str, 29 | transform=None, 30 | character: str = 'abcdefghijklmnopqrstuvwxyz0123456789', 31 | batch_max_length: int = 100000, 32 | data_filter: bool = True): 33 | 34 | assert type( 35 | root 36 | ) == str, f'The type of root should be str but got {type(root)}' 37 | 38 | self.root = os.path.abspath(root) 39 | self.character = character 40 | self.batch_max_length = batch_max_length 41 | self.data_filter = data_filter 42 | 43 | if transform is not None: 44 | self.transforms = transform 45 | self.samples = 0 46 | self.img_names = [] 47 | self.gt_texts = [] 48 | self.get_name_list() 49 | 50 | self.logger = logging.getLogger() 51 | self.logger.info( 52 | f'current dataset length is {self.samples} in {self.root}') 53 | 54 | def get_name_list(self): 55 | raise NotImplementedError 56 | 57 | def filter(self, label, retrun_len=False): 58 | if not self.data_filter: 59 | if not retrun_len: 60 | return False 61 | return False, len(label) 62 | """We will filter those samples whose length is larger 63 | than defined max_length by default.""" 64 | character = "".join(sorted(self.character, key=lambda x: ord(x))) 65 | out_of_char = f'[^{character}]' 66 | # replace those character not in self.character with '' 67 | label = re.sub(out_of_char, '', label.lower()) 68 | # filter whose label larger than batch_max_length 69 | if len(label) > self.batch_max_length: 70 | if not retrun_len: 71 | return True 72 | return True, len(label) 73 | if not retrun_len: 74 | return False 75 | return False, len(label) 76 | 77 | def __getitem__(self, index): 78 | # default img channel is rgb 79 | img = cv2.imread(self.img_names[index]) 80 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 81 | label = self.gt_texts[index] 82 | 83 | if self.transforms: 84 | aug = self.transforms(image=img, label=label) 85 | img, label = aug['image'], aug['label'] 86 | 87 | return img, label 88 | 89 | def __len__(self): 90 | return self.samples 91 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from vedastr.utils import build_from_cfg 4 | from .registry import DATASETS 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def build_datasets(cfg, default_args=None): 10 | if isinstance(cfg, list): 11 | datasets = [] 12 | for icfg in cfg: 13 | ds = build_from_cfg(icfg, DATASETS, default_args) 14 | datasets.append(ds) 15 | else: 16 | datasets = build_from_cfg(cfg, DATASETS, default_args) 17 | 18 | return datasets 19 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import ConcatDataset as _ConcatDataset 2 | 3 | from .builder import build_datasets 4 | from .registry import DATASETS 5 | 6 | 7 | @DATASETS.register_module 8 | class ConcatDatasets(_ConcatDataset): 9 | """ Concat different datasets. 10 | 11 | Args: 12 | datasets (list[dict]): A list of which each elements is a dataset cfg. 13 | batch_ratio (list[float]): Ratio of corresponding dataset will be used 14 | in constructing a batch. It makes effect only with balance 15 | sampler. 16 | **kwargs: 17 | 18 | """ 19 | 20 | def __init__(self, 21 | datasets: list, 22 | batch_ratio: list = None, 23 | **kwargs): 24 | assert isinstance(datasets, list) 25 | datasets = build_datasets(datasets, default_args=kwargs) 26 | self.root = ''.join([ds.root for ds in datasets]) 27 | data_range = [len(dataset) for dataset in datasets] 28 | self.data_range = [ 29 | sum(data_range[:i]) for i in range(1, 30 | len(data_range) + 1) 31 | ] 32 | self.batch_ratio = batch_ratio 33 | if self.batch_ratio is not None: 34 | assert len(self.batch_ratio) == len(datasets), \ 35 | 'The length of batch_ratio and datasets should be equal. ' \ 36 | f'But got {len(self.batch_ratio)} batch_ratio and ' \ 37 | f'{len(datasets)} datasets.' 38 | super(ConcatDatasets, self).__init__(datasets=datasets) 39 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/fold_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .base import BaseDataset 4 | from .registry import DATASETS 5 | 6 | 7 | @DATASETS.register_module 8 | class FolderDataset(BaseDataset): 9 | """ Read images in a folder. The format of image filename should be 10 | same as follows: 11 | 'name_gt.extension', where name represents arbitrary string, 12 | gt represents the ground-truth of the image, and extension 13 | represents the postfix (png, jpg, etc.). 14 | 15 | """ 16 | 17 | def __init__(self, 18 | root: str, 19 | transform=None, 20 | character: str = 'abcdefghijklmnopqrstuvwxyz0123456789', 21 | batch_max_length: int = 100000, 22 | data_filter: bool = True, 23 | extension_names: tuple = ('.jpg', '.png', '.bmp', '.jpeg')): 24 | super(FolderDataset, self).__init__( 25 | root=root, 26 | transform=transform, 27 | character=character, 28 | batch_max_length=batch_max_length, 29 | data_filter=data_filter) 30 | self.extension_names = extension_names 31 | 32 | @staticmethod 33 | def parse_filename(text): 34 | return text.split('_')[-1] 35 | 36 | def get_name_list(self): 37 | for item in os.listdir(self.root): 38 | file_name, file_extension = os.path.splitext(item) 39 | if file_extension in self.extension_names: 40 | label = self.parse_filename(file_name) 41 | if self.filter(label): 42 | continue 43 | else: 44 | self.img_names.append(os.path.join(self.root, item)) 45 | self.gt_texts.append(label) 46 | self.samples = len(self.gt_texts) 47 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # modify from clovaai 2 | 3 | import logging 4 | 5 | import lmdb 6 | import numpy as np 7 | import six 8 | from PIL import Image 9 | 10 | from .base import BaseDataset 11 | from .registry import DATASETS 12 | 13 | logger = logging.getLogger() 14 | 15 | 16 | @DATASETS.register_module 17 | class LmdbDataset(BaseDataset): 18 | """ Read the data of lmdb format. 19 | Please refer to https://github.com/Media-Smart/vedastr/issues/27#issuecomment-691793593 # noqa 501 20 | if you have problems with creating lmdb format file. 21 | 22 | """ 23 | 24 | def __init__(self, 25 | root: str, 26 | transform=None, 27 | character: str = 'abcdefghijklmnopqrstuvwxyz0123456789', 28 | batch_max_length: int = 100000, 29 | data_filter: bool = True): 30 | self.index_list = [] 31 | super(LmdbDataset, self).__init__( 32 | root=root, 33 | transform=transform, 34 | character=character, 35 | batch_max_length=batch_max_length, 36 | data_filter=data_filter) 37 | 38 | def get_name_list(self): 39 | self.env = lmdb.open( 40 | self.root, 41 | max_readers=32, 42 | readonly=True, 43 | lock=False, 44 | readahead=False, 45 | meminit=False) 46 | with self.env.begin(write=False) as txn: 47 | n_samples = int(txn.get('num-samples'.encode())) 48 | for index in range(n_samples): 49 | idx = index + 1 # lmdb starts with 1 50 | label_key = 'label-%09d'.encode() % idx 51 | label = txn.get(label_key).decode('utf-8') 52 | if self.filter( 53 | label 54 | ): # if length of label larger than max_len, drop this sample 55 | continue 56 | else: 57 | self.index_list.append(idx) 58 | self.samples = len(self.index_list) 59 | 60 | def read_data(self, index): 61 | assert index <= len(self), 'index range error' 62 | index = self.index_list[index] 63 | with self.env.begin(write=False) as txn: 64 | label_key = 'label-%09d'.encode() % index 65 | label = txn.get(label_key).decode('utf-8') 66 | img_key = 'image-%09d'.encode() % index 67 | imgbuf = txn.get(img_key) 68 | 69 | buf = six.BytesIO() 70 | buf.write(imgbuf) 71 | buf.seek(0) 72 | img = Image.open(buf).convert('RGB') # for color image 73 | img = np.array(img) 74 | 75 | return img, label 76 | 77 | def __getitem__(self, index): 78 | 79 | img, label = self.read_data(index) 80 | if self.transforms: 81 | aug = self.transforms(image=img, label=label) 82 | img, label = aug['image'], aug['label'] 83 | 84 | return img, label 85 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/paste_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import lmdb 5 | 6 | from .lmdb_dataset import LmdbDataset 7 | from .registry import DATASETS 8 | 9 | 10 | @DATASETS.register_module 11 | class PasteDataset(LmdbDataset): 12 | """ Concat two images and the combined label should satisfy the constrains. 13 | 14 | Args: 15 | p: The probability of pasting operation. 16 | 17 | Warnings:: We will create a new transform operation to 18 | replace this dataset. 19 | """ 20 | 21 | def __init__(self, p: float = 0.1, *args, **kwargs): 22 | self.len_sample = dict() 23 | self.len_lists = list() 24 | self.p = p 25 | super(PasteDataset, self).__init__(*args, **kwargs) 26 | 27 | def get_name_list(self): 28 | self.env = lmdb.open( 29 | self.root, 30 | max_readers=32, 31 | readonly=True, 32 | lock=False, 33 | readahead=False, 34 | meminit=False) 35 | with self.env.begin(write=False) as txn: 36 | n_samples = int(txn.get('num-samples'.encode())) 37 | for index in range(n_samples): 38 | idx = index + 1 # lmdb starts with 1 39 | label_key = 'label-%09d'.encode() % idx 40 | label = txn.get(label_key).decode('utf-8') 41 | flag, length = self.filter(label, retrun_len=True) 42 | if flag: 43 | continue 44 | else: 45 | self.index_list.append(idx) 46 | self.len_lists.append(length) 47 | if length not in self.len_sample: 48 | self.len_sample[length] = [len(self.index_list) - 1] 49 | else: 50 | self.len_sample[length].append( 51 | len(self.index_list) - 1) 52 | 53 | self.samples = len(self.index_list) 54 | 55 | def __getitem__(self, index): 56 | img, label = self.read_data(index) 57 | th, tw = img.shape[:2] 58 | c_len = self.len_lists[index] 59 | max_need_len = self.batch_max_length - c_len 60 | if max_need_len > 2 and random.random() < self.p: 61 | p_len = random.randint(1, max_need_len - 1) 62 | p_idx = random.choice(self.len_sample[p_len]) 63 | p_img, p_label = self.read_data(p_idx) 64 | p_img = cv2.resize(p_img, (tw, th)) 65 | img = cv2.hconcat([img, p_img]) 66 | label = label + p_label 67 | 68 | if self.transforms: 69 | aug = self.transforms(image=img, label=label) 70 | img, label = aug['image'], aug['label'] 71 | 72 | return img, label 73 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | DATASETS = Registry('datasets') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/datasets/txt_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .base import BaseDataset 4 | from .registry import DATASETS 5 | 6 | 7 | @DATASETS.register_module 8 | class TxtDataset(BaseDataset): 9 | """ Read images based on the txt file. 10 | The format of lines in txt should be same as follows: 11 | image_path label 12 | 13 | The image path and label should be split with '\t'. 14 | 15 | """ 16 | 17 | def __init__( 18 | self, 19 | root: str, 20 | gt_txt: str, 21 | transform=None, 22 | character: str = 'abcdefghijklmnopqrstuvwxyz0123456789', 23 | batch_max_length: int = 25, 24 | data_filter: bool = True, 25 | ): 26 | # ground truth 파일 있는지 검사 27 | if gt_txt is not None: 28 | assert os.path.isfile(gt_txt) 29 | self.gt_txt = gt_txt 30 | super(TxtDataset, self).__init__( 31 | root=root, 32 | transform=transform, 33 | character=character, 34 | batch_max_length=batch_max_length, 35 | data_filter=data_filter, 36 | ) 37 | 38 | def get_name_list(self): 39 | with open(self.gt_txt, 'r') as gt: 40 | for line in gt.readlines(): 41 | img_name, label = line.strip().split('\t') 42 | if self.filter(label): 43 | continue 44 | else: 45 | self.img_names.append(os.path.join(self.root, img_name)) 46 | self.gt_texts.append(label) 47 | 48 | self.samples = len(self.gt_texts) 49 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_logger # noqa 401 2 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/logger/builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import time 5 | 6 | import torch.distributed as dist 7 | 8 | 9 | def build_logger(cfg, default_args): 10 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 11 | format_ = '%(asctime)s - %(levelname)s - %(message)s' 12 | 13 | formatter = logging.Formatter(format_) 14 | logger = logging.getLogger() 15 | logger.setLevel(logging.DEBUG) 16 | if logger.parent is not None: 17 | logger.parent.handlers.clear() 18 | else: 19 | logger.handlers.clear() 20 | if dist.is_available() and dist.is_initialized(): 21 | rank = dist.get_rank() 22 | else: 23 | rank = 0 24 | for handler in cfg['handlers']: 25 | if handler['type'] == 'StreamHandler': 26 | instance = logging.StreamHandler(sys.stdout) 27 | elif handler['type'] == 'FileHandler': 28 | # only rank 0 will add a FileHandler 29 | if default_args.get('workdir') and rank == 0: 30 | fp = os.path.join(default_args['workdir'], 31 | '%s.log' % timestamp) 32 | instance = logging.FileHandler(fp, 'w') 33 | else: 34 | continue 35 | else: 36 | instance = logging.StreamHandler(sys.stdout) 37 | 38 | level = getattr(logging, handler['level']) 39 | 40 | instance.setFormatter(formatter) 41 | if rank == 0: 42 | instance.setLevel(level) 43 | else: 44 | logger.setLevel(logging.ERROR) 45 | 46 | logger.addHandler(instance) 47 | 48 | return logger 49 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/lr_schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_lr_scheduler # noqa 401 2 | from .constant_lr import ConstantLR # noqa 401 3 | from .cosine_lr import CosineLR # noqa 401 4 | from .exponential_lr import ExponentialLR # noqa 401 5 | from .poly_lr import PolyLR # noqa 401 6 | from .step_lr import StepLR # noqa 401 7 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/lr_schedulers/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import LR_SCHEDULERS 3 | 4 | 5 | def build_lr_scheduler(cfg, default_args=None): 6 | scheduler = build_from_cfg(cfg, LR_SCHEDULERS, default_args, 'registry') 7 | return scheduler 8 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/lr_schedulers/constant_lr.py: -------------------------------------------------------------------------------- 1 | from .base import _Iter_LRScheduler 2 | from .registry import LR_SCHEDULERS 3 | 4 | 5 | @LR_SCHEDULERS.register_module 6 | class ConstantLR(_Iter_LRScheduler): 7 | """ConstantLR 8 | """ 9 | 10 | def __init__(self, 11 | optimizer, 12 | niter_per_epoch, 13 | last_iter=-1, 14 | warmup_epochs=0, 15 | iter_based=True, 16 | **kwargs): 17 | self.warmup_iters = niter_per_epoch * warmup_epochs 18 | super().__init__(optimizer, niter_per_epoch, last_iter, iter_based) 19 | 20 | def get_lr(self): 21 | if self.last_iter < self.warmup_iters: 22 | multiplier = self.last_iter / float(self.warmup_iters) 23 | else: 24 | multiplier = 1.0 25 | return [base_lr * multiplier for base_lr in self.base_lrs] 26 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/lr_schedulers/cosine_lr.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi 2 | 3 | from .base import _Iter_LRScheduler 4 | from .registry import LR_SCHEDULERS 5 | 6 | 7 | @LR_SCHEDULERS.register_module 8 | class CosineLR(_Iter_LRScheduler): 9 | """CosineLR 10 | """ 11 | 12 | def __init__(self, 13 | optimizer, 14 | niter_per_epoch, 15 | max_epochs, 16 | last_iter=-1, 17 | warmup_epochs=0, 18 | iter_based=True): 19 | self.max_iters = niter_per_epoch * max_epochs 20 | self.warmup_iters = niter_per_epoch * warmup_epochs 21 | super().__init__(optimizer, niter_per_epoch, last_iter, iter_based) 22 | 23 | def get_lr(self): 24 | if self.last_iter < self.warmup_iters: 25 | multiplier = 0.5 * ( 26 | 1 - cos(pi * (self.last_iter / float(self.warmup_iters)))) 27 | else: 28 | multiplier = 0.5 * (1 + cos(pi * ( 29 | (self.last_iter - self.warmup_iters) / 30 | float(self.max_iters - self.warmup_iters)))) 31 | return [base_lr * multiplier for base_lr in self.base_lrs] 32 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/lr_schedulers/exponential_lr.py: -------------------------------------------------------------------------------- 1 | from .base import _Iter_LRScheduler 2 | from .registry import LR_SCHEDULERS 3 | 4 | 5 | @LR_SCHEDULERS.register_module 6 | class ExponentialLR(_Iter_LRScheduler): 7 | """ExponentialLR 8 | """ 9 | 10 | def __init__(self, 11 | optimizer, 12 | niter_per_epoch, 13 | max_epochs, 14 | gamma, 15 | step, 16 | last_iter=-1, 17 | warmup_epochs=0, 18 | iter_based=True): 19 | self.max_iters = niter_per_epoch * max_epochs 20 | self.gamma = gamma 21 | self.step_iters = niter_per_epoch * step 22 | self.warmup_iters = int(niter_per_epoch * warmup_epochs) 23 | super().__init__(optimizer, niter_per_epoch, last_iter, iter_based) 24 | 25 | def get_lr(self): 26 | if self.last_iter < self.warmup_iters: 27 | multiplier = self.last_iter / float(self.warmup_iters) 28 | else: 29 | multiplier = self.gamma**((self.last_iter - self.warmup_iters) / 30 | float(self.step_iters)) 31 | return [base_lr * multiplier for base_lr in self.base_lrs] 32 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/lr_schedulers/poly_lr.py: -------------------------------------------------------------------------------- 1 | from .base import _Iter_LRScheduler 2 | from .registry import LR_SCHEDULERS 3 | 4 | 5 | @LR_SCHEDULERS.register_module 6 | class PolyLR(_Iter_LRScheduler): 7 | """PolyLR 8 | """ 9 | 10 | def __init__(self, 11 | optimizer, 12 | niter_per_epoch, 13 | max_epochs, 14 | power=0.9, 15 | last_iter=-1, 16 | warmup_epochs=0, 17 | iter_based=True): 18 | self.max_iters = niter_per_epoch * max_epochs 19 | self.power = power 20 | self.warmup_iters = niter_per_epoch * warmup_epochs 21 | super().__init__(optimizer, niter_per_epoch, last_iter, iter_based) 22 | 23 | def get_lr(self): 24 | if self.last_iter < self.warmup_iters: 25 | multiplier = (self.last_iter / 26 | float(self.warmup_iters))**self.power 27 | else: 28 | multiplier = ( 29 | 1 - (self.last_iter - self.warmup_iters) / 30 | float(self.max_iters - self.warmup_iters))**self.power 31 | return [base_lr * multiplier for base_lr in self.base_lrs] 32 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/lr_schedulers/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | LR_SCHEDULERS = Registry('lr_scheduler') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/lr_schedulers/step_lr.py: -------------------------------------------------------------------------------- 1 | from .base import _Iter_LRScheduler 2 | from .registry import LR_SCHEDULERS 3 | 4 | 5 | @LR_SCHEDULERS.register_module 6 | class StepLR(_Iter_LRScheduler): 7 | 8 | def __init__(self, 9 | optimizer, 10 | niter_per_epoch, 11 | max_epochs, 12 | milestones, 13 | gamma=0.1, 14 | last_iter=-1, 15 | warmup_epochs=0, 16 | iter_based=True): 17 | self.max_iters = niter_per_epoch * max_epochs 18 | self.milestones = milestones 19 | self.count = 0 20 | self.gamma = gamma 21 | self.warmup_iters = int(niter_per_epoch * warmup_epochs) 22 | super(StepLR, self).__init__(optimizer, niter_per_epoch, last_iter, 23 | iter_based) 24 | 25 | def get_lr(self): 26 | if self._iter_based and self.last_iter in self.milestones: 27 | self.count += 1 28 | elif not self._iter_based and self.last_epoch in self.milestones: 29 | self.count += 1 30 | 31 | if self.last_iter < self.warmup_iters: 32 | multiplier = self.last_iter / float(self.warmup_iters) 33 | else: 34 | multiplier = self.gamma**self.count 35 | return [base_lr * multiplier for base_lr in self.base_lrs] 36 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import Accuracy # noqa 401 2 | from .builder import build_metric # noqa 401 3 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | # modify from clovaai 2 | import random 3 | 4 | import torch 5 | from nltk.metrics.distance import edit_distance 6 | 7 | from .registry import METRICS 8 | from ..utils import gather_tensor, get_dist_info 9 | 10 | 11 | @METRICS.register_module 12 | class Accuracy(object): 13 | 14 | def __init__(self): 15 | self.reset() 16 | self.predict_example_log = None 17 | 18 | @property 19 | def result(self): 20 | res = { 21 | 'acc': self.avg['acc']['true'], 22 | 'edit_distance': self.avg['edit'], 23 | } 24 | return res 25 | 26 | def measure(self, preds, preds_prob, gts, exclude_num=0): 27 | batch_size = len(gts) 28 | true_nums = [] 29 | norm_EDs = [] 30 | r, w = get_dist_info() 31 | for pstr, gstr in zip(preds, gts): 32 | if pstr == gstr: 33 | true_nums.append(1.) 34 | else: 35 | true_nums.append(0.) 36 | if len(pstr) == 0 or len(gstr) == 0: 37 | norm_EDs.append(0) 38 | elif len(gstr) > len(pstr): 39 | norm_EDs.append(1 - edit_distance(pstr, gstr) / len(gstr)) 40 | else: 41 | norm_EDs.append(1 - edit_distance(pstr, gstr) / len(pstr)) 42 | # gather batch_size, true_num, norm_ED from different workers 43 | batch_sizes = gather_tensor(torch.tensor(batch_size)[None].cuda()) 44 | true_nums = gather_tensor( 45 | torch.tensor(true_nums)[None].cuda()).flatten() 46 | norm_EDs = gather_tensor(torch.tensor(norm_EDs)[None].cuda()).flatten() 47 | 48 | # remove exclude data 49 | if exclude_num != 0: 50 | batch_size = torch.sum(batch_sizes).cpu().numpy() - exclude_num 51 | true_nums = list(true_nums.split(true_nums.shape[0] // w)) 52 | for i in range(1, exclude_num + 1): 53 | true_nums[-i] = true_nums[-i][:-1] 54 | true_nums = torch.cat(true_nums) 55 | # true_nums = true_nums.flatten()[:-exclude_num] 56 | norm_EDs = list(norm_EDs.split(true_nums.shape[0] // w)) 57 | for i in range(1, exclude_num + 1): 58 | norm_EDs[-i] = norm_EDs[-i][:-1] 59 | norm_EDs = torch.cat(norm_EDs) 60 | # norm_EDs = norm_EDs.flatten()[:-exclude_num] 61 | else: 62 | batch_size = torch.sum(batch_sizes).cpu().numpy() 63 | 64 | true_num = torch.sum(true_nums).cpu().numpy() 65 | norm_ED = torch.sum(norm_EDs).cpu().numpy() 66 | 67 | if preds_prob is not None: 68 | self.show_example(preds, preds_prob, gts) 69 | self.all['acc']['true'] += true_num 70 | self.all['acc']['false'] += (batch_size - true_num) 71 | self.all['edit'] += norm_ED 72 | self.count += batch_size 73 | for key, value in self.all['acc'].items(): 74 | self.avg['acc'][key] = self.all['acc'][key] / self.count 75 | self.avg['edit'] = self.all['edit'] / self.count 76 | 77 | def reset(self): 78 | self.all = dict(acc=dict(true=0, false=0), edit=0) 79 | self.avg = dict(acc=dict(true=0, false=0), edit=0) 80 | self.count = 0 81 | 82 | def show_example(self, preds, preds_prob, gts): 83 | count = 0 84 | self.predict_example_log = None 85 | dashed_line = '-' * 80 86 | 87 | output = dashed_line + '\n' 88 | show_inds = list(range(len(gts))) 89 | random.shuffle(show_inds) 90 | show_inds = show_inds[:5] 91 | show_gts = [gts[i] for i in show_inds] 92 | show_preds = [preds[i] for i in show_inds] 93 | show_prob = [preds_prob[i] for i in show_inds] 94 | for gt, pred, prob in zip(show_gts, show_preds, show_prob): 95 | output += f'{"Ground Truth : ":18s} {gt:40s} \n' \ 96 | f'{"Prediction : ":18s} {pred:40s} \n' \ 97 | f'{"Confidence : ":18s} {prob:0.4f} \n' \ 98 | f'{"Success(T/F) : ":18s} {str(pred == gt)} \n' \ 99 | f'{dashed_line}\n' 100 | count += 1 101 | if count > 4: 102 | break 103 | 104 | output += dashed_line 105 | 106 | self.predict_example_log = output 107 | 108 | return self.predict_example_log 109 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/metrics/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import METRICS 3 | 4 | 5 | def build_metric(cfg, default_args=None): 6 | metric = build_from_cfg(cfg, METRICS, default_args) 7 | 8 | return metric 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/metrics/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | METRICS = Registry('metrics') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_model # noqa 401 2 | from .model import GModel # noqa 401 3 | from .registry import MODELS # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/__init__.py: -------------------------------------------------------------------------------- 1 | from .body import GBody # noqa 401 2 | from .builder import build_body, build_component # noqa 401 3 | from .feature_extractors import build_brick, build_feature_extractor # noqa 401 4 | from .rectificators import build_rectificator # noqa 401 5 | from .sequences import build_sequence_decoder, build_sequence_encoder # noqa 401 6 | from .component import (BrickComponent, FeatureExtractorComponent, # noqa 401 7 | PlugComponent, RectificatorComponent, # noqa 401 8 | SequenceEncoderComponent) # noqa 401 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/body.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .builder import build_component 4 | from .feature_extractors import build_brick 5 | from .registry import BODIES 6 | 7 | 8 | @BODIES.register_module 9 | class GBody(nn.Module): 10 | 11 | def __init__(self, pipelines, collect=None): 12 | super(GBody, self).__init__() 13 | 14 | self.input_to_layer = 'input' 15 | self.components = nn.ModuleList( 16 | [build_component(component) for component in pipelines]) 17 | 18 | if collect is not None: 19 | self.collect = build_brick(collect) 20 | 21 | @property 22 | def with_collect(self): 23 | return hasattr(self, 'collect') and self.collect is not None 24 | 25 | def forward(self, x): 26 | feats = {self.input_to_layer: x} 27 | 28 | for component in self.components: 29 | component_from = component.from_layer 30 | component_to = component.to_layer 31 | 32 | if isinstance(component_from, list): 33 | inp = {key: feats[key] for key in component_from} 34 | out = component(**inp) 35 | else: 36 | inp = feats[component_from] 37 | out = component(inp) 38 | feats[component_to] = out 39 | 40 | if self.with_collect: 41 | return self.collect(feats) 42 | else: 43 | return feats 44 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import BODIES, COMPONENT 3 | 4 | 5 | def build_component(cfg, default_args=None): 6 | component = build_from_cfg(cfg, COMPONENT, default_args) 7 | 8 | return component 9 | 10 | 11 | def build_body(cfg, default_args=None): 12 | body = build_from_cfg(cfg, BODIES, default_args) 13 | 14 | return body 15 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/component.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .feature_extractors import build_brick, build_feature_extractor 4 | from .rectificators import build_rectificator 5 | from .registry import COMPONENT 6 | from .sequences import build_sequence_encoder 7 | from ..utils import build_module 8 | 9 | 10 | class BaseComponent(nn.Module): 11 | 12 | def __init__(self, from_layer, to_layer, component): 13 | super(BaseComponent, self).__init__() 14 | 15 | self.from_layer = from_layer 16 | self.to_layer = to_layer 17 | self.component = component 18 | 19 | def forward(self, x): 20 | return self.component(x) 21 | 22 | 23 | @COMPONENT.register_module 24 | class FeatureExtractorComponent(BaseComponent): 25 | 26 | def __init__(self, from_layer, to_layer, arch): 27 | super(FeatureExtractorComponent, 28 | self).__init__(from_layer, to_layer, 29 | build_feature_extractor(arch)) 30 | 31 | 32 | @COMPONENT.register_module 33 | class RectificatorComponent(BaseComponent): 34 | 35 | def __init__(self, from_layer, to_layer, arch): 36 | super(RectificatorComponent, self).__init__(from_layer, to_layer, 37 | build_rectificator(arch)) 38 | 39 | 40 | @COMPONENT.register_module 41 | class SequenceEncoderComponent(BaseComponent): 42 | 43 | def __init__(self, from_layer, to_layer, arch): 44 | super(SequenceEncoderComponent, 45 | self).__init__(from_layer, to_layer, 46 | build_sequence_encoder(arch)) 47 | 48 | 49 | @COMPONENT.register_module 50 | class BrickComponent(BaseComponent): 51 | 52 | def __init__(self, from_layer, to_layer, arch): 53 | super(BrickComponent, self).__init__(from_layer, to_layer, 54 | build_brick(arch)) 55 | 56 | 57 | @COMPONENT.register_module 58 | class PlugComponent(BaseComponent): 59 | 60 | def __init__(self, from_layer, to_layer, arch): 61 | super(PlugComponent, self).__init__(from_layer, to_layer, 62 | build_module(arch)) 63 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_feature_extractor # noqa 401 2 | from .decoders import build_brick, build_bricks, build_decoder # noqa 401 3 | from .encoders import build_backbone, build_encoder, build_enhance_module # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/builder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .decoders import build_brick, build_decoder 4 | from .encoders import build_encoder 5 | 6 | 7 | def build_feature_extractor(cfg): 8 | encoder = build_encoder(cfg.get('encoder')) 9 | 10 | if cfg.get('decoder'): 11 | middle = build_decoder(cfg.get('decoder')) 12 | if 'collect' in cfg: 13 | final = build_brick(cfg.get('collect')) 14 | feature_extractor = nn.Sequential(encoder, middle, final) 15 | else: 16 | feature_extractor = nn.Sequential(encoder, middle) 17 | # assert 'collect' not in cfg 18 | else: 19 | assert 'collect' in cfg 20 | middle = build_brick(cfg.get('collect')) 21 | feature_extractor = nn.Sequential(encoder, middle) 22 | 23 | return feature_extractor 24 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .bricks import build_brick, build_bricks # noqa 401 2 | from .builder import build_decoder # noqa 401 3 | from .gfpn import GFPN # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/bricks/__init__.py: -------------------------------------------------------------------------------- 1 | from .bricks import (CellAttentionBlock, CollectBlock, FusionBlock, # noqa 401 2 | JunctionBlock) # noqa 401 3 | from .builder import build_brick, build_bricks # noqa 401 4 | from .pva import PVABlock # noqa 401 5 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/bricks/builder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from vedastr.utils import build_from_cfg 4 | from .registry import BRICKS 5 | 6 | 7 | def build_brick(cfg, default_args=None): 8 | brick = build_from_cfg(cfg, BRICKS, default_args) 9 | return brick 10 | 11 | 12 | def build_bricks(cfgs): 13 | bricks = nn.ModuleList() 14 | for brick_cfg in cfgs: 15 | bricks.append(build_brick(brick_cfg)) 16 | return bricks 17 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/bricks/pva.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from vedastr.models.weight_init import init_weights 5 | from .registry import BRICKS 6 | 7 | 8 | @BRICKS.register_module 9 | class PVABlock(nn.Module): 10 | 11 | def __init__(self, 12 | num_steps, 13 | in_channels, 14 | embedding_channels=512, 15 | inner_channels=512): 16 | super(PVABlock, self).__init__() 17 | 18 | self.num_steps = num_steps 19 | self.in_channels = in_channels 20 | self.inner_channels = inner_channels 21 | self.embedding_channels = embedding_channels 22 | 23 | self.order_embeddings = nn.Parameter( 24 | torch.randn(self.num_steps, self.embedding_channels), 25 | requires_grad=True) 26 | 27 | self.v_linear = nn.Linear( 28 | self.in_channels, self.inner_channels, bias=False) 29 | self.o_linear = nn.Linear( 30 | self.embedding_channels, self.inner_channels, bias=False) 31 | self.e_linear = nn.Linear(self.inner_channels, 1, bias=False) 32 | 33 | init_weights(self.modules()) 34 | 35 | def forward(self, x): 36 | b, c, h, w = x.size() 37 | 38 | x = x.reshape(b, c, h * w).permute(0, 2, 1) 39 | 40 | o_out = self.o_linear(self.order_embeddings).view( 41 | 1, self.num_steps, 1, self.inner_channels) 42 | v_out = self.v_linear(x).unsqueeze(1) 43 | att = self.e_linear(torch.tanh(o_out + v_out)).squeeze(3) 44 | att = torch.softmax(att, dim=2) 45 | 46 | out = torch.bmm(att, x) 47 | 48 | return out 49 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/bricks/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | BRICKS = Registry('brick') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import DECODERS 3 | 4 | 5 | def build_decoder(cfg, default_args=None): 6 | decoder = build_from_cfg(cfg, DECODERS, default_args) 7 | return decoder 8 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/gfpn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from vedastr.models.weight_init import init_weights 6 | from .bricks import build_brick, build_bricks 7 | from .registry import DECODERS 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | @DECODERS.register_module 13 | class GFPN(nn.Module): 14 | 15 | def __init__(self, neck: list, fusion: dict = None): 16 | super().__init__() 17 | self.neck = build_bricks(neck) 18 | if fusion: 19 | self.fusion = build_brick(fusion) 20 | else: 21 | self.fusion = None 22 | logger.info('GFPN init weights') 23 | init_weights(self.modules()) 24 | 25 | def forward(self, bottom_up): 26 | 27 | x = None 28 | feats = {} 29 | for ii, layer in enumerate(self.neck): 30 | top_down_from_layer = layer.from_layer.get('top_down') 31 | lateral_from_layer = layer.from_layer.get('lateral') 32 | 33 | if lateral_from_layer: 34 | ll = bottom_up[lateral_from_layer] 35 | else: 36 | ll = None 37 | if top_down_from_layer is None: 38 | td = None 39 | elif 'c' in top_down_from_layer: 40 | td = bottom_up[top_down_from_layer] 41 | elif 'p' in top_down_from_layer: 42 | td = feats[top_down_from_layer] 43 | else: 44 | raise ValueError('Key error') 45 | 46 | x = layer(td, ll) 47 | feats[layer.to_layer] = x 48 | bottom_up[layer.to_layer] = x 49 | 50 | if self.fusion: 51 | x = self.fusion(feats) 52 | bottom_up['fusion'] = x 53 | return bottom_up 54 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/decoders/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | DECODERS = Registry('decoder') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import build_backbone # noqa 401 2 | from .builder import build_encoder # noqa 401 3 | from .enhance_modules import build_enhance_module # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_backbone # noqa 401 2 | from .general_backbone import GBackbone # noqa 401 3 | from .resnet import GResNet, ResNet # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/backbones/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import BACKBONES 3 | 4 | 5 | def build_backbone(cfg, default_args=None): 6 | backbone = build_from_cfg(cfg, BACKBONES, default_args) 7 | 8 | return backbone 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/backbones/general_backbone.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from vedastr.models.utils import build_module, build_torch_nn 6 | from vedastr.models.weight_init import init_weights 7 | from .registry import BACKBONES 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | @BACKBONES.register_module 13 | class GBackbone(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | layers: list, 18 | ): 19 | super(GBackbone, self).__init__() 20 | 21 | self.layers = nn.ModuleList() 22 | stage_layers = [] 23 | for layer_cfg in layers: 24 | type_name = layer_cfg['type'] 25 | if hasattr(nn, type_name): 26 | layer = build_torch_nn(layer_cfg) 27 | else: 28 | layer = build_module(layer_cfg) 29 | stride = layer_cfg.get('stride', 1) 30 | max_stride = stride if isinstance(stride, int) else max(stride) 31 | if max_stride > 1: 32 | self.layers.append(nn.Sequential(*stage_layers)) 33 | stage_layers = [] 34 | stage_layers.append(layer) 35 | self.layers.append(nn.Sequential(*stage_layers)) 36 | logger.info('GBackbone init weights') 37 | init_weights(self.modules()) 38 | 39 | def forward(self, x): 40 | feats = {} 41 | 42 | for i, layer in enumerate(self.layers): 43 | x = layer(x) 44 | feats['c{}'.format(i)] = x 45 | 46 | return feats 47 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/backbones/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | BACKBONES = Registry('backbone') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/builder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .backbones import build_backbone 4 | from .enhance_modules import build_enhance_module 5 | 6 | 7 | def build_encoder(cfg, default_args=None): 8 | backbone = build_backbone(cfg['backbone']) 9 | 10 | enhance_cfg = cfg.get('enhance') 11 | if enhance_cfg: 12 | enhance_module = build_enhance_module(enhance_cfg) 13 | encoder = nn.Sequential(backbone, enhance_module) 14 | else: 15 | encoder = backbone 16 | 17 | return encoder 18 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aspp import ASPP # noqa 401 2 | from .builder import build_enhance_module # noqa 401 3 | from .ppm import PPM # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/aspp.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/pytorch/vision/tree/master/torchvision/models/segmentation/deeplabv3.py # noqa 501 2 | 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from vedastr.models.weight_init import init_weights 10 | from .registry import ENHANCE_MODULES 11 | 12 | logger = logging.getLogger() 13 | 14 | 15 | class ASPPConv(nn.Sequential): 16 | 17 | def __init__(self, in_channels, out_channels, dilation): 18 | modules = [ 19 | nn.Conv2d( 20 | in_channels, 21 | out_channels, 22 | 3, 23 | padding=dilation, 24 | dilation=dilation, 25 | bias=False), 26 | nn.BatchNorm2d(out_channels), 27 | nn.ReLU(inplace=True) 28 | ] 29 | super(ASPPConv, self).__init__(*modules) 30 | 31 | 32 | class ASPPPooling(nn.Sequential): 33 | 34 | def __init__(self, in_channels, out_channels): 35 | super(ASPPPooling, self).__init__( 36 | nn.AdaptiveAvgPool2d(1), 37 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 38 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) 39 | 40 | def forward(self, x): 41 | size = x.shape[-2:] 42 | x = super(ASPPPooling, self).forward(x) 43 | return F.interpolate(x, size=size, mode='nearest') 44 | 45 | 46 | @ENHANCE_MODULES.register_module 47 | class ASPP(nn.Module): 48 | 49 | def __init__(self, 50 | in_channels: int, 51 | out_channels: int, 52 | atrous_rates: tuple, 53 | from_layer: str, 54 | to_layer: str, 55 | dropout=None): 56 | super(ASPP, self).__init__() 57 | self.from_layer = from_layer 58 | self.to_layer = to_layer 59 | 60 | modules = [] 61 | modules.append( 62 | nn.Sequential( 63 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 64 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))) 65 | 66 | rate1, rate2, rate3 = tuple(atrous_rates) 67 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 68 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 69 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 70 | modules.append(ASPPPooling(in_channels, out_channels)) 71 | 72 | self.convs = nn.ModuleList(modules) 73 | 74 | self.project = nn.Sequential( 75 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 76 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) 77 | self.with_dropout = dropout is not None 78 | if self.with_dropout: 79 | self.dropout = nn.Dropout(p=dropout) 80 | 81 | logger.info('ASPP init weights') 82 | init_weights(self.modules()) 83 | 84 | def forward(self, feats): 85 | feats_ = feats.copy() 86 | x = feats_[self.from_layer] 87 | res = [] 88 | for conv in self.convs: 89 | res.append(conv(x)) 90 | res = torch.cat(res, dim=1) 91 | res = self.project(res) 92 | if self.with_dropout: 93 | res = self.dropout(res) 94 | feats_[self.to_layer] = res 95 | return feats_ 96 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import ENHANCE_MODULES 3 | 4 | 5 | def build_enhance_module(cfg, default_args=None): 6 | enhance_module = build_from_cfg(cfg, ENHANCE_MODULES, default_args) 7 | 8 | return enhance_module 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/ppm.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/hszhao/semseg/blob/master/model/pspnet.py 2 | 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from vedastr.models.weight_init import init_weights 10 | from .registry import ENHANCE_MODULES 11 | 12 | logger = logging.getLogger() 13 | 14 | 15 | @ENHANCE_MODULES.register_module 16 | class PPM(nn.Module): 17 | 18 | def __init__(self, in_channels, out_channels, bins, from_layer, to_layer): 19 | super(PPM, self).__init__() 20 | self.from_layer = from_layer 21 | self.to_layer = to_layer 22 | 23 | self.blocks = nn.ModuleList() 24 | for bin_ in bins: 25 | self.blocks.append( 26 | nn.Sequential( 27 | nn.AdaptiveAvgPool2d(bin_), 28 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 29 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))) 30 | logger.info('PPM init weights') 31 | init_weights(self.modules()) 32 | 33 | def forward(self, feats): 34 | feats_ = feats.copy() 35 | x = feats_[self.from_layer] 36 | h, w = x.shape[2:] 37 | out = [x] 38 | for block in self.blocks: 39 | feat = F.interpolate( 40 | block(x), (h, w), mode='bilinear', align_corners=True) 41 | out.append(feat) 42 | out = torch.cat(out, 1) 43 | feats_[self.to_layer] = out 44 | 45 | return feats_ 46 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | ENHANCE_MODULES = Registry('enhance_module') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/rectificators/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_rectificator # noqa 401 2 | from .spin import SPIN # noqa 401 3 | from .tps_stn import TPS_STN # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/rectificators/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import RECTIFICATORS 3 | 4 | 5 | def build_rectificator(cfg, default_args=None): 6 | rectificator = build_from_cfg(cfg, RECTIFICATORS, default_args) 7 | 8 | return rectificator 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/rectificators/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | RECTIFICATORS = Registry('Rectificator') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/rectificators/spin.py: -------------------------------------------------------------------------------- 1 | # [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117) # noqa 501 2 | # Not fully implemented yet. SPN has tested successfully. 3 | import copy 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from vedastr.models.bodies.feature_extractors import build_feature_extractor 10 | from vedastr.models.utils import build_module, build_torch_nn 11 | from vedastr.models.weight_init import init_weights 12 | from .registry import RECTIFICATORS 13 | 14 | 15 | class SPN(nn.Module): 16 | 17 | def __init__(self, cfg): 18 | super(SPN, self).__init__() 19 | self.body = build_feature_extractor(cfg['feature_extractor']) 20 | self.pool = build_torch_nn(cfg['pool']) 21 | heads = [] 22 | for head in cfg['head']: 23 | heads.append(build_module(head)) 24 | self.head = nn.Sequential(*heads) 25 | 26 | def forward(self, x): 27 | batch_size = x.size(0) 28 | x = self.body(x) 29 | x = self.pool(x).view(batch_size, -1) 30 | x = self.head(x) 31 | return x 32 | 33 | 34 | class AIN(nn.Module): 35 | 36 | def __init__(self, cfg): 37 | super(AIN, self).__init__() 38 | self.body = build_feature_extractor(cfg['feature_extractor']) 39 | 40 | def forward(self, x): 41 | x = self.body(x) 42 | 43 | return x 44 | 45 | 46 | @RECTIFICATORS.register_module 47 | class SPIN(nn.Module): 48 | 49 | def __init__(self, spin: dict, k: int): 50 | super(SPIN, self).__init__() 51 | self.body = build_feature_extractor(spin['feature_extractor']) 52 | self.spn = SPN(spin['spn']) 53 | self.betas = generate_beta(k) 54 | init_weights(self.modules()) 55 | 56 | def forward(self, x): 57 | b, c, h, w = x.size() 58 | init_img = copy.copy(x) 59 | # shared parameters 60 | x = self.body(x) 61 | 62 | spn_out = self.spn(x) # 2k+2 63 | omega = spn_out[:, :-1] 64 | g_out = init_img.requires_grad_(True) 65 | 66 | gamma_out = [g_out**beta for beta in self.betas] 67 | gamma_out = torch.stack(gamma_out, axis=1).requires_grad_(True) 68 | 69 | fusion_img = omega[:, :, None, None, None] * gamma_out 70 | fusion_img = torch.sigmoid(fusion_img.sum(dim=1)) 71 | return fusion_img 72 | 73 | 74 | def generate_beta(k): 75 | betas = [] 76 | for i in range(1, k + 2): 77 | p = i / (2 * (k + 1)) 78 | beta = round(np.log(1 - p) / np.log(p), 2) 79 | betas.append(beta) 80 | for i in range(k + 2, 2 * k + 2): 81 | beta = round(1 / betas[(i - (k + 1))], 2) 82 | betas.append(beta) 83 | 84 | return betas 85 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/rectificators/sspin.py: -------------------------------------------------------------------------------- 1 | # We implement a new module which has same property like spin to some extent. 2 | # We think this manner can replace the GA-SPIN by enlarging output features 3 | # of se layer, but we didn't do further experiments. 4 | 5 | import torch.nn as nn 6 | 7 | from vedastr.models.bodies.feature_extractors import build_feature_extractor 8 | from vedastr.models.utils import SE 9 | from vedastr.models.weight_init import init_weights 10 | from .registry import RECTIFICATORS 11 | 12 | 13 | @RECTIFICATORS.register_module 14 | class SSPIN(nn.Module): 15 | 16 | def __init__(self, feature_cfg, se_cfgs): 17 | super(SSPIN, self).__init__() 18 | self.body = build_feature_extractor(feature_cfg) 19 | self.se = SE(**se_cfgs) 20 | init_weights(self.modules()) 21 | 22 | def forward(self, x): 23 | x = self.body(x) 24 | x = self.se(x) 25 | 26 | return x 27 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | COMPONENT = Registry('component') 4 | BODIES = Registry('body') 5 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_sequence_decoder, build_sequence_encoder # noqa 401 2 | from .rnn import RNN, GRUCell # noqa 401 3 | from .transformer import TransformerEncoder # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import SEQUENCE_DECODERS, SEQUENCE_ENCODERS 3 | 4 | 5 | def build_sequence_encoder(cfg, default_args=None): 6 | sequence_encoder = build_from_cfg(cfg, SEQUENCE_ENCODERS, default_args) 7 | 8 | return sequence_encoder 9 | 10 | 11 | def build_sequence_decoder(cfg, default_args=None): 12 | sequence_encoder = build_from_cfg(cfg, SEQUENCE_DECODERS, default_args) 13 | 14 | return sequence_encoder 15 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | SEQUENCE_ENCODERS = Registry('sequence_encoder') 4 | SEQUENCE_DECODERS = Registry('sequence_decoder') 5 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/rnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import GRUCell, LSTMCell # noqa 401 2 | from .encoder import RNN # noqa 401 3 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/rnn/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from vedastr.models.weight_init import init_weights 5 | from ..registry import SEQUENCE_DECODERS 6 | 7 | 8 | class BaseCell(nn.Module): 9 | 10 | def __init__(self, 11 | basic_cell, 12 | input_size, 13 | hidden_size, 14 | bias=True, 15 | num_layers=1): 16 | super(BaseCell, self).__init__() 17 | 18 | self.input_size = input_size 19 | self.hidden_size = hidden_size 20 | self.bias = bias 21 | self.num_layers = num_layers 22 | 23 | self.cells = nn.ModuleList() 24 | for i in range(num_layers): 25 | if i == 0: 26 | self.cells.append( 27 | basic_cell( 28 | input_size=input_size, 29 | hidden_size=hidden_size, 30 | bias=bias)) 31 | else: 32 | self.cells.append( 33 | basic_cell( 34 | input_size=hidden_size, 35 | hidden_size=hidden_size, 36 | bias=bias)) 37 | init_weights(self.modules()) 38 | 39 | def init_hidden(self, batch_size, device=None, value=0): 40 | raise NotImplementedError() 41 | 42 | def get_output(self, hiddens): 43 | raise NotImplementedError() 44 | 45 | def get_hidden_state(self, hidden): 46 | raise NotImplementedError() 47 | 48 | def forward(self, x, pre_hiddens): 49 | next_hiddens = [] 50 | 51 | hidden = None 52 | for i, cell in enumerate(self.cells): 53 | if i == 0: 54 | hidden = cell(x, pre_hiddens[i]) 55 | else: 56 | hidden = cell(self.get_hidden_state(hidden), pre_hiddens[i]) 57 | 58 | next_hiddens.append(hidden) 59 | 60 | return next_hiddens 61 | 62 | 63 | @SEQUENCE_DECODERS.register_module 64 | class LSTMCell(BaseCell): 65 | 66 | def __init__(self, input_size, hidden_size, bias=True, num_layers=1): 67 | super(LSTMCell, self).__init__(nn.LSTMCell, input_size, hidden_size, 68 | bias, num_layers) 69 | 70 | def init_hidden(self, batch_size, device=None, value=0): 71 | hiddens = [] 72 | for _ in range(self.num_layers): 73 | hidden = ( 74 | torch.FloatTensor(batch_size, 75 | self.hidden_size).fill_(value).to(device), 76 | torch.FloatTensor(batch_size, 77 | self.hidden_size).fill_(value).to(device), 78 | ) 79 | hiddens.append(hidden) 80 | 81 | return hiddens 82 | 83 | def get_output(self, hiddens): 84 | return hiddens[-1][0] 85 | 86 | def get_hidden_state(self, hidden): 87 | return hidden[0] 88 | 89 | 90 | @SEQUENCE_DECODERS.register_module 91 | class GRUCell(BaseCell): 92 | 93 | def __init__(self, input_size, hidden_size, bias=True, num_layers=1): 94 | super(GRUCell, self).__init__(nn.GRUCell, input_size, hidden_size, 95 | bias, num_layers) 96 | 97 | def init_hidden(self, batch_size, device=None, value=0): 98 | hiddens = [] 99 | for i in range(self.num_layers): 100 | hidden = torch.FloatTensor( 101 | batch_size, self.hidden_size).fill_(value).to(device) 102 | hiddens.append(hidden) 103 | 104 | return hiddens 105 | 106 | def get_output(self, hiddens): 107 | return hiddens[-1] 108 | 109 | def get_hidden_state(self, hidden): 110 | return hidden 111 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/rnn/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from vedastr.models.utils import build_torch_nn 4 | from vedastr.models.weight_init import init_weights 5 | from ..registry import SEQUENCE_ENCODERS 6 | 7 | 8 | @SEQUENCE_ENCODERS.register_module 9 | class RNN(nn.Module): 10 | 11 | def __init__(self, input_pool, layers, keep_order=False): 12 | super(RNN, self).__init__() 13 | self.keep_order = keep_order 14 | 15 | if input_pool: 16 | self.input_pool = build_torch_nn(input_pool) 17 | 18 | self.layers = nn.ModuleList() 19 | for i, (layer_name, layer_cfg) in enumerate(layers): 20 | if layer_name in ['rnn', 'fc']: 21 | self.layers.add_module('{}_{}'.format(i, layer_name), 22 | build_torch_nn(layer_cfg)) 23 | else: 24 | raise ValueError('Unknown layer name {}'.format(layer_name)) 25 | 26 | init_weights(self.modules()) 27 | 28 | @property 29 | def with_input_pool(self): 30 | return hasattr(self, 'input_pool') and self.input_pool 31 | 32 | def forward(self, x): 33 | if self.with_input_pool: 34 | out = self.input_pool(x).squeeze(2) 35 | else: 36 | out = x 37 | # input order (B, C, T) -> (B, T, C) 38 | out = out.permute(0, 2, 1) 39 | for layer_name, layer in self.layers.named_children(): 40 | 41 | if 'rnn' in layer_name: 42 | layer.flatten_parameters() 43 | out, _ = layer(out) 44 | else: 45 | out = layer(out) 46 | if not self.keep_order: 47 | out = out.permute(0, 2, 1).unsqueeze(2) 48 | 49 | return out.contiguous() 50 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import TransformerDecoder # noqa 401 2 | from .encoder import TransformerEncoder # noqa 401 3 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/decoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from vedastr.models.weight_init import init_weights 6 | from .position_encoder import build_position_encoder 7 | from .unit import build_decoder_layer 8 | from ..registry import SEQUENCE_DECODERS 9 | 10 | logger = logging.getLogger() 11 | 12 | 13 | @SEQUENCE_DECODERS.register_module 14 | class TransformerDecoder(nn.Module): 15 | 16 | def __init__(self, 17 | decoder_layer: dict, 18 | num_layers: int, 19 | position_encoder: dict = None): 20 | super(TransformerDecoder, self).__init__() 21 | 22 | if position_encoder is not None: 23 | self.pos_encoder = build_position_encoder(position_encoder) 24 | 25 | self.layers = nn.ModuleList( 26 | [build_decoder_layer(decoder_layer) for _ in range(num_layers)]) 27 | 28 | logger.info('TransformerDecoder init weights') 29 | init_weights(self.modules()) 30 | 31 | @property 32 | def with_position_encoder(self): 33 | return hasattr(self, 'pos_encoder') and self.pos_encoder is not None 34 | 35 | def forward(self, tgt, src, tgt_mask=None, src_mask=None): 36 | if self.with_position_encoder: 37 | tgt = self.pos_encoder(tgt) 38 | 39 | for layer in self.layers: 40 | tgt = layer(tgt, src, tgt_mask, src_mask) 41 | 42 | return tgt 43 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/encoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from vedastr.models.weight_init import init_weights 6 | from .position_encoder import build_position_encoder 7 | from .unit import build_encoder_layer 8 | from ..registry import SEQUENCE_ENCODERS 9 | 10 | logger = logging.getLogger() 11 | 12 | 13 | @SEQUENCE_ENCODERS.register_module 14 | class TransformerEncoder(nn.Module): 15 | 16 | def __init__(self, 17 | encoder_layer: dict, 18 | num_layers: int, 19 | position_encoder: dict = None): 20 | super(TransformerEncoder, self).__init__() 21 | 22 | if position_encoder is not None: 23 | self.pos_encoder = build_position_encoder(position_encoder) 24 | 25 | self.layers = nn.ModuleList( 26 | [build_encoder_layer(encoder_layer) for _ in range(num_layers)]) 27 | 28 | logger.info('TransformerEncoder init weights') 29 | init_weights(self.modules()) 30 | 31 | @property 32 | def with_position_encoder(self): 33 | return hasattr(self, 'pos_encoder') and self.pos_encoder is not None 34 | 35 | def forward(self, src, src_mask=None): 36 | if self.with_position_encoder: 37 | src = self.pos_encoder(src) 38 | 39 | for layer in self.layers: 40 | src = layer(src, src_mask) 41 | 42 | return src 43 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptive_2d_encoder import Adaptive2DPositionEncoder # noqa 401 2 | from .builder import build_position_encoder # noqa 401 3 | from .encoder import PositionEncoder1D # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/adaptive_2d_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .registry import POSITION_ENCODERS 4 | from .utils import generate_encoder 5 | 6 | 7 | @POSITION_ENCODERS.register_module 8 | class Adaptive2DPositionEncoder(nn.Module): 9 | 10 | def __init__(self, in_channels, max_h=200, max_w=200, dropout=0.1): 11 | super(Adaptive2DPositionEncoder, self).__init__() 12 | 13 | h_position_encoder = generate_encoder(in_channels, max_h) 14 | h_position_encoder = h_position_encoder.transpose(0, 1).view( 15 | 1, in_channels, max_h, 1) 16 | 17 | w_position_encoder = generate_encoder(in_channels, max_w) 18 | w_position_encoder = w_position_encoder.transpose(0, 1).view( 19 | 1, in_channels, 1, max_w) 20 | 21 | self.register_buffer('h_position_encoder', h_position_encoder) 22 | self.register_buffer('w_position_encoder', w_position_encoder) 23 | 24 | self.h_scale = self.scale_factor_generate(in_channels) 25 | self.w_scale = self.scale_factor_generate(in_channels) 26 | self.pool = nn.AdaptiveAvgPool2d(1) 27 | self.dropout = nn.Dropout(p=dropout) 28 | 29 | def scale_factor_generate(self, in_channels): 30 | scale_factor = nn.Sequential( 31 | nn.Conv2d(in_channels, in_channels, kernel_size=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.Sigmoid()) 34 | 35 | return scale_factor 36 | 37 | def forward(self, x): 38 | b, c, h, w = x.size() 39 | 40 | avg_pool = self.pool(x) 41 | 42 | h_pos_encoding = self.h_scale( 43 | avg_pool) * self.h_position_encoder[:, :, :h, :] 44 | w_pos_encoding = self.w_scale( 45 | avg_pool) * self.w_position_encoder[:, :, :, :w] 46 | 47 | out = x + h_pos_encoding + w_pos_encoding 48 | 49 | out = self.dropout(out) 50 | 51 | return out 52 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import POSITION_ENCODERS 3 | 4 | 5 | def build_position_encoder(cfg, default_args=None): 6 | position_encoder = build_from_cfg(cfg, POSITION_ENCODERS, default_args) 7 | 8 | return position_encoder 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .registry import POSITION_ENCODERS 4 | from .utils import generate_encoder 5 | 6 | 7 | @POSITION_ENCODERS.register_module 8 | class PositionEncoder1D(nn.Module): 9 | 10 | def __init__(self, in_channels, max_len=2000, dropout=0.1): 11 | super(PositionEncoder1D, self).__init__() 12 | 13 | position_encoder = generate_encoder(in_channels, max_len) 14 | position_encoder = position_encoder.unsqueeze(0) 15 | self.register_buffer('position_encoder', position_encoder) 16 | self.dropout = nn.Dropout(p=dropout) 17 | 18 | def forward(self, x): 19 | out = x + self.position_encoder[:, :x.size(1), :] 20 | out = self.dropout(out) 21 | 22 | return out 23 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | POSITION_ENCODERS = Registry('position_encoder') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/position_encoder/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_encoder(in_channels, max_len): 5 | pos = torch.arange(max_len).float().unsqueeze(1) 6 | 7 | i = torch.arange(in_channels).float().unsqueeze(0) 8 | angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / in_channels) 9 | 10 | position_encoder = pos * angle_rates 11 | position_encoder[:, 0::2] = torch.sin(position_encoder[:, 0::2]) 12 | position_encoder[:, 1::2] = torch.cos(position_encoder[:, 1::2]) 13 | 14 | return position_encoder 15 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_decoder_layer, build_encoder_layer # noqa 401 2 | from .decoder import TransformerDecoderLayer1D # noqa 401 3 | from .encoder import TransformerEncoderLayer1D, TransformerEncoderLayer2D # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_attention # noqa 401 2 | from .multihead_attention import MultiHeadAttention # noqa 401 3 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/attention/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import TRANSFORMER_ATTENTIONS 3 | 4 | 5 | def build_attention(cfg, default_args=None): 6 | attention = build_from_cfg(cfg, TRANSFORMER_ATTENTIONS, default_args) 7 | 8 | return attention 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/attention/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .registry import TRANSFORMER_ATTENTIONS 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | 9 | def __init__(self, temperature, dropout=0.1): 10 | super(ScaledDotProductAttention, self).__init__() 11 | 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(p=dropout) 14 | 15 | def forward(self, q, k, v, mask=None): 16 | attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature 17 | 18 | if mask is not None: 19 | attn = attn.masked_fill(mask=mask, value=float('-inf')) 20 | 21 | attn = torch.softmax(attn, dim=-1) 22 | attn = self.dropout(attn) 23 | 24 | out = torch.matmul(attn, v) 25 | 26 | return out, attn 27 | 28 | 29 | @TRANSFORMER_ATTENTIONS.register_module 30 | class MultiHeadAttention(nn.Module): 31 | 32 | def __init__(self, 33 | in_channels: int, 34 | k_channels: int, 35 | v_channels: int, 36 | n_head: int = 8, 37 | dropout: float = 0.1): 38 | super(MultiHeadAttention, self).__init__() 39 | 40 | self.in_channels = in_channels 41 | self.k_channels = k_channels 42 | self.v_channels = v_channels 43 | self.n_head = n_head 44 | 45 | self.q_linear = nn.Linear(in_channels, n_head * k_channels) 46 | self.k_linear = nn.Linear(in_channels, n_head * k_channels) 47 | self.v_linear = nn.Linear(in_channels, n_head * v_channels) 48 | self.attention = ScaledDotProductAttention( 49 | temperature=k_channels**0.5, dropout=dropout) 50 | self.out_linear = nn.Linear(n_head * v_channels, in_channels) 51 | 52 | self.dropout = nn.Dropout(p=dropout) 53 | 54 | def forward(self, q, k, v, mask=None): 55 | b, q_len, k_len, v_len = q.size(0), q.size(1), k.size(1), v.size(1) 56 | 57 | q = self.q_linear(q).view(b, q_len, self.n_head, 58 | self.k_channels).transpose(1, 2) 59 | k = self.k_linear(k).view(b, k_len, self.n_head, 60 | self.k_channels).transpose(1, 2) 61 | v = self.v_linear(v).view(b, v_len, self.n_head, 62 | self.v_channels).transpose(1, 2) 63 | 64 | if mask is not None: 65 | mask = mask.unsqueeze(1) 66 | 67 | out, attn = self.attention(q, k, v, mask=mask) 68 | 69 | out = out.transpose(1, 70 | 2).contiguous().view(b, q_len, 71 | self.n_head * self.v_channels) 72 | out = self.out_linear(out) 73 | out = self.dropout(out) 74 | 75 | return out, attn 76 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/attention/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | TRANSFORMER_ATTENTIONS = Registry('transformer_attention') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import TRANSFORMER_DECODER_LAYERS, TRANSFORMER_ENCODER_LAYERS 3 | 4 | 5 | def build_encoder_layer(cfg, default_args=None): 6 | encoder_layer = build_from_cfg(cfg, TRANSFORMER_ENCODER_LAYERS, 7 | default_args) 8 | 9 | return encoder_layer 10 | 11 | 12 | def build_decoder_layer(cfg, default_args=None): 13 | decoder_layer = build_from_cfg(cfg, TRANSFORMER_DECODER_LAYERS, 14 | default_args) 15 | 16 | return decoder_layer 17 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from vedastr.models.utils import build_torch_nn 4 | from .attention import build_attention 5 | from .feedforward import build_feedforward 6 | from .registry import TRANSFORMER_DECODER_LAYERS 7 | 8 | 9 | @TRANSFORMER_DECODER_LAYERS.register_module 10 | class TransformerDecoderLayer1D(nn.Module): 11 | 12 | def __init__(self, self_attention, self_attention_norm, attention, 13 | attention_norm, feedforward, feedforward_norm): 14 | super(TransformerDecoderLayer1D, self).__init__() 15 | 16 | self.self_attention = build_attention(self_attention) 17 | self.self_attention_norm = build_torch_nn(self_attention_norm) 18 | 19 | self.attention = build_attention(attention) 20 | self.attention_norm = build_torch_nn(attention_norm) 21 | 22 | self.feedforward = build_feedforward(feedforward) 23 | self.feedforward_norm = build_torch_nn(feedforward_norm) 24 | 25 | def forward(self, tgt, src, tgt_mask=None, src_mask=None): 26 | attn1, _ = self.self_attention(tgt, tgt, tgt, tgt_mask) 27 | out1 = self.self_attention_norm(tgt + attn1) 28 | 29 | size = src.size() 30 | if len(size) == 4: 31 | b, c, h, w = size 32 | src = src.view(b, c, h * w).transpose(1, 2) 33 | if src_mask is not None: 34 | src_mask = src_mask.view(b, 1, h * w) 35 | 36 | attn2, _ = self.attention(out1, src, src, src_mask) 37 | out2 = self.attention_norm(out1 + attn2) 38 | 39 | ffn_out = self.feedforward(out2) 40 | out3 = self.feedforward_norm(out2 + ffn_out) 41 | 42 | return out3 43 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from vedastr.models.utils import build_torch_nn 4 | from .attention import build_attention 5 | from .feedforward import build_feedforward 6 | from .registry import TRANSFORMER_ENCODER_LAYERS 7 | 8 | 9 | class _TransformerEncoderLayer(nn.Module): 10 | 11 | def __init__(self, attention, attention_norm, feedforward, 12 | feedforward_norm): 13 | super(_TransformerEncoderLayer, self).__init__() 14 | self.attention = build_attention(attention) 15 | self.attention_norm = build_torch_nn(attention_norm) 16 | 17 | self.feedforward = build_feedforward(feedforward) 18 | self.feedforward_norm = build_torch_nn(feedforward_norm) 19 | 20 | 21 | @TRANSFORMER_ENCODER_LAYERS.register_module 22 | class TransformerEncoderLayer1D(_TransformerEncoderLayer): 23 | 24 | def __init__(self, attention, attention_norm, feedforward, 25 | feedforward_norm): 26 | super(TransformerEncoderLayer1D, 27 | self).__init__(attention, attention_norm, feedforward, 28 | feedforward_norm) 29 | 30 | def forward(self, src, src_mask=None): 31 | attn_out, _ = self.attention(src, src, src, src_mask) 32 | out1 = self.attention_norm(src + attn_out) 33 | 34 | ffn_out = self.feedforward(out1) 35 | out2 = self.feedforward_norm(out1 + ffn_out) 36 | 37 | return out2 38 | 39 | 40 | @TRANSFORMER_ENCODER_LAYERS.register_module 41 | class TransformerEncoderLayer2D(_TransformerEncoderLayer): 42 | 43 | def __init__(self, attention, attention_norm, feedforward, 44 | feedforward_norm): 45 | super(TransformerEncoderLayer2D, 46 | self).__init__(attention, attention_norm, feedforward, 47 | feedforward_norm) 48 | 49 | def norm(self, norm_layer, x): 50 | b, c, h, w = x.size() 51 | 52 | if isinstance(norm_layer, nn.LayerNorm): 53 | out = x.view(b, c, h * w).transpose(1, 2) 54 | out = norm_layer(out) 55 | out = out.transpose(1, 2).contiguous().view(b, c, h, w) 56 | else: 57 | out = norm_layer(x) 58 | 59 | return out 60 | 61 | def forward(self, src, src_mask=None): 62 | b, c, h, w = src.size() 63 | 64 | src = src.view(b, c, h * w).transpose(1, 2) 65 | if src_mask is not None: 66 | src_mask = src_mask.view(b, 1, h * w) 67 | 68 | attn_out, _ = self.attention(src, src, src, src_mask) 69 | out1 = src + attn_out 70 | out1 = out1.transpose(1, 2).contiguous().view(b, c, h, w) 71 | out1 = self.norm(self.attention_norm, out1) 72 | 73 | ffn_out = self.feedforward(out1) 74 | out2 = self.norm(self.feedforward_norm, out1 + ffn_out) 75 | 76 | return out2 77 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/feedforward/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_feedforward # noqa 401 2 | from .feedforward import Feedforward # noqa 401 3 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/feedforward/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import TRANSFORMER_FEEDFORWARDS 3 | 4 | 5 | def build_feedforward(cfg, default_args=None): 6 | feedforward = build_from_cfg(cfg, TRANSFORMER_FEEDFORWARDS, default_args) 7 | 8 | return feedforward 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/feedforward/feedforward.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from vedastr.models.utils import build_module 4 | from .registry import TRANSFORMER_FEEDFORWARDS 5 | 6 | 7 | @TRANSFORMER_FEEDFORWARDS.register_module 8 | class Feedforward(nn.Module): 9 | 10 | def __init__(self, layers: dict): 11 | super(Feedforward, self).__init__() 12 | 13 | self.layers = [build_module(layer) for layer in layers] 14 | self.layers = nn.Sequential(*self.layers) 15 | 16 | def forward(self, x): 17 | out = self.layers(x) 18 | 19 | return out 20 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/feedforward/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | TRANSFORMER_FEEDFORWARDS = Registry('transformer_feedforward') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/bodies/sequences/transformer/unit/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | TRANSFORMER_ENCODER_LAYERS = Registry('transformer_encoder_layer') 4 | TRANSFORMER_DECODER_LAYERS = Registry('transformer_decoder_layer') 5 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/builder.py: -------------------------------------------------------------------------------- 1 | from ..utils import build_from_cfg 2 | from .registry import MODELS 3 | 4 | 5 | def build_model(cfg, default_args=None): 6 | model = build_from_cfg(cfg, MODELS, default_args) 7 | 8 | return model 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .att_head import AttHead # noqa 401 2 | from .builder import build_head # noqa 401 3 | from .conv_head import ConvHead # noqa 401 4 | from .ctc_head import CTCHead # noqa 401 5 | from .fc_head import FCHead # noqa 401 6 | from .head import Head # noqa 401 7 | from .multi_head import MultiHead # noqa 401 8 | from .transformer_head import TransformerHead # noqa 401 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/att_head.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from vedastr.models.bodies import build_brick, build_sequence_decoder 7 | from vedastr.models.utils import build_torch_nn 8 | from vedastr.models.weight_init import init_weights 9 | from .registry import HEADS 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | @HEADS.register_module 15 | class AttHead(nn.Module): 16 | 17 | def __init__(self, 18 | cell, 19 | generator, 20 | num_steps, 21 | num_class, 22 | input_attention_block=None, 23 | output_attention_block=None, 24 | text_transform=None, 25 | holistic_input_from=None): 26 | super(AttHead, self).__init__() 27 | 28 | if input_attention_block is not None: 29 | self.input_attention_block = build_brick(input_attention_block) 30 | 31 | self.cell = build_sequence_decoder(cell) 32 | self.generator = build_torch_nn(generator) 33 | self.num_steps = num_steps 34 | self.num_class = num_class 35 | 36 | if output_attention_block is not None: 37 | self.output_attention_block = build_brick(output_attention_block) 38 | 39 | if text_transform is not None: 40 | self.text_transform = build_torch_nn(text_transform) 41 | 42 | if holistic_input_from is not None: 43 | self.holistic_input_from = holistic_input_from 44 | 45 | self.register_buffer('embeddings', 46 | torch.diag(torch.ones(self.num_class))) 47 | logger.info('AttHead init weights') 48 | init_weights(self.modules()) 49 | 50 | @property 51 | def with_holistic_input(self): 52 | return hasattr(self, 53 | 'holistic_input_from') and self.holistic_input_from 54 | 55 | @property 56 | def with_input_attention(self): 57 | return hasattr( 58 | self, 59 | 'input_attention_block') and self.input_attention_block is not None 60 | 61 | @property 62 | def with_output_attention(self): 63 | return hasattr(self, 'output_attention_block' 64 | ) and self.output_attention_block is not None 65 | 66 | @property 67 | def with_text_transform(self): 68 | return hasattr(self, 'text_transform') and self.text_transform 69 | 70 | def forward(self, feats, texts): 71 | batch_size = texts.size(0) 72 | 73 | hidden = self.cell.init_hidden(batch_size, device=texts.device) 74 | if self.with_holistic_input: 75 | holistic_input = feats[self.holistic_input_from][:, :, 0, -1] 76 | hidden = self.cell(holistic_input, hidden) 77 | 78 | out = [] 79 | 80 | if self.training: 81 | use_gt = True 82 | assert self.num_steps == texts.size(1) 83 | else: 84 | use_gt = False 85 | assert texts.size(1) == 1 86 | 87 | for i in range(self.num_steps): 88 | if i == 0: 89 | indexes = texts[:, i] 90 | else: 91 | if use_gt: 92 | indexes = texts[:, i] 93 | else: 94 | _, indexes = out[-1].max(1) 95 | text_feat = self.embeddings.index_select(0, indexes) 96 | 97 | if self.with_text_transform: 98 | text_feat = self.text_transform(text_feat) 99 | 100 | if self.with_input_attention: 101 | attention_feat = self.input_attention_block( 102 | feats, 103 | self.cell.get_output(hidden).unsqueeze(-1).unsqueeze(-1)) 104 | cell_input = torch.cat([attention_feat, text_feat], dim=1) 105 | else: 106 | cell_input = text_feat 107 | hidden = self.cell(cell_input, hidden) 108 | out_feat = self.cell.get_output(hidden) 109 | 110 | if self.with_output_attention: 111 | attention_feat = self.output_attention_block( 112 | feats, 113 | self.cell.get_output(hidden).unsqueeze(-1).unsqueeze(-1)) 114 | out_feat = torch.cat( 115 | [self.cell.get_output(hidden), attention_feat], dim=1) 116 | 117 | out.append(self.generator(out_feat)) 118 | 119 | out = torch.stack(out, dim=1) 120 | 121 | return out 122 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/builder.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import build_from_cfg 2 | from .registry import HEADS 3 | 4 | 5 | def build_head(cfg, default_args=None): 6 | head = build_from_cfg(cfg, HEADS, default_args) 7 | return head 8 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/conv_head.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from vedastr.models.utils import ConvModules 6 | from vedastr.models.weight_init import init_weights 7 | from .registry import HEADS 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | @HEADS.register_module 13 | class ConvHead(nn.Module): 14 | """FCHead 15 | 16 | Args: 17 | """ 18 | 19 | def __init__(self, 20 | in_channels, 21 | num_class, 22 | from_layer, 23 | num_convs=0, 24 | inner_channels=None, 25 | kernel_size=1, 26 | padding=0, 27 | **kwargs): 28 | super(ConvHead, self).__init__() 29 | 30 | self.from_layer = from_layer 31 | 32 | self.conv = [] 33 | if num_convs > 0: 34 | out_channels = inner_channels 35 | self.conv.append( 36 | ConvModules( 37 | in_channels, 38 | out_channels, 39 | num_convs=num_convs, 40 | kernel_size=kernel_size, 41 | padding=padding, 42 | **kwargs)) 43 | else: 44 | out_channels = in_channels 45 | self.conv.append( 46 | nn.Conv2d( 47 | out_channels, 48 | num_class, 49 | kernel_size=kernel_size, 50 | padding=padding)) 51 | self.conv = nn.Sequential(*self.conv) 52 | 53 | logger.info('ConvHead init weights') 54 | init_weights(self.modules()) 55 | 56 | def forward(self, x_input): 57 | x = x_input[self.from_layer] 58 | assert x.size(2) == 1 59 | 60 | out = self.conv(x).mean(2).permute(0, 2, 1) 61 | 62 | return out 63 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/ctc_head.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from vedastr.models.utils import build_torch_nn 6 | from vedastr.models.weight_init import init_weights 7 | from .registry import HEADS 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | @HEADS.register_module 13 | class CTCHead(nn.Module): 14 | """CTCHead 15 | 16 | """ 17 | 18 | def __init__( 19 | self, 20 | in_channels, 21 | num_class, 22 | from_layer, 23 | pool=None, 24 | export=False, 25 | ): 26 | super(CTCHead, self).__init__() 27 | 28 | self.num_class = num_class 29 | self.from_layer = from_layer 30 | fc = nn.Linear(in_channels, num_class) 31 | if pool is not None: 32 | self.pool = build_torch_nn(pool) 33 | self.fc = fc 34 | self.export = export 35 | 36 | logger.info('CTCHead init weights') 37 | init_weights(self.modules()) 38 | 39 | @property 40 | def with_pool(self): 41 | return hasattr(self, 'pool') and self.pool is not None 42 | 43 | def forward(self, x_input): 44 | x = x_input[self.from_layer] 45 | if self.export: 46 | x = x.mean(2).permute(0, 2, 1) 47 | elif self.with_pool: 48 | x = self.pool(x).permute(0, 3, 1, 2).squeeze(3) 49 | out = self.fc(x) 50 | 51 | return out 52 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/fc_head.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from vedastr.models.utils import FCModules, build_torch_nn 6 | from vedastr.models.weight_init import init_weights 7 | from .registry import HEADS 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | @HEADS.register_module 13 | class FCHead(nn.Module): 14 | """FCHead 15 | 16 | Args: 17 | """ 18 | 19 | def __init__(self, 20 | in_channels, 21 | out_channels, 22 | num_class, 23 | batch_max_length, 24 | from_layer, 25 | inner_channels=None, 26 | bias=True, 27 | activation='relu', 28 | inplace=True, 29 | dropouts=None, 30 | num_fcs=0, 31 | pool=None): 32 | super(FCHead, self).__init__() 33 | 34 | self.num_class = num_class 35 | self.batch_max_length = batch_max_length 36 | self.from_layer = from_layer 37 | 38 | if num_fcs > 0: 39 | inter_fc = FCModules(in_channels, inner_channels, bias, activation, 40 | inplace, dropouts, num_fcs) 41 | fc = nn.Linear(inner_channels, out_channels) 42 | else: 43 | inter_fc = nn.Sequential() 44 | fc = nn.Linear(in_channels, out_channels) 45 | 46 | if pool is not None: 47 | self.pool = build_torch_nn(pool) 48 | 49 | self.inter_fc = inter_fc 50 | self.fc = fc 51 | 52 | logger.info('FCHead init weights') 53 | init_weights(self.modules()) 54 | 55 | @property 56 | def with_pool(self): 57 | return hasattr(self, 'pool') and self.pool is not None 58 | 59 | def forward(self, x_input): 60 | x = x_input[self.from_layer] 61 | batch_size = x.size(0) 62 | 63 | if self.with_pool: 64 | x = self.pool(x) 65 | 66 | x = x.contiguous().view(batch_size, -1) 67 | 68 | out = self.inter_fc(x) 69 | out = self.fc(out) 70 | 71 | return out.reshape(-1, self.batch_max_length + 1, self.num_class) 72 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/head.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from vedastr.models.utils import build_module 6 | from vedastr.models.weight_init import init_weights 7 | from .registry import HEADS 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | @HEADS.register_module 13 | class Head(nn.Module): 14 | """Head 15 | 16 | Args: 17 | """ 18 | 19 | def __init__( 20 | self, 21 | from_layer, 22 | generator, 23 | ): 24 | super(Head, self).__init__() 25 | 26 | self.from_layer = from_layer 27 | self.generator = build_module(generator) 28 | 29 | logger.info('Head init weights') 30 | init_weights(self.modules()) 31 | 32 | def forward(self, feats): 33 | x = feats[self.from_layer] 34 | out = self.generator(x) 35 | 36 | return out 37 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/multi_head.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from vedastr.models.utils import FCModules, build_module, build_torch_nn 7 | from vedastr.models.weight_init import init_weights 8 | from .registry import HEADS 9 | 10 | logger = logging.getLogger() 11 | 12 | 13 | @HEADS.register_module 14 | class MultiHead(nn.Module): 15 | """MultiHead 16 | 17 | Args: 18 | """ 19 | 20 | def __init__(self, 21 | in_channels, 22 | num_class, 23 | batch_max_length, 24 | from_layer, 25 | inners=None, 26 | skip_connections=None, 27 | inner_channels=None, 28 | bias=True, 29 | activation='relu', 30 | inplace=True, 31 | dropouts=None, 32 | embedding=False, 33 | num_fcs=0, 34 | pool=None): 35 | super(MultiHead, self).__init__() 36 | 37 | self.num_class = num_class 38 | self.batch_max_length = batch_max_length 39 | self.embedding = embedding 40 | self.from_layer = from_layer 41 | 42 | if inners is not None: 43 | self.inners = [] 44 | for inner_cfg in inners: 45 | self.inners.append(build_module(inner_cfg)) 46 | self.inners = nn.Sequential(*self.inners) 47 | 48 | if skip_connections is not None: 49 | self.skip_layer = [] 50 | for skip_cfg in skip_connections: 51 | self.skip_layer.append(build_module(skip_cfg)) 52 | self.skip_layer = nn.Sequential(*self.skip_layer) 53 | 54 | self.fcs = nn.ModuleList() 55 | for i in range(batch_max_length + 1): 56 | if num_fcs > 0: 57 | inter_fc = FCModules(in_channels, inner_channels, bias, 58 | activation, inplace, dropouts, num_fcs) 59 | else: 60 | inter_fc = nn.Sequential() 61 | fc = nn.Linear(in_channels, num_class) 62 | self.fcs.append(nn.Sequential(inter_fc, fc)) 63 | 64 | if pool is not None: 65 | self.pool = build_torch_nn(pool) 66 | 67 | logger.info('MultiHead init weights') 68 | init_weights(self.modules()) 69 | 70 | @property 71 | def with_pool(self): 72 | return hasattr(self, 'pool') and self.pool is not None 73 | 74 | @property 75 | def with_inners(self): 76 | return hasattr(self, 'inners') and self.inners is not None 77 | 78 | @property 79 | def with_skip_layer(self): 80 | return hasattr(self, 'skip_layer') and self.skip_layer is not None 81 | 82 | def forward(self, x_input): 83 | x = x_input[self.from_layer] 84 | batch_size = x.size(0) 85 | 86 | if self.with_pool: 87 | x = self.pool(x) 88 | if self.with_inners: 89 | inner_x = self.inners(x) 90 | if self.with_skip_layer: 91 | short_x = self.skip_layer(x) 92 | x = inner_x + short_x 93 | else: 94 | x = inner_x + x 95 | x = x.contiguous().view(batch_size, -1) 96 | else: 97 | x = x.squeeze() 98 | x = torch.split(x.squeeze(), (1, ) * x.size(2), dim=2) 99 | 100 | outs = [] 101 | for idx, layer in enumerate(self.fcs): 102 | out = layer(x) if not isinstance(x, tuple) else layer( 103 | x[idx].squeeze()) 104 | out = out.unsqueeze(1) 105 | outs.append(out) 106 | outs = torch.cat(outs, dim=1) 107 | 108 | return outs 109 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils.registry import Registry 2 | 3 | HEADS = Registry('head') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/heads/transformer_head.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from vedastr.models.bodies import build_sequence_decoder 8 | from vedastr.models.utils import build_torch_nn 9 | from vedastr.models.weight_init import init_weights 10 | from .registry import HEADS 11 | 12 | logger = logging.getLogger() 13 | 14 | 15 | @HEADS.register_module 16 | class TransformerHead(nn.Module): 17 | 18 | def __init__( 19 | self, 20 | decoder, 21 | generator, 22 | embedding, 23 | num_steps, 24 | pad_id, 25 | src_from, 26 | src_mask_from=None, 27 | ): 28 | super(TransformerHead, self).__init__() 29 | 30 | self.decoder = build_sequence_decoder(decoder) 31 | self.generator = build_torch_nn(generator) 32 | self.embedding = build_torch_nn(embedding) 33 | self.num_steps = num_steps 34 | self.pad_id = pad_id 35 | self.src_from = src_from 36 | self.src_mask_from = src_mask_from 37 | 38 | logger.info('TransformerHead init weights') 39 | init_weights(self.modules()) 40 | 41 | def pad_mask(self, text): 42 | pad_mask = (text == self.pad_id) 43 | pad_mask[:, 0] = False 44 | pad_mask = pad_mask.unsqueeze(1) 45 | 46 | return pad_mask 47 | 48 | def order_mask(self, text): 49 | t = text.size(1) 50 | order_mask = torch.triu(torch.ones(t, t), diagonal=1).bool() 51 | order_mask = order_mask.unsqueeze(0).to(text.device) 52 | 53 | return order_mask 54 | 55 | def text_embedding(self, texts): 56 | tgt = self.embedding(texts) 57 | tgt *= math.sqrt(tgt.size(2)) 58 | 59 | return tgt 60 | 61 | def forward(self, feats, texts): 62 | src = feats[self.src_from] 63 | if self.src_mask_from: 64 | src_mask = feats[self.src_mask_from] 65 | else: 66 | src_mask = None 67 | 68 | if self.training: 69 | tgt = self.text_embedding(texts) 70 | tgt_mask = (self.pad_mask(texts) | self.order_mask(texts)) 71 | 72 | out = self.decoder(tgt, src, tgt_mask, src_mask) 73 | out = self.generator(out) 74 | else: 75 | out = None 76 | for _ in range(self.num_steps): 77 | tgt = self.text_embedding(texts) 78 | tgt_mask = self.order_mask(texts) 79 | out = self.decoder(tgt, src, tgt_mask, src_mask) 80 | out = self.generator(out) 81 | next_text = torch.argmax(out[:, -1:, :], dim=-1) 82 | 83 | texts = torch.cat([texts, next_text], dim=-1) 84 | 85 | return out 86 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .bodies import build_body 4 | from .heads import build_head 5 | from .registry import MODELS 6 | 7 | 8 | @MODELS.register_module 9 | class GModel(nn.Module): 10 | 11 | def __init__(self, body, head, need_text=True): 12 | super(GModel, self).__init__() 13 | 14 | self.body = build_body(body) 15 | self.head = build_head(head) 16 | self.need_text = need_text 17 | 18 | def forward(self, inputs): 19 | if not isinstance(inputs, (tuple, list)): 20 | inputs = [inputs] 21 | x = self.body(inputs[0]) 22 | if self.need_text: 23 | out = self.head(x, inputs[1]) 24 | else: 25 | out = self.head(x) 26 | 27 | return out 28 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/registry.py: -------------------------------------------------------------------------------- 1 | from ..utils import Registry 2 | 3 | MODELS = Registry('model') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_module, build_torch_nn # noqa 401 2 | from .cbam import CBAM # noqa 401 3 | from .conv_module import ConvModule, ConvModules # noqa 401 4 | from .fc_module import FCModule, FCModules # noqa 401 5 | from .non_local import NonLocal2d # noqa 401 6 | from .norm import build_norm_layer # noqa 401 7 | from .residual_module import BasicBlock, Bottleneck # noqa 401 8 | from .upsample import Upsample # noqa 401 9 | from .squeeze_excitation_module import SE # noqa 401 10 | 11 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/utils/builder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from vedastr.utils import build_from_cfg 4 | from .registry import UTILS 5 | 6 | 7 | def build_module(cfg, default_args=None): 8 | util = build_from_cfg(cfg, UTILS, default_args) 9 | return util 10 | 11 | 12 | def build_torch_nn(cfg, default_args=None): 13 | module = build_from_cfg(cfg, nn, default_args, 'module') 14 | return module 15 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/utils/cbam.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .conv_module import ConvModule 8 | from .registry import UTILS 9 | 10 | 11 | class Flatten(nn.Module): 12 | 13 | def forward(self, x): 14 | return x.view(x.size(0), -1) 15 | 16 | 17 | class ChannelGate(nn.Module): 18 | 19 | def __init__(self, 20 | gate_channels, 21 | reduction_ratio=16, 22 | pool_types=['avg', 'max']): 23 | super(ChannelGate, self).__init__() 24 | self.gate_channels = gate_channels 25 | self.mlp = nn.Sequential( 26 | Flatten(), 27 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 28 | nn.ReLU(), 29 | nn.Linear(gate_channels // reduction_ratio, gate_channels)) 30 | self.pool_types = pool_types 31 | 32 | def forward(self, x): 33 | channel_att_sum = None 34 | for pool_type in self.pool_types: 35 | if pool_type == 'avg': 36 | avg_pool = F.avg_pool2d( 37 | x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 38 | channel_att_raw = self.mlp(avg_pool) 39 | elif pool_type == 'max': 40 | max_pool = F.max_pool2d( 41 | x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 42 | channel_att_raw = self.mlp(max_pool) 43 | elif pool_type == 'lp': 44 | lp_pool = F.lp_pool2d( 45 | x, 46 | 2, (x.size(2), x.size(3)), 47 | stride=(x.size(2), x.size(3))) 48 | channel_att_raw = self.mlp(lp_pool) 49 | elif pool_type == 'lse': 50 | # LSE pool only 51 | lse_pool = logsumexp_2d(x) 52 | channel_att_raw = self.mlp(lse_pool) 53 | 54 | if channel_att_sum is None: 55 | channel_att_sum = channel_att_raw 56 | else: 57 | channel_att_sum = channel_att_sum + channel_att_raw 58 | 59 | scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as( 60 | x) 61 | return x * scale 62 | 63 | 64 | def logsumexp_2d(tensor): 65 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 66 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 67 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 68 | return outputs 69 | 70 | 71 | class ChannelPool(nn.Module): 72 | 73 | def forward(self, x): 74 | return torch.cat( 75 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), 76 | dim=1) 77 | 78 | 79 | class SpatialGate(nn.Module): 80 | 81 | def __init__(self, norm_cfg=None): 82 | super(SpatialGate, self).__init__() 83 | kernel_size = 7 84 | self.compress = ChannelPool() 85 | self.spatial = ConvModule( 86 | 2, 87 | 1, 88 | kernel_size, 89 | stride=1, 90 | padding=(kernel_size - 1) // 2, 91 | activation=None, 92 | norm_cfg=norm_cfg) 93 | 94 | def forward(self, x): 95 | x_compress = self.compress(x) 96 | x_out = self.spatial(x_compress) 97 | scale = F.sigmoid(x_out) # broadcasting 98 | return x * scale 99 | 100 | 101 | @UTILS.register_module 102 | class CBAM(nn.Module): 103 | 104 | def __init__( 105 | self, 106 | gate_channels, 107 | reduction_ratio=16, 108 | pool_types=['avg', 'max'], 109 | no_spatial=False, 110 | norm_cfg=None, 111 | ): 112 | # TODO, default CBAM BN args 113 | # out_planes, eps=1e-5, momentum=0.01, affine=True 114 | super(CBAM, self).__init__() 115 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, 116 | pool_types) 117 | self.no_spatial = no_spatial 118 | if not no_spatial: 119 | self.SpatialGate = SpatialGate(norm_cfg) 120 | 121 | def forward(self, x): 122 | x_out = self.ChannelGate(x) 123 | if not self.no_spatial: 124 | x_out = self.SpatialGate(x_out) 125 | return x_out 126 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/utils/fc_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .registry import UTILS 4 | 5 | 6 | @UTILS.register_module 7 | class FCModule(nn.Module): 8 | """FCModule 9 | 10 | Args: 11 | """ 12 | 13 | def __init__(self, 14 | in_channels, 15 | out_channels, 16 | bias=True, 17 | activation='relu', 18 | inplace=True, 19 | dropout=None, 20 | order=('fc', 'act')): 21 | super(FCModule, self).__init__() 22 | self.order = order 23 | self.activation = activation 24 | self.inplace = inplace 25 | 26 | self.with_activatation = activation is not None 27 | self.with_dropout = dropout is not None 28 | 29 | self.fc = nn.Linear(in_channels, out_channels, bias) 30 | 31 | # build activation layer 32 | if self.with_activatation: 33 | # TODO: introduce `activation` and supports more activation layers 34 | if self.activation not in ['relu', 'tanh', 'sigmoid']: 35 | raise ValueError('{} is currently not supported.'.format( 36 | self.activation)) 37 | if self.activation == 'relu': 38 | self.activate = nn.ReLU(inplace=inplace) 39 | elif self.activation == 'tanh': 40 | self.activate = nn.Tanh() 41 | elif self.activation == 'sigmoid': 42 | self.activate = nn.Sigmoid() 43 | 44 | if self.with_dropout: 45 | self.dropout = nn.Dropout(p=dropout) 46 | 47 | def forward(self, x): 48 | if self.order == ('fc', 'act'): 49 | x = self.fc(x) 50 | 51 | if self.with_activatation: 52 | x = self.activate(x) 53 | elif self.order == ('act', 'fc'): 54 | if self.with_activatation: 55 | x = self.activate(x) 56 | x = self.fc(x) 57 | 58 | if self.with_dropout: 59 | x = self.dropout(x) 60 | 61 | return x 62 | 63 | 64 | @UTILS.register_module 65 | class FCModules(nn.Module): 66 | """FCModules 67 | 68 | Args: 69 | """ 70 | 71 | def __init__(self, 72 | in_channels, 73 | out_channels, 74 | bias=True, 75 | activation='relu', 76 | inplace=True, 77 | dropouts=None, 78 | num_fcs=1): 79 | super().__init__() 80 | 81 | if dropouts is not None: 82 | assert num_fcs == len(dropouts) 83 | dropout = dropouts[0] 84 | else: 85 | dropout = None 86 | 87 | layers = [ 88 | FCModule(in_channels, out_channels, bias, activation, inplace, 89 | dropout) 90 | ] 91 | for ii in range(1, num_fcs): 92 | if dropouts is not None: 93 | dropout = dropouts[ii] 94 | else: 95 | dropout = None 96 | layers.append( 97 | FCModule(out_channels, out_channels, bias, activation, inplace, 98 | dropout)) 99 | 100 | self.block = nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | feat = self.block(x) 104 | 105 | return feat 106 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/utils/norm.py: -------------------------------------------------------------------------------- 1 | # modify from mmcv and mmdetection 2 | 3 | import torch.nn as nn 4 | 5 | norm_cfg = { 6 | # format: layer_type: (abbreviation, module) 7 | 'BN': ('bn', nn.BatchNorm2d), 8 | 'SyncBN': ('bn', nn.SyncBatchNorm), 9 | 'GN': ('gn', nn.GroupNorm), 10 | # and potentially 'SN' 11 | } 12 | 13 | 14 | def build_norm_layer(cfg, num_features, postfix='', layer_only=False): 15 | """ Build normalization layer 16 | 17 | Args: 18 | cfg (dict): cfg should contain: 19 | type (str): identify norm layer type. 20 | layer args: args needed to instantiate a norm layer. 21 | requires_grad (bool): [optional] whether stop gradient updates 22 | num_features (int): number of channels from input. 23 | postfix (int, str): appended into norm abbreviation to 24 | create named layer. 25 | 26 | Returns: 27 | name (str): abbreviation + postfix 28 | layer (nn.Module): created norm layer 29 | """ 30 | assert isinstance(cfg, dict) and 'type' in cfg 31 | cfg_ = cfg.copy() 32 | 33 | layer_type = cfg_.pop('type') 34 | if layer_type not in norm_cfg: 35 | raise KeyError('Unrecognized norm type {}'.format(layer_type)) 36 | else: 37 | abbr, norm_layer = norm_cfg[layer_type] 38 | if norm_layer is None: 39 | raise NotImplementedError 40 | 41 | assert isinstance(postfix, (int, str)) 42 | name = abbr + str(postfix) 43 | 44 | requires_grad = cfg_.pop('requires_grad', True) 45 | if layer_type != 'GN': 46 | layer = norm_layer(num_features, **cfg_) 47 | if layer_type == 'SyncBN': 48 | layer._specify_ddp_gpu_num(1) 49 | else: 50 | assert 'num_groups' in cfg_ 51 | layer = norm_layer(num_channels=num_features, **cfg_) 52 | 53 | for param in layer.parameters(): 54 | param.requires_grad = requires_grad 55 | 56 | if layer_only: 57 | return layer 58 | 59 | return name, layer 60 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/utils/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | UTILS = Registry('utils') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/utils/squeeze_excitation_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .fc_module import FCModule 4 | from .registry import UTILS 5 | 6 | 7 | @UTILS.register_module 8 | class SE(nn.Module): 9 | 10 | def __init__(self, channel, reduction): 11 | # TODO, input channel should has same name with other modules 12 | 13 | super(SE, self).__init__() 14 | assert channel % reduction == 0, \ 15 | "Input_channel can't be evenly divided by reduction." 16 | 17 | self.pool = nn.AdaptiveAvgPool2d(1) 18 | self.layer = nn.Sequential( 19 | FCModule(channel, channel // reduction, bias=False), 20 | FCModule( 21 | channel // reduction, 22 | channel, 23 | bias=False, 24 | activation='sigmoid'), 25 | ) 26 | 27 | def forward(self, x): 28 | b, c, _, _ = x.size() 29 | y = self.pool(x).view(b, c) 30 | y = self.layer(y).view(b, c, 1, 1) 31 | 32 | return x * y.expand_as(x) 33 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/utils/upsample.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .registry import UTILS 5 | 6 | 7 | @UTILS.register_module 8 | class Upsample(nn.Module): 9 | __constants__ = [ 10 | 'size', 'scale_factor', 'scale_bias', 'mode', 'align_corners', 'name' 11 | ] 12 | 13 | def __init__(self, 14 | size=None, 15 | scale_factor=None, 16 | scale_bias=0, 17 | mode='nearest', 18 | align_corners=None): 19 | super(Upsample, self).__init__() 20 | self.size = size 21 | self.scale_factor = scale_factor 22 | self.scale_bias = scale_bias 23 | self.mode = mode 24 | self.align_corners = align_corners 25 | 26 | assert (self.size is None) ^ (self.scale_factor is None) 27 | 28 | def forward(self, x): 29 | if self.size: 30 | size = self.size 31 | else: 32 | n, c, h, w = x.size() 33 | new_h = int(h * self.scale_factor + self.scale_bias) 34 | new_w = int(w * self.scale_factor + self.scale_bias) 35 | 36 | size = (new_h, new_w) 37 | 38 | return F.interpolate( 39 | x, size=size, mode=self.mode, align_corners=self.align_corners) 40 | 41 | def extra_repr(self): 42 | if self.size is not None: 43 | info = 'size=' + str(self.size) 44 | else: 45 | info = 'scale_factor=' + str(self.scale_factor) 46 | info += ', scale_bias=' + str(self.scale_bias) 47 | info += ', mode=' + self.mode 48 | return info 49 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/models/weight_init.py: -------------------------------------------------------------------------------- 1 | # modify from mmcv and mmdetection 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def constant_init(module, val, bias=0): 7 | nn.init.constant_(module.weight, val) 8 | if hasattr(module, 'bias') and module.bias is not None: 9 | nn.init.constant_(module.bias, bias) 10 | 11 | 12 | def xavier_init(module, gain=1, bias=0, distribution='normal'): 13 | assert distribution in ['uniform', 'normal'] 14 | if distribution == 'uniform': 15 | nn.init.xavier_uniform_(module.weight, gain=gain) 16 | else: 17 | nn.init.xavier_normal_(module.weight, gain=gain) 18 | if hasattr(module, 'bias') and module.bias is not None: 19 | nn.init.constant_(module.bias, bias) 20 | 21 | 22 | def normal_init(module, mean=0, std=1, bias=0): 23 | nn.init.normal_(module.weight, mean, std) 24 | if hasattr(module, 'bias') and module.bias is not None: 25 | nn.init.constant_(module.bias, bias) 26 | 27 | 28 | def uniform_init(module, a=0, b=1, bias=0): 29 | nn.init.uniform_(module.weight, a, b) 30 | if hasattr(module, 'bias') and module.bias is not None: 31 | nn.init.constant_(module.bias, bias) 32 | 33 | 34 | def kaiming_init(module, 35 | a=0, 36 | is_rnn=False, 37 | mode='fan_in', 38 | nonlinearity='leaky_relu', 39 | bias=0, 40 | distribution='normal'): 41 | assert distribution in ['uniform', 'normal'] 42 | if distribution == 'uniform': 43 | if is_rnn: 44 | for name, param in module.named_parameters(): 45 | if 'bias' in name: 46 | nn.init.constant_(param, bias) 47 | elif 'weight' in name: 48 | nn.init.kaiming_uniform_( 49 | param, a=a, mode=mode, nonlinearity=nonlinearity) 50 | else: 51 | nn.init.kaiming_uniform_( 52 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 53 | 54 | else: 55 | if is_rnn: 56 | for name, param in module.named_parameters(): 57 | if 'bias' in name: 58 | nn.init.constant_(param, bias) 59 | elif 'weight' in name: 60 | nn.init.kaiming_normal_( 61 | param, a=a, mode=mode, nonlinearity=nonlinearity) 62 | else: 63 | nn.init.kaiming_normal_( 64 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 65 | 66 | if not is_rnn and hasattr(module, 'bias') and module.bias is not None: 67 | nn.init.constant_(module.bias, bias) 68 | 69 | 70 | def caffe2_xavier_init(module, bias=0): 71 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch 72 | # Acknowledgment to FAIR's internal code 73 | kaiming_init( 74 | module, 75 | a=1, 76 | mode='fan_in', 77 | nonlinearity='leaky_relu', 78 | distribution='uniform') 79 | 80 | 81 | def init_weights(modules): 82 | for m in modules: 83 | if isinstance(m, nn.Conv2d): 84 | kaiming_init(m) 85 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 86 | constant_init(m, 1) 87 | elif isinstance(m, nn.Linear): 88 | xavier_init(m) 89 | elif isinstance(m, (nn.LSTM, nn.LSTMCell)): 90 | kaiming_init(m, is_rnn=True) 91 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_optimizer # noqa 401 2 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/optimizers/builder.py: -------------------------------------------------------------------------------- 1 | import torch.optim as torch_optim 2 | 3 | from vedastr.utils import build_from_cfg 4 | 5 | 6 | def build_optimizer(cfg, default_args=None): 7 | optim = build_from_cfg(cfg, torch_optim, default_args, 'module') 8 | 9 | return optim 10 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference_runner import InferenceRunner # noqa 401 2 | from .test_runner import TestRunner # noqa 401 3 | from .train_runner import TrainRunner # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/runners/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torch.backends import cudnn 6 | 7 | from ..dataloaders import build_dataloader 8 | from ..dataloaders.samplers import build_sampler 9 | from ..datasets import build_datasets 10 | from ..logger import build_logger 11 | from ..metrics import build_metric 12 | from ..transforms import build_transform 13 | from ..utils import get_dist_info, init_dist_pytorch 14 | 15 | 16 | class Common(object): 17 | 18 | def __init__(self, cfg): 19 | super(Common, self).__init__() 20 | 21 | # build logger 22 | logger_cfg = cfg.get('logger') 23 | if logger_cfg is None: 24 | logger_cfg = dict( 25 | handlers=(dict(type='StreamHandler', level='INFO'), 26 | ), 27 | ) 28 | self.workdir = cfg.get('workdir') 29 | self.distribute = cfg.get('distribute', False) 30 | 31 | # set gpu devices 32 | self.use_gpu = self._set_device() 33 | 34 | # set distribute setting 35 | if self.distribute and self.use_gpu: 36 | init_dist_pytorch(**cfg.dist_params) 37 | 38 | self.rank, self.world_size = get_dist_info() 39 | self.logger = self._build_logger(logger_cfg) 40 | 41 | # set cudnn configuration 42 | self._set_cudnn( 43 | cfg.get('cudnn_deterministic', False), 44 | cfg.get('cudnn_benchmark', False)) 45 | 46 | # set seed 47 | self._set_seed(cfg.get('seed', None)) 48 | self.seed = cfg.get('seed', None) 49 | 50 | # build metric 51 | if 'metric' in cfg: 52 | self.metric = self._build_metric(cfg['metric']) 53 | self.backup_metric = self._build_metric(cfg['metric']) 54 | else: 55 | raise KeyError('Please set metric in config file.') 56 | 57 | # set need_text 58 | self.need_text = False 59 | 60 | def _build_logger(self, cfg): 61 | return build_logger(cfg, dict(workdir=self.workdir)) 62 | 63 | def _set_device(self): 64 | self.gpu_num = torch.cuda.device_count() 65 | if torch.cuda.is_available(): 66 | use_gpu = True 67 | else: 68 | use_gpu = False 69 | 70 | return use_gpu 71 | 72 | def _set_seed(self, seed): 73 | if seed is not None: 74 | self.logger.info('Set seed {}'.format(seed)) 75 | random.seed(seed) 76 | np.random.seed(seed) 77 | torch.manual_seed(seed) 78 | torch.cuda.manual_seed_all(seed) 79 | 80 | def _set_cudnn(self, deterministic, benchmark): 81 | self.logger.info('Set cudnn deterministic {}'.format(deterministic)) 82 | cudnn.deterministic = deterministic 83 | 84 | self.logger.info('Set cudnn benchmark {}'.format(benchmark)) 85 | cudnn.benchmark = benchmark 86 | 87 | def _build_metric(self, cfg): 88 | return build_metric(cfg) 89 | 90 | def _build_transform(self, cfg): 91 | return build_transform(cfg) 92 | 93 | def _build_dataloader(self, cfg): 94 | transform = build_transform(cfg['transform']) 95 | dataset = build_datasets(cfg['dataset'], dict(transform=transform)) 96 | 97 | # TODO, distributed sampler or not 98 | if not cfg.get('sampler'): 99 | sampler = None 100 | else: 101 | if isinstance(dataset, list): 102 | sampler = [ 103 | build_sampler(self.distribute, cfg['sampler'], 104 | dict(dataset=d)) for d in dataset 105 | ] 106 | else: 107 | sampler = build_sampler(self.distribute, 108 | cfg['sampler'], 109 | dict(dataset=dataset)) 110 | dataloader = build_dataloader( 111 | self.distribute, 112 | self.gpu_num, 113 | cfg['dataloader'], 114 | dict(dataset=dataset, sampler=sampler), 115 | seed=self.seed, 116 | ) 117 | return dataloader 118 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/runners/inference_runner.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from .base import Common 8 | from ..converter import build_converter 9 | from ..models import build_model 10 | from ..utils import load_checkpoint 11 | 12 | 13 | class InferenceRunner(Common): 14 | 15 | def __init__(self, inference_cfg, common_cfg=None): 16 | inference_cfg = inference_cfg.copy() 17 | common_cfg = {} if common_cfg is None else common_cfg.copy() 18 | super(InferenceRunner, self).__init__(common_cfg) 19 | 20 | # build test transform 21 | self.transform = self._build_transform(inference_cfg['transform']) 22 | # build converter 23 | self.converter = self._build_converter(inference_cfg['converter']) 24 | # build model 25 | self.model = self._build_model(inference_cfg['model']) 26 | self.logger.info(self.model) 27 | self.postprocess_cfg = inference_cfg.get('postprocess', None) 28 | self.model.eval() 29 | 30 | def _build_model(self, cfg): 31 | self.logger.info('Build model') 32 | 33 | model = build_model(cfg) 34 | params_num = [] 35 | for p in filter(lambda p: p.requires_grad, model.parameters()): 36 | params_num.append(np.prod(p.size())) 37 | self.logger.info('Trainable params num : %s' % (sum(params_num))) 38 | self.need_text = model.need_text 39 | 40 | if self.use_gpu: 41 | if self.distribute: 42 | model = torch.nn.parallel.DistributedDataParallel( 43 | model.cuda(), 44 | device_ids=[torch.cuda.current_device()], 45 | broadcast_buffers=True, 46 | ) 47 | self.logger.info('Using distributed training') 48 | else: 49 | if torch.cuda.device_count() > 1: 50 | model = torch.nn.DataParallel(model) 51 | model.cuda() 52 | return model 53 | 54 | def _build_converter(self, cfg): 55 | return build_converter(cfg) 56 | 57 | def load_checkpoint(self, filename, map_location='default', strict=True): 58 | self.logger.info('Load checkpoint from {}'.format(filename)) 59 | 60 | if map_location == 'default': 61 | if self.use_gpu: 62 | device_id = torch.cuda.current_device() 63 | map_location = lambda storage, loc: storage.cuda(device_id) 64 | else: 65 | map_location = 'cpu' 66 | 67 | return load_checkpoint(self.model, filename, map_location, strict) 68 | 69 | def postprocess(self, preds, cfg=None): 70 | if cfg is not None: 71 | sensitive = cfg.get('sensitive', True) 72 | character = cfg.get('character', '') 73 | else: 74 | sensitive = True 75 | character = '' 76 | 77 | probs = F.softmax(preds, dim=2) 78 | max_probs, indexes = probs.max(dim=2) 79 | preds_str = [] 80 | preds_prob = [] 81 | for i, pstr in enumerate(self.converter.decode(indexes)): 82 | str_len = len(pstr) 83 | if str_len == 0: 84 | prob = 0 85 | else: 86 | prob = max_probs[i, :str_len].cumprod(dim=0)[-1] 87 | preds_prob.append(prob) 88 | if not sensitive: 89 | pstr = pstr.lower() 90 | 91 | if character: 92 | pstr = re.sub('[^{}]'.format(character), '', pstr) 93 | 94 | preds_str.append(pstr) 95 | return preds_str, preds_prob 96 | 97 | def __call__(self, image): 98 | with torch.no_grad(): 99 | dummy_text = '' 100 | aug = self.transform(image=image, label=dummy_text) 101 | image, text = aug['image'], aug['label'] 102 | image = image.unsqueeze(0) 103 | label_input, label_length, label_target = self.converter.test_encode([text]) # noqa 501 104 | if self.use_gpu: 105 | image = image.cuda() 106 | label_input = label_input.cuda() 107 | 108 | if self.need_text: 109 | pred = self.model((image, label_input)) 110 | else: 111 | pred = self.model((image,)) 112 | 113 | pred, prob = self.postprocess(pred, self.postprocess_cfg) 114 | 115 | return pred, prob 116 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/runners/test_runner.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | from .inference_runner import InferenceRunner 6 | 7 | 8 | class TestRunner(InferenceRunner): 9 | 10 | def __init__(self, test_cfg, inference_cfg, common_cfg=None): 11 | super(TestRunner, self).__init__(inference_cfg, common_cfg) 12 | 13 | self.test_dataloader = self._build_dataloader(test_cfg['data']) 14 | if not isinstance(self.test_dataloader, dict): 15 | self.test_dataloader = dict(all=self.test_dataloader) 16 | self.test_exclude_num = dict() 17 | for k, v in self.test_dataloader.items(): 18 | extra_data = len(v.dataset) % self.world_size 19 | self.test_exclude_num[ 20 | k] = self.world_size - extra_data if extra_data != 0 else 0 21 | self.postprocess_cfg = test_cfg.get('postprocess_cfg', None) 22 | 23 | def test_batch(self, img, label, save_path=None, exclude_num=0): 24 | self.model.eval() 25 | with torch.no_grad(): 26 | label_input, label_length, label_target = self.converter.test_encode(label) # noqa 501 27 | if self.use_gpu: 28 | img = img.cuda() 29 | label_input = label_input.cuda() 30 | 31 | if self.need_text: 32 | pred = self.model((img, label_input)) 33 | else: 34 | pred = self.model((img,)) 35 | 36 | pred, prob = self.postprocess(pred, self.postprocess_cfg) 37 | 38 | if save_path is not None: 39 | for idx, (p, l) in enumerate(zip(pred, label)): 40 | if p == l: 41 | print(p, '\t', l) 42 | cimg = img[idx][0, :, :].cpu().numpy() 43 | cimg = (cimg * 0.5) + 0.5 44 | cv2.imwrite(save_path + f'/%s_{p}_{l}.png' % idx, 45 | (cimg * 255).astype(np.uint8)) 46 | self.metric.measure(pred, prob, label, exclude_num) 47 | self.backup_metric.measure(pred, prob, label, exclude_num) 48 | 49 | def __call__(self): 50 | self.logger.info('Start testing') 51 | self.logger.info('test info: %s' % self.postprocess_cfg) 52 | self.metric.reset() 53 | accs = [] 54 | for name, dataloader in self.test_dataloader.items(): 55 | test_exclude_num = self.test_exclude_num[name] 56 | save_path = None 57 | self.backup_metric.reset() 58 | for tidx, (img, label) in enumerate(dataloader): 59 | exclude_num = test_exclude_num if (tidx + 60 | 1) == len(dataloader) else 0 61 | 62 | self.test_batch(img, label, save_path, exclude_num) 63 | accs.append(self.backup_metric.avg['acc']['true']) 64 | self.logger.info( 65 | 'Test, current dataset root %s, acc %.4f, edit distance %.4f' % 66 | (name, self.backup_metric.avg['acc']['true'], 67 | self.backup_metric.avg['edit'])) 68 | self.logger.info( 69 | 'Test, average acc %.4f, edit distance %s' % 70 | (self.metric.avg['acc']['true'], self.metric.avg['edit'])) 71 | acc_str = ' '.join(list(map(lambda x: str(x)[:6], accs))) 72 | self.logger.info('For copy and record, %s' % acc_str) 73 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_transform # noqa 401 2 | from .transforms import (FactorScale, LongestMaxSize, PadIfNeeded, RandomScale, # noqa 401 3 | Sensitive, ToTensor) # noqa 401 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/transforms/builder.py: -------------------------------------------------------------------------------- 1 | import albumentations as albu 2 | 3 | from vedastr.utils import build_from_cfg 4 | from .registry import TRANSFORMS 5 | 6 | 7 | def build_transform(cfgs): 8 | tfs = [] 9 | for cfg in cfgs: 10 | if TRANSFORMS.get(cfg['type']): 11 | tf = build_from_cfg(cfg, TRANSFORMS) 12 | else: 13 | tf = build_from_cfg(cfg, albu, src='module') 14 | tfs.append(tf) 15 | aug = albu.Compose(tfs) 16 | 17 | return aug 18 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/transforms/registry.py: -------------------------------------------------------------------------------- 1 | from vedastr.utils import Registry 2 | 3 | TRANSFORMS = Registry('transforms') 4 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import (load_checkpoint, load_state_dict, load_url_dist, # noqa 401 2 | save_checkpoint) # noqa 401 3 | from .common import (WorkerInit, build_from_cfg, get_root_logger, # noqa 401 4 | set_random_seed) # noqa 401 5 | from .config import Config, ConfigDict # noqa 401 6 | from .dist_utils import (gather_tensor, get_dist_info, init_dist_pytorch, # noqa 401 7 | master_only, reduce_tensor) # noqa 401 8 | from .registry import Registry # noqa 401 9 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/utils/common.py: -------------------------------------------------------------------------------- 1 | # modify from mmcv and mmdetection 2 | 3 | import inspect 4 | import logging 5 | import random 6 | import sys 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def build_from_cfg(cfg, parent, default_args=None, src='registry'): 13 | if src == 'registry': 14 | return obj_from_dict_registry(cfg, parent, default_args) 15 | elif src == 'module': 16 | return obj_from_dict_module(cfg, parent, default_args) 17 | else: 18 | raise ValueError('Method %s is not supported' % src) 19 | 20 | 21 | def obj_from_dict_module(info, parent=None, default_args=None): 22 | """Initialize an object from dict. 23 | The dict must contain the key "type", which indicates the object type, it 24 | can be either a string or type, such as "list" or ``list``. Remaining 25 | fields are treated as the arguments for constructing the object. 26 | Args: 27 | info (dict): Object types and arguments. 28 | parent (:class:`module`): Module which may containing expected object 29 | classes. 30 | default_args (dict, optional): Default arguments for initializing the 31 | object. 32 | Returns: 33 | any type: Object built from the dict. 34 | """ 35 | assert isinstance(info, dict) and 'type' in info 36 | assert isinstance(default_args, dict) or default_args is None 37 | args = info.copy() 38 | obj_type = args.pop('type') 39 | if isinstance(obj_type, str): 40 | if parent is not None: 41 | obj_type = getattr(parent, obj_type) 42 | else: 43 | obj_type = sys.modules[obj_type] 44 | elif not isinstance(obj_type, type): 45 | raise TypeError('type must be a str or valid type, but got {}'.format( 46 | type(obj_type))) 47 | if default_args is not None: 48 | for name, value in default_args.items(): 49 | args.setdefault(name, value) 50 | return obj_type(**args) 51 | 52 | 53 | def obj_from_dict_registry(cfg, registry, default_args=None): 54 | """Build a module from config dict. 55 | Args: 56 | cfg (dict): Config dict. It should at least contain the key "type". 57 | registry (:obj:`Registry`): The registry to search the type from. 58 | default_args (dict, optional): Default initialization arguments. 59 | Returns: 60 | obj: The constructed object. 61 | """ 62 | assert isinstance(cfg, dict) and 'type' in cfg 63 | assert isinstance(default_args, dict) or default_args is None 64 | args = cfg.copy() 65 | obj_type = args.pop('type') 66 | if isinstance(obj_type, str): 67 | obj_cls = registry.get(obj_type) 68 | if obj_cls is None: 69 | raise KeyError('{} is not in the {} registry'.format( 70 | obj_type, registry.name)) 71 | elif inspect.isclass(obj_type): 72 | obj_cls = obj_type 73 | else: 74 | raise TypeError('type must be a str or valid type, but got {}'.format( 75 | type(obj_type))) 76 | if default_args is not None: 77 | for name, value in default_args.items(): 78 | args.setdefault(name, value) 79 | return obj_cls(**args) 80 | 81 | 82 | def set_random_seed(seed): 83 | random.seed(seed) 84 | np.random.seed(seed) 85 | torch.manual_seed(seed) 86 | torch.cuda.manual_seed_all(seed) 87 | 88 | 89 | def get_root_logger(log_level=logging.INFO): 90 | logger = logging.getLogger() 91 | np.set_printoptions(precision=4) 92 | if not logger.hasHandlers(): 93 | logging.basicConfig( 94 | format='%(asctime)s - %(levelname)s - %(message)s', 95 | level=log_level) 96 | return logger 97 | 98 | 99 | class WorkerInit: 100 | 101 | def __init__(self, num_workers, rank, seed, epoch): 102 | self.num_workers = num_workers 103 | self.rank = rank 104 | self.seed = seed if seed is not None else 0 105 | self.epoch = epoch 106 | 107 | def __call__(self, worker_id): 108 | worker_seed = self.num_workers * self.rank + \ 109 | worker_id + self.seed + self.epoch 110 | np.random.seed(worker_seed) 111 | random.seed(worker_seed) 112 | 113 | def set_epoch(self, n): 114 | assert isinstance(n, int) 115 | self.epoch = n 116 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # adapted from mmcv and mmdetection 2 | 3 | import functools 4 | import os 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | 10 | def init_dist_pytorch(backend='nccl', **kwargs): 11 | rank = int(os.environ['RANK']) 12 | num_gpus = torch.cuda.device_count() 13 | torch.cuda.set_device(rank % num_gpus) 14 | dist.init_process_group(backend=backend, **kwargs) 15 | 16 | 17 | def get_dist_info(): 18 | if dist.is_available(): 19 | initialized = dist.is_initialized() 20 | else: 21 | initialized = False 22 | 23 | if initialized: 24 | rank = dist.get_rank() 25 | world_size = dist.get_world_size() 26 | else: 27 | rank = 0 28 | world_size = 1 29 | 30 | return rank, world_size 31 | 32 | 33 | def reduce_tensor(data, average=True): 34 | rank, world_size = get_dist_info() 35 | if world_size < 2: 36 | return data 37 | 38 | with torch.no_grad(): 39 | if not isinstance(data, torch.Tensor): 40 | data = torch.tensor(data).cuda() 41 | dist.reduce(data, dst=0) 42 | if rank == 0 and average: 43 | data /= world_size 44 | return data 45 | 46 | 47 | def gather_tensor(data): 48 | _, world_size = get_dist_info() 49 | if world_size < 2: 50 | return data 51 | 52 | with torch.no_grad(): 53 | if not isinstance(data, torch.Tensor): 54 | data = torch.tensor(data).cuda() 55 | if not data.size(): 56 | data = data.unsqueeze(0) 57 | 58 | gather_list = [torch.ones_like(data) for _ in range(world_size)] 59 | dist.all_gather(gather_list, data) 60 | gather_data = torch.cat(gather_list, 0) 61 | 62 | return gather_data 63 | 64 | 65 | def synchronize(): 66 | if not dist.is_available(): 67 | return 68 | if not dist.is_initialized(): 69 | return 70 | world_size = dist.get_world_size() 71 | if world_size == 1: 72 | return 73 | dist.barrier() 74 | 75 | 76 | def master_only(func): 77 | """Don't use master_only to decorate function which have random state. 78 | 79 | """ 80 | 81 | @functools.wraps(func) 82 | def wrapper(*args, **kwargs): 83 | rank, _ = get_dist_info() 84 | if rank == 0: 85 | return func(*args, **kwargs) 86 | 87 | return wrapper 88 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/utils/path.py: -------------------------------------------------------------------------------- 1 | # modify from mmcv and mmdetection 2 | 3 | import os 4 | import os.path as osp 5 | import sys 6 | from pathlib import Path 7 | 8 | import six 9 | 10 | from .misc import is_str 11 | 12 | if sys.version_info <= (3, 3): 13 | FileNotFoundError = IOError 14 | else: 15 | FileNotFoundError = FileNotFoundError 16 | 17 | 18 | def is_filepath(x): 19 | if is_str(x) or isinstance(x, Path): 20 | return True 21 | else: 22 | return False 23 | 24 | 25 | def fopen(filepath, *args, **kwargs): 26 | if is_str(filepath): 27 | return open(filepath, *args, **kwargs) 28 | elif isinstance(filepath, Path): 29 | return filepath.open(*args, **kwargs) 30 | 31 | 32 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 33 | if not osp.isfile(filename): 34 | raise FileNotFoundError(msg_tmpl.format(filename)) 35 | 36 | 37 | def mkdir_or_exist(dir_name, mode=0o777): 38 | if dir_name == '': 39 | return 40 | dir_name = osp.expanduser(dir_name) 41 | if six.PY3: 42 | os.makedirs(dir_name, mode=mode, exist_ok=True) 43 | else: 44 | if not osp.isdir(dir_name): 45 | os.makedirs(dir_name, mode=mode) 46 | 47 | 48 | def symlink(src, dst, overwrite=True, **kwargs): 49 | if os.path.lexists(dst) and overwrite: 50 | os.remove(dst) 51 | os.symlink(src, dst, **kwargs) 52 | 53 | 54 | def _scandir_py35(dir_path, suffix=None): 55 | for entry in os.scandir(dir_path): 56 | if not entry.is_file(): 57 | continue 58 | filename = entry.name 59 | if suffix is None: 60 | yield filename 61 | elif filename.endswith(suffix): 62 | yield filename 63 | 64 | 65 | def _scandir_py(dir_path, suffix=None): 66 | for filename in os.listdir(dir_path): 67 | if not osp.isfile(osp.join(dir_path, filename)): 68 | continue 69 | if suffix is None: 70 | yield filename 71 | elif filename.endswith(suffix): 72 | yield filename 73 | 74 | 75 | def scandir(dir_path, suffix=None): 76 | if suffix is not None and not isinstance(suffix, (str, tuple)): 77 | raise TypeError('"suffix" must be a string or tuple of strings') 78 | if sys.version_info >= (3, 5): 79 | return _scandir_py35(dir_path, suffix) 80 | else: 81 | return _scandir_py(dir_path, suffix) 82 | 83 | 84 | def find_vcs_root(path, markers=('.git', )): 85 | """Finds the root directory (including itself) of specified markers. 86 | 87 | Args: 88 | path (str): Path of directory or file. 89 | markers (list[str], optional): List of file or directory names. 90 | 91 | Returns: 92 | The directory contained one of the markers or None if not found. 93 | """ 94 | if osp.isfile(path): 95 | path = osp.dirname(path) 96 | 97 | prev, cur = None, osp.abspath(osp.expanduser(path)) 98 | while cur != prev: 99 | if any(osp.exists(osp.join(cur, marker)) for marker in markers): 100 | return cur 101 | prev, cur = cur, osp.split(cur)[0] 102 | return None 103 | -------------------------------------------------------------------------------- /vedastr_cstr/vedastr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # modify from mmcv and mmdetection 2 | import inspect 3 | 4 | 5 | class Registry(object): 6 | 7 | def __init__(self, name): 8 | self._name = name 9 | self._module_dict = dict() 10 | 11 | def __repr__(self): 12 | format_str = self.__class__.__name__ + '(name={}, items={})'.format( 13 | self._name, list(self._module_dict.keys())) 14 | return format_str 15 | 16 | @property 17 | def name(self): 18 | return self._name 19 | 20 | @property 21 | def module_dict(self): 22 | return self._module_dict 23 | 24 | def get(self, key): 25 | return self._module_dict.get(key, None) 26 | 27 | def _register_module(self, module_class): 28 | """Register a module. 29 | Args: 30 | module (:obj:`nn.Module`): Module to be registered. 31 | """ 32 | if not inspect.isclass(module_class): 33 | raise TypeError('module must be a class, but got {}'.format( 34 | type(module_class))) 35 | module_name = module_class.__name__ 36 | if module_name in self._module_dict: 37 | raise KeyError('{} is already registered in {}'.format( 38 | module_name, self.name)) 39 | self._module_dict[module_name] = module_class 40 | 41 | def register_module(self, cls): 42 | self._register_module(cls) 43 | return cls 44 | --------------------------------------------------------------------------------