├── Datasets.md ├── README.md ├── finetune ├── Datasets.md ├── LICENSE ├── Makefile ├── NOTICE ├── bench.py ├── config │ ├── config.yaml │ ├── hydra.yaml │ └── overrides.yaml ├── configs │ ├── bench.yaml │ ├── charset │ │ ├── 36_lowercase.yaml │ │ ├── 62_mixed-case.yaml │ │ ├── 6623_chinese.yaml │ │ └── 94_full.yaml │ ├── dataset │ │ ├── Union.yaml │ │ ├── ard.yaml │ │ ├── cvl.yaml │ │ ├── doczh.yaml │ │ ├── iam.yaml │ │ ├── icdar24.yaml │ │ ├── real.yaml │ │ ├── scenezh.yaml │ │ ├── synth.yaml │ │ └── webzh.yaml │ ├── experiment │ │ ├── abinet-sv.yaml │ │ ├── abinet-union.yaml │ │ ├── abinet.yaml │ │ ├── cluster-debug.yaml │ │ ├── crnn-syn.yaml │ │ ├── crnn-union.yaml │ │ ├── crnn.yaml │ │ ├── finetune-icdar-mdr.yaml │ │ ├── finetune-icdar.yaml │ │ ├── mdr-base-ard.yaml │ │ ├── mdr-base-synth.yaml │ │ ├── mdr-dec6-union.yaml │ │ ├── mdr-hybrid-large.yaml │ │ ├── mdr-pool-base.yaml │ │ ├── mdr-pool-baug.yaml │ │ ├── mdr-syn-dec4.yaml │ │ ├── mdr-syn-dec6.yaml │ │ ├── parseq-patch16-224.yaml │ │ ├── parseq-real.yaml │ │ ├── parseq-syn-mdr.yaml │ │ ├── parseq-syn.yaml │ │ ├── parseq-union-pre.yaml │ │ ├── parseq-union.yaml │ │ ├── parseq.yaml │ │ ├── rand-sdr18-moco-dec6-union.yaml │ │ ├── randsdr19-moco-dec6-union.yaml │ │ ├── sdr-dino-nocycle-dec6-union.yaml │ │ ├── sdr-moco-dec6-union.yaml │ │ ├── sdr-moco-dec6-union20e.yaml │ │ ├── sdr-only2.yaml │ │ ├── sdr-randaug.yaml │ │ ├── trans-cvl.yaml │ │ ├── transocr-base-pre.yaml │ │ ├── transocr-dev.yaml │ │ ├── transocr-dev2.yaml │ │ ├── transocr-dig.yaml │ │ ├── transocr-iam.yaml │ │ ├── transocr-jlw.yaml │ │ ├── transocr-maerec.yaml │ │ ├── transocr-mim.yaml │ │ ├── transocr-scenezh.yaml │ │ ├── transocr-tiny.yaml │ │ ├── transocr-webzh.yaml │ │ ├── trba.yaml │ │ ├── trbc.yaml │ │ ├── trstr-mdr.yaml │ │ ├── tune_abinet-lm.yaml │ │ └── vitstr.yaml │ ├── main.yaml │ ├── model │ │ ├── abinet.yaml │ │ ├── clusterseq.yaml │ │ ├── crnn.yaml │ │ ├── parseq.yaml │ │ ├── transocr.yaml │ │ ├── trba.yaml │ │ └── vitstr.yaml │ └── tune.yaml ├── hubconf.py ├── pyproject.toml ├── read.py ├── requirements │ ├── bench.in │ ├── bench.txt │ ├── constraints.txt │ ├── core.in │ ├── core.txt │ ├── test.in │ ├── test.txt │ ├── train.in │ ├── train.txt │ ├── tune.in │ └── tune.txt ├── strhub │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── __init__.cpython-39.pyc │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── aa_overrides.cpython-310.pyc │ │ │ ├── aa_overrides.cpython-38.pyc │ │ │ ├── augment.cpython-310.pyc │ │ │ ├── augment.cpython-38.pyc │ │ │ ├── dataset.cpython-310.pyc │ │ │ ├── dataset.cpython-38.pyc │ │ │ ├── module.cpython-310.pyc │ │ │ ├── module.cpython-38.pyc │ │ │ ├── module.cpython-39.pyc │ │ │ ├── utils.cpython-310.pyc │ │ │ └── utils.cpython-38.pyc │ │ ├── aa_overrides.py │ │ ├── augment.py │ │ ├── count_word.py │ │ ├── dataset.py │ │ ├── filter_unseen.py │ │ ├── module.py │ │ ├── read_from_lmdb.py │ │ ├── split_union.py │ │ └── utils.py │ └── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base.cpython-310.pyc │ │ ├── base.cpython-38.pyc │ │ ├── modules.cpython-38.pyc │ │ ├── utils.cpython-310.pyc │ │ └── utils.cpython-38.pyc │ │ ├── abinet │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── attention.cpython-38.pyc │ │ │ ├── backbone.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── model_abinet_iter.cpython-38.pyc │ │ │ ├── model_alignment.cpython-38.pyc │ │ │ ├── model_language.cpython-38.pyc │ │ │ ├── model_vision.cpython-38.pyc │ │ │ ├── resnet.cpython-38.pyc │ │ │ ├── system.cpython-38.pyc │ │ │ └── transformer.cpython-38.pyc │ │ ├── attention.py │ │ ├── backbone.py │ │ ├── model.py │ │ ├── model_abinet_iter.py │ │ ├── model_alignment.py │ │ ├── model_language.py │ │ ├── model_vision.py │ │ ├── resnet.py │ │ ├── system.py │ │ └── transformer.py │ │ ├── base.py │ │ ├── clusterseq │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── modules.cpython-37.pyc │ │ │ ├── modules.cpython-38.pyc │ │ │ ├── system.cpython-37.pyc │ │ │ └── system.cpython-38.pyc │ │ ├── modules.py │ │ └── system.py │ │ ├── crnn │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ └── system.cpython-38.pyc │ │ ├── model.py │ │ └── system.py │ │ ├── modules.py │ │ ├── parseq │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── modules.cpython-310.pyc │ │ │ ├── modules.cpython-38.pyc │ │ │ ├── system.cpython-310.pyc │ │ │ └── system.cpython-38.pyc │ │ ├── modules.py │ │ └── system.py │ │ ├── transocr │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── modules.cpython-37.pyc │ │ │ ├── modules.cpython-38.pyc │ │ │ ├── system.cpython-37.pyc │ │ │ └── system.cpython-38.pyc │ │ ├── modules.py │ │ └── system.py │ │ ├── trba │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── feature_extraction.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── prediction.cpython-38.pyc │ │ │ ├── system.cpython-38.pyc │ │ │ └── transformation.cpython-38.pyc │ │ ├── feature_extraction.py │ │ ├── model.py │ │ ├── prediction.py │ │ ├── system.py │ │ └── transformation.py │ │ ├── utils.py │ │ └── vitstr │ │ ├── __init__.py │ │ ├── model.py │ │ └── system.py ├── test.py ├── test_conf.py ├── test_custom.py ├── test_dig.py ├── test_wordart.py ├── test_zh.py ├── tools │ ├── art_converter.py │ ├── case_sensitive_str_datasets_converter.py │ ├── coco_2_converter.py │ ├── coco_text_converter.py │ ├── create_lmdb_dataset.py │ ├── filter_lmdb.py │ ├── lsvt_converter.py │ ├── mlt19_converter.py │ ├── openvino_converter.py │ ├── split_data.py │ ├── test_abinet_lm_acc.py │ └── textocr_converter.py ├── train.py ├── train.sh ├── train_crnn.py ├── train_zh.py └── tune.py └── pretrain ├── augmentation ├── __pycache__ │ ├── blur.cpython-37.pyc │ ├── blur.cpython-38.pyc │ ├── camera.cpython-37.pyc │ ├── camera.cpython-38.pyc │ ├── geometry.cpython-37.pyc │ ├── geometry.cpython-38.pyc │ ├── noise.cpython-37.pyc │ ├── noise.cpython-38.pyc │ ├── ops.cpython-37.pyc │ ├── ops.cpython-38.pyc │ ├── pattern.cpython-37.pyc │ ├── pattern.cpython-38.pyc │ ├── process.cpython-37.pyc │ ├── process.cpython-38.pyc │ ├── warp.cpython-37.pyc │ ├── warp.cpython-38.pyc │ ├── weather.cpython-37.pyc │ └── weather.cpython-38.pyc ├── blur.py ├── camera.py ├── frost │ ├── frost1.png │ ├── frost2.png │ ├── frost3.png │ ├── frost4.jpg │ ├── frost5.jpg │ └── frost6.jpg ├── geometry.py ├── images │ ├── delivery.png │ ├── education.png │ ├── manila.png │ ├── nokia.png │ └── telekom.png ├── noise.py ├── ops.py ├── pattern.py ├── process.py ├── test.py ├── warp.py └── weather.py ├── datasets.py ├── demo.py ├── engine_finetune.py ├── engine_pixel.py ├── engine_pretrain.py ├── infer_single.py ├── load_model.py ├── main_finetune.py ├── main_linprobe.py ├── main_pixel.py ├── main_pretrain.py ├── models_abi_re.py ├── models_mae.py ├── models_pixel.py ├── models_sim_re.py ├── models_vit.py ├── requirements.txt ├── scripts ├── .ipynb_checkpoints │ └── encoder-pre4-checkpoint.sh ├── encoder-pretrain.sh ├── eval-SR-compare.sh ├── eval-SR.sh ├── eval-Seg-compare.sh ├── eval-Seg.sh ├── eval.sh ├── eval_all.sh ├── finetune-SR-CCD.sh ├── finetune-SR-DiG.sh └── finetune-Seg.sh ├── sim_pretrain.py └── util ├── __pycache__ ├── lr_sched.cpython-37.pyc ├── lr_sched.cpython-38.pyc ├── metric_iou.cpython-38.pyc ├── misc.cpython-37.pyc ├── misc.cpython-38.pyc ├── pos_embed.cpython-37.pyc ├── pos_embed.cpython-38.pyc └── transforms.cpython-38.pyc ├── crop.py ├── datasets.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── metric_iou.py ├── misc.py ├── pos_embed.py └── transforms.py /finetune/Makefile: -------------------------------------------------------------------------------- 1 | # Reference: https://dida.do/blog/managing-layered-requirements-with-pip-tools 2 | 3 | REQUIREMENTS_TXT := $(addsuffix .txt, $(basename $(wildcard requirements/*.in))) 4 | PIP_COMPILE := pip-compile --quiet --no-header --allow-unsafe --resolver=backtracking 5 | 6 | .DEFAULT_GOAL := help 7 | .PHONY: reqs clean-reqs help 8 | 9 | requirements/constraints.txt: requirements/*.in 10 | CONSTRAINTS=/dev/null $(PIP_COMPILE) --strip-extras --output-file $@ $^ --extra-index-url https://download.pytorch.org/whl/cpu 11 | 12 | requirements/%.txt: requirements/%.in requirements/constraints.txt 13 | CONSTRAINTS=constraints.txt $(PIP_COMPILE) --no-annotate --output-file $@ $< 14 | @# Remove --extra-index-url, blank lines, and torch dependency from non-core groups 15 | @[ $* = core ] || sed '/^--/d; /^$$/d; /^torch==/d' -i $@ 16 | 17 | reqs: $(REQUIREMENTS_TXT) ## Generate the requirements files 18 | 19 | torch-%: requirements/core.txt ## Set PyTorch platform to use, e.g. cpu, cu117, rocm5.2 20 | @echo Generating requirements/core.$*.txt 21 | @sed 's|cpu|$*|' $< >requirements/core.$*.txt 22 | 23 | clean-reqs: ## Delete the requirements files 24 | rm -f requirements/constraints.txt requirements/core.*.txt $(REQUIREMENTS_TXT) 25 | 26 | help: ## Display this help 27 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 28 | -------------------------------------------------------------------------------- /finetune/NOTICE: -------------------------------------------------------------------------------- 1 | Scene Text Recognition Model Hub 2 | Copyright 2022 Darwin Bautista 3 | 4 | The Initial Developer of strhub/models/abinet (sans system.py) is 5 | Fang et al. (https://github.com/FangShancheng/ABINet). 6 | Copyright 2021-2022 USTC 7 | 8 | The Initial Developer of strhub/models/crnn (sans system.py) is 9 | Jieru Mei (https://github.com/meijieru/crnn.pytorch). 10 | Copyright 2017-2022 Jieru Mei 11 | 12 | The Initial Developer of strhub/models/trba (sans system.py) is 13 | Jeonghun Baek (https://github.com/clovaai/deep-text-recognition-benchmark). 14 | Copyright 2019-2022 NAVER Corp. 15 | 16 | The Initial Developer of strhub/models/vitstr (sans system.py) is 17 | Rowel Atienza (https://github.com/roatienza/deep-text-recognition-benchmark). 18 | Copyright 2021-2022 Rowel Atienza 19 | -------------------------------------------------------------------------------- /finetune/bench.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Scene Text Recognition Model Hub 3 | # Copyright 2022 Darwin Bautista 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | 19 | import torch 20 | from torch.utils import benchmark 21 | 22 | from fvcore.nn import FlopCountAnalysis, ActivationCountAnalysis, flop_count_table 23 | 24 | import hydra 25 | from omegaconf import DictConfig 26 | 27 | 28 | @torch.inference_mode() 29 | @hydra.main(config_path='configs', config_name='bench', version_base='1.2') 30 | def main(config: DictConfig): 31 | # For consistent behavior 32 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 33 | torch.backends.cudnn.benchmark = False 34 | torch.use_deterministic_algorithms(True) 35 | 36 | device = config.get('device', 'cuda') 37 | 38 | h, w = config.data.img_size 39 | x = torch.rand(1, 3, h, w, device=device) 40 | model = hydra.utils.instantiate(config.model).eval().to(device) 41 | 42 | if config.get('range', False): 43 | for i in range(1, 26, 4): 44 | timer = benchmark.Timer( 45 | stmt='model(x, len)', 46 | globals={'model': model, 'x': x, 'len': i}) 47 | print(timer.blocked_autorange(min_run_time=1)) 48 | else: 49 | timer = benchmark.Timer( 50 | stmt='model(x)', 51 | globals={'model': model, 'x': x}) 52 | flops = FlopCountAnalysis(model, x) 53 | acts = ActivationCountAnalysis(model, x) 54 | print(timer.blocked_autorange(min_run_time=1)) 55 | print(flop_count_table(flops, 1, acts, False)) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /finetune/config/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _convert_: all 3 | img_size: 4 | - 32 5 | - 128 6 | max_label_length: 25 7 | charset_train: 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ 8 | charset_test: 0123456789abcdefghijklmnopqrstuvwxyz 9 | batch_size: 192 10 | weight_decay: 0.0 11 | warmup_pct: 0.075 12 | name: transocr 13 | _target_: strhub.models.transocr.system.TransOCR 14 | patch_size: 15 | - 4 16 | - 4 17 | enc_embed_dim: 384 18 | enc_num_heads: 6 19 | enc_mlp_ratio: 4 20 | enc_depth: 12 21 | dec_embed_dim: 384 22 | dec_num_heads: 12 23 | dec_mlp_ratio: 4 24 | dec_depth: 6 25 | query_type: learn 26 | standard: true 27 | lr: 6.0e-06 28 | perm_num: 6 29 | perm_forward: true 30 | perm_mirrored: true 31 | dropout: 0.1 32 | decode_ar: true 33 | refine_iters: 1 34 | parseq_pretrained: true 35 | parseq_checkpoint: /home/gaoz/outputs/transocr/finetune-mdr-icdar24-384-20e/2024-03-29_12-11-50/checkpoints/last.ckpt 36 | encoder_pretrained: false 37 | encoder_checkpoint: /home/gaoz/pretrained/trans-hybird-mulemb3-prob/checkpoint-19.pth 38 | data: 39 | _target_: strhub.data.module.SceneTextDataModule 40 | root_dir: /home/gaoz/datasets/ 41 | train_dir: icdar24 42 | batch_size: ${model.batch_size} 43 | img_size: ${model.img_size} 44 | charset_train: ${model.charset_train} 45 | charset_test: ${model.charset_test} 46 | max_label_length: ${model.max_label_length} 47 | remove_whitespace: true 48 | normalize_unicode: true 49 | augment: true 50 | num_workers: 8 51 | trainer: 52 | _target_: pytorch_lightning.Trainer 53 | _convert_: all 54 | val_check_interval: 1000 55 | max_epochs: 20 56 | gradient_clip_val: 20 57 | accelerator: gpu 58 | devices: 2 59 | ckpt_path: null 60 | pretrained: null 61 | exp_name: 6e6-finetune-mdr-icdar24-384-20e 62 | -------------------------------------------------------------------------------- /finetune/config/hydra.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: /home/gaoz/parseq-dev 4 | sweep: 5 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | subdir: ${hydra.job.override_dirname} 7 | launcher: 8 | _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher 9 | sweeper: 10 | _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper 11 | max_batch_size: null 12 | params: null 13 | help: 14 | app_name: ${hydra.job.name} 15 | header: '${hydra.help.app_name} is powered by Hydra. 16 | 17 | ' 18 | footer: 'Powered by Hydra (https://hydra.cc) 19 | 20 | Use --hydra-help to view Hydra specific help 21 | 22 | ' 23 | template: '${hydra.help.header} 24 | 25 | == Configuration groups == 26 | 27 | Compose your configuration from those groups (group=option) 28 | 29 | 30 | $APP_CONFIG_GROUPS 31 | 32 | 33 | == Config == 34 | 35 | Override anything in the config (foo.bar=value) 36 | 37 | 38 | $CONFIG 39 | 40 | 41 | ${hydra.help.footer} 42 | 43 | ' 44 | hydra_help: 45 | template: 'Hydra (${hydra.runtime.version}) 46 | 47 | See https://hydra.cc for more info. 48 | 49 | 50 | == Flags == 51 | 52 | $FLAGS_HELP 53 | 54 | 55 | == Configuration groups == 56 | 57 | Compose your configuration from those groups (For example, append hydra/job_logging=disabled 58 | to command line) 59 | 60 | 61 | $HYDRA_CONFIG_GROUPS 62 | 63 | 64 | Use ''--cfg hydra'' to Show the Hydra config. 65 | 66 | ' 67 | hydra_help: ??? 68 | hydra_logging: 69 | version: 1 70 | formatters: 71 | simple: 72 | format: '[%(asctime)s][HYDRA] %(message)s' 73 | handlers: 74 | console: 75 | class: logging.StreamHandler 76 | formatter: simple 77 | stream: ext://sys.stdout 78 | root: 79 | level: INFO 80 | handlers: 81 | - console 82 | loggers: 83 | logging_example: 84 | level: DEBUG 85 | disable_existing_loggers: false 86 | job_logging: 87 | version: 1 88 | formatters: 89 | simple: 90 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 91 | handlers: 92 | console: 93 | class: logging.StreamHandler 94 | formatter: simple 95 | stream: ext://sys.stdout 96 | file: 97 | class: logging.FileHandler 98 | formatter: simple 99 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 100 | root: 101 | level: INFO 102 | handlers: 103 | - console 104 | - file 105 | disable_existing_loggers: false 106 | env: {} 107 | mode: RUN 108 | searchpath: [] 109 | callbacks: {} 110 | output_subdir: config 111 | overrides: 112 | hydra: 113 | - hydra.run.dir="/home/gaoz/parseq-dev" 114 | - hydra.job.name=train_ddp_process_1 115 | - hydra.mode=RUN 116 | task: 117 | - +experiment=finetune-icdar-mdr 118 | job: 119 | name: train_ddp_process_1 120 | chdir: null 121 | override_dirname: +experiment=finetune-icdar-mdr 122 | id: ??? 123 | num: ??? 124 | config_name: main 125 | env_set: {} 126 | env_copy: [] 127 | config: 128 | override_dirname: 129 | kv_sep: '=' 130 | item_sep: ',' 131 | exclude_keys: [] 132 | runtime: 133 | version: 1.3.2 134 | version_base: '1.2' 135 | cwd: /home/gaoz/parseq-dev 136 | config_sources: 137 | - path: hydra.conf 138 | schema: pkg 139 | provider: hydra 140 | - path: /home/gaoz/parseq-dev/configs 141 | schema: file 142 | provider: main 143 | - path: '' 144 | schema: structured 145 | provider: schema 146 | output_dir: /home/gaoz/parseq-dev 147 | choices: 148 | experiment: finetune-icdar-mdr 149 | dataset: icdar24 150 | charset: 94_full 151 | model: transocr 152 | hydra/env: default 153 | hydra/callbacks: null 154 | hydra/job_logging: default 155 | hydra/hydra_logging: default 156 | hydra/hydra_help: default 157 | hydra/help: default 158 | hydra/sweeper: basic 159 | hydra/launcher: basic 160 | hydra/output: default 161 | verbose: false 162 | -------------------------------------------------------------------------------- /finetune/config/overrides.yaml: -------------------------------------------------------------------------------- 1 | - +experiment=finetune-icdar-mdr 2 | -------------------------------------------------------------------------------- /finetune/configs/bench.yaml: -------------------------------------------------------------------------------- 1 | # Disable any logging or output 2 | defaults: 3 | - main 4 | - _self_ 5 | - override hydra/job_logging: disabled 6 | 7 | hydra: 8 | output_subdir: null 9 | run: 10 | dir: . 11 | -------------------------------------------------------------------------------- /finetune/configs/charset/36_lowercase.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | charset_train: "0123456789abcdefghijklmnopqrstuvwxyz" 4 | -------------------------------------------------------------------------------- /finetune/configs/charset/62_mixed-case.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | charset_train: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 4 | -------------------------------------------------------------------------------- /finetune/configs/charset/94_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | charset_train: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" 4 | -------------------------------------------------------------------------------- /finetune/configs/dataset/Union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: Union 4 | -------------------------------------------------------------------------------- /finetune/configs/dataset/ard.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: ard 4 | -------------------------------------------------------------------------------- /finetune/configs/dataset/cvl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: cvl 4 | -------------------------------------------------------------------------------- /finetune/configs/dataset/doczh.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: document_train 4 | -------------------------------------------------------------------------------- /finetune/configs/dataset/iam.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: iam 4 | -------------------------------------------------------------------------------- /finetune/configs/dataset/icdar24.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: icdar24 4 | 5 | # CUDA_VISIBLE_DEVICES=4,6 ./train.py +experiment=finetune-icdar-mdr -------------------------------------------------------------------------------- /finetune/configs/dataset/real.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: real 4 | -------------------------------------------------------------------------------- /finetune/configs/dataset/scenezh.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: scene_train 4 | -------------------------------------------------------------------------------- /finetune/configs/dataset/synth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: synth 4 | num_workers: 3 5 | 6 | # trainer: 7 | # limit_train_batches: 0.20496 # to match the steps per epoch of `real` 8 | -------------------------------------------------------------------------------- /finetune/configs/dataset/webzh.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: web_train 4 | -------------------------------------------------------------------------------- /finetune/configs/experiment/abinet-sv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: abinet 4 | 5 | model: 6 | name: abinet-sv 7 | v_num_layers: 2 8 | v_attention: attention 9 | -------------------------------------------------------------------------------- /finetune/configs/experiment/abinet-union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: abinet 4 | - override /charset: 36_lowercase 5 | - override /dataset: semi 6 | 7 | model: 8 | parseq_pretrained: False 9 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 10 | encoder_pretrained: True 11 | encoder_checkpoint: /home/sist/zuangao/pretrain_encoder/mdr-abinet45-2048b-10e/checkpoint-9.pth 12 | batch_size: 192 13 | 14 | 15 | data: 16 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 17 | num_workers: 8 18 | augment: True 19 | 20 | trainer: 21 | # val_check_interval: 5 22 | accelerator: gpu 23 | devices: 4 24 | 25 | 26 | exp_name: abinet-4A100-192-semi-aug-pre 27 | 28 | 29 | hydra: 30 | output_subdir: config 31 | run: 32 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 33 | sweep: 34 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 35 | subdir: ${hydra.job.override_dirname} 36 | 37 | -------------------------------------------------------------------------------- /finetune/configs/experiment/abinet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: abinet 4 | -------------------------------------------------------------------------------- /finetune/configs/experiment/cluster-debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: clusterseq 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | enc_embed_dim: 768 9 | enc_num_heads: 12 10 | enc_mlp_ratio: 4 11 | enc_depth: 12 12 | dec_embed_dim: 768 13 | dec_num_heads: 8 14 | dec_mlp_ratio: 4 15 | dec_depth: 4 16 | patch_size: [ 4, 4 ] 17 | parseq_pretrained: False 18 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 19 | encoder_pretrained: False 20 | encoder_checkpoint: /home/sist/zuangao/output/hybrid-mul-base-realrec/checkpoint-19.pth 21 | batch_size: 192 22 | 23 | data: 24 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 25 | num_workers: 8 26 | augment: False 27 | 28 | trainer: 29 | # val_check_interval: 5 30 | accelerator: gpu 31 | devices: 4 32 | 33 | 34 | exp_name: transocr-vitbase-4A100-192x4-Union-4x4-noaug-94full-prerealrec 35 | 36 | 37 | hydra: 38 | output_subdir: config 39 | run: 40 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 41 | sweep: 42 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 43 | subdir: ${hydra.job.override_dirname} 44 | 45 | -------------------------------------------------------------------------------- /finetune/configs/experiment/crnn-syn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: crnn 4 | - override /charset: 36_lowercase 5 | - override /dataset: synth 6 | 7 | model: 8 | ckpt_pretrained: False 9 | ckpt_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 10 | encoder_pretrained: False 11 | encoder_checkpoint: /home/sist/zuangao/output/crnn-debug/checkpoint-19.pth 12 | batch_size: 384 13 | 14 | 15 | data: 16 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 17 | num_workers: 8 18 | augment: True 19 | 20 | trainer: 21 | accelerator: gpu 22 | devices: 2 23 | 24 | 25 | exp_name: crnn-2A100-768-synth-aug-nopre 26 | 27 | 28 | hydra: 29 | output_subdir: config 30 | run: 31 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 32 | sweep: 33 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | subdir: ${hydra.job.override_dirname} 35 | 36 | -------------------------------------------------------------------------------- /finetune/configs/experiment/crnn-union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: crnn 4 | - override /charset: 36_lowercase 5 | - override /dataset: Union 6 | 7 | model: 8 | ckpt_pretrained: False 9 | ckpt_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 10 | encoder_pretrained: True 11 | encoder_checkpoint: /home/sist/zuangao/output/crnn-debug/checkpoint-19.pth 12 | batch_size: 384 13 | 14 | 15 | data: 16 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 17 | num_workers: 8 18 | augment: True 19 | 20 | trainer: 21 | accelerator: gpu 22 | devices: 2 23 | 24 | 25 | exp_name: crnn-2A100-768-Union-aug-pre 26 | 27 | 28 | hydra: 29 | output_subdir: config 30 | run: 31 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 32 | sweep: 33 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | subdir: ${hydra.job.override_dirname} 35 | 36 | -------------------------------------------------------------------------------- /finetune/configs/experiment/crnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: crnn 4 | 5 | data: 6 | num_workers: 5 7 | -------------------------------------------------------------------------------- /finetune/configs/experiment/finetune-icdar-mdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: icdar24 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: True 10 | parseq_checkpoint: /home/gaoz/outputs/transocr/finetune-mdr-icdar24-384-20e/2024-03-29_12-11-50/checkpoints/last.ckpt 11 | encoder_pretrained: False 12 | encoder_checkpoint: /home/gaoz/pretrained/trans-hybird-mulemb3-prob/checkpoint-19.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.000006 16 | standard: True 17 | 18 | 19 | data: 20 | root_dir: /home/gaoz/datasets/ 21 | num_workers: 8 22 | augment: True 23 | 24 | trainer: 25 | # val_check_interval: 5 26 | max_epochs: 20 27 | accelerator: gpu 28 | devices: 2 29 | 30 | 31 | exp_name: 6e6-finetune-mdr-icdar24-384-20e 32 | 33 | 34 | hydra: 35 | output_subdir: config 36 | run: 37 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | sweep: 39 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 40 | subdir: ${hydra.job.override_dirname} 41 | 42 | -------------------------------------------------------------------------------- /finetune/configs/experiment/finetune-icdar.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: icdar24 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: True 10 | parseq_checkpoint: /home/gaoz/outputs/transocr/finetune-dig-icdar24-384-20e/2024-03-29_12-06-35/checkpoints/epoch=19-step=31609-val_accuracy=93.5918-val_NED=97.9359.ckpt 11 | encoder_pretrained: False 12 | encoder_checkpoint: /home/gaoz/pretrained/trans-hybird-mulemb3-prob/checkpoint-19.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.000006 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/ 20 | num_workers: 8 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 5 25 | max_epochs: 20 26 | accelerator: gpu 27 | devices: 2 28 | 29 | 30 | exp_name: 6e6-finetune-dig-icdar24-384-20e 31 | 32 | 33 | hydra: 34 | output_subdir: config 35 | run: 36 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | sweep: 38 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | subdir: ${hydra.job.override_dirname} 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/mdr-base-ard.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: ard 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/mdr-pool-base-OCRCC/checkpoint-19.pth 13 | batch_size: 192 14 | enc_embed_dim: 512 15 | enc_num_heads: 8 16 | dec_depth: 6 17 | dec_embed_dim: 512 18 | dec_num_heads: 8 19 | lr: 0.0003 20 | standard: True 21 | 22 | 23 | data: 24 | root_dir: /home/gaoz/datasets/data/ 25 | num_workers: 8 26 | augment: True 27 | 28 | trainer: 29 | # val_check_interval: 8000 30 | max_epochs: 30 31 | accelerator: gpu 32 | devices: 2 33 | 34 | 35 | exp_name: mdr-pool-base-ard-384-30e-aug 36 | 37 | 38 | 39 | hydra: 40 | output_subdir: config 41 | run: 42 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 43 | sweep: 44 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 45 | subdir: ${hydra.job.override_dirname} 46 | 47 | -------------------------------------------------------------------------------- /finetune/configs/experiment/mdr-base-synth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: synth 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/mdr-pool-base-OCRCC/checkpoint-19.pth 13 | batch_size: 192 14 | enc_embed_dim: 512 15 | enc_num_heads: 8 16 | dec_depth: 6 17 | # dec_embed_dim: 512 18 | # dec_num_heads: 8 19 | lr: 0.0003 20 | standard: True 21 | 22 | 23 | data: 24 | root_dir: /home/gaoz/datasets/data/ 25 | num_workers: 8 26 | augment: True 27 | 28 | trainer: 29 | # val_check_interval: 8000 30 | max_epochs: 10 31 | accelerator: gpu 32 | devices: 4 33 | 34 | 35 | exp_name: mdr-pool-base-synth-768-10e-aug 36 | 37 | 38 | 39 | hydra: 40 | output_subdir: config 41 | run: 42 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 43 | sweep: 44 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 45 | subdir: ${hydra.job.override_dirname} 46 | 47 | -------------------------------------------------------------------------------- /finetune/configs/experiment/mdr-dec6-union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: False 12 | encoder_checkpoint: $pretrained_path$ 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: $finetune_data_path$ 20 | num_workers: 8 21 | augment: False 22 | 23 | trainer: 24 | # val_check_interval: 5 25 | max_epochs: 10 26 | accelerator: gpu 27 | devices: 2 28 | 29 | 30 | exp_name: trstr-dec6-union-384-10e 31 | 32 | 33 | hydra: 34 | output_subdir: config 35 | run: 36 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | sweep: 38 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | subdir: ${hydra.job.override_dirname} 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/mdr-hybrid-large.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/hybrid-mul-base-pretext-4gpus256b/checkpoint-19.pth 13 | batch_size: 96 14 | lr: 0.0003 15 | standard: True 16 | enc_embed_dim: 768 17 | enc_num_heads: 12 18 | enc_mlp_ratio: 4 19 | enc_depth: 12 20 | dec_embed_dim: 768 21 | dec_depth: 6 22 | dec_num_heads: 8 23 | 24 | data: 25 | root_dir: /home/gaoz/datasets/data/ 26 | num_workers: 8 27 | augment: True 28 | 29 | trainer: 30 | # val_check_interval: 5 31 | max_epochs: 10 32 | accelerator: gpu 33 | devices: 6 34 | 35 | 36 | exp_name: mdr-hybrid-large-union-384-10e-aug 37 | 38 | 39 | hydra: 40 | output_subdir: config 41 | run: 42 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 43 | sweep: 44 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 45 | subdir: ${hydra.job.override_dirname} 46 | 47 | -------------------------------------------------------------------------------- /finetune/configs/experiment/mdr-pool-base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/mdr-pool-base-prtext/checkpoint-19.pth 13 | batch_size: 192 14 | enc_embed_dim: 512 15 | enc_num_heads: 8 16 | dec_depth: 6 17 | dec_embed_dim: 512 18 | dec_num_heads: 8 19 | lr: 0.0003 20 | standard: True 21 | 22 | 23 | data: 24 | root_dir: /home/gaoz/datasets/data/ 25 | num_workers: 8 26 | augment: True 27 | 28 | trainer: 29 | # val_check_interval: 8000 30 | max_epochs: 10 31 | accelerator: gpu 32 | devices: 2 33 | 34 | 35 | exp_name: mdr-pool-base-union-384-10e-aug 36 | 37 | 38 | 39 | hydra: 40 | output_subdir: config 41 | run: 42 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 43 | sweep: 44 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 45 | subdir: ${hydra.job.override_dirname} 46 | 47 | -------------------------------------------------------------------------------- /finetune/configs/experiment/mdr-pool-baug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/mdr-union-b256-baug/checkpoint-19.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/data/ 20 | num_workers: 8 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 8000 25 | max_epochs: 10 26 | accelerator: gpu 27 | devices: 2 28 | 29 | 30 | exp_name: mdr-pool-baug-union-384-10e-aug 31 | 32 | 33 | 34 | hydra: 35 | output_subdir: config 36 | run: 37 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | sweep: 39 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 40 | subdir: ${hydra.job.override_dirname} 41 | 42 | -------------------------------------------------------------------------------- /finetune/configs/experiment/mdr-syn-dec4.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: synth 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/pretrained/trans-hybird-mulemb3-prob/checkpoint-19.pth 13 | batch_size: 192 14 | dec_depth: 4 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/data/ 20 | num_workers: 8 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 5 25 | accelerator: gpu 26 | devices: 4 27 | 28 | 29 | exp_name: mdr-syn-dec4-768-3e4 30 | 31 | 32 | hydra: 33 | output_subdir: config 34 | run: 35 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 36 | sweep: 37 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | subdir: ${hydra.job.override_dirname} 39 | 40 | -------------------------------------------------------------------------------- /finetune/configs/experiment/mdr-syn-dec6.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: synth 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/pretrained/trans-hybird-mulemb3-prob/checkpoint-19.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/data/ 20 | num_workers: 8 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 5 25 | accelerator: gpu 26 | devices: 4 27 | 28 | 29 | exp_name: mdr-syn-dec6-768-3e4 30 | 31 | 32 | hydra: 33 | output_subdir: config 34 | run: 35 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 36 | sweep: 37 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | subdir: ${hydra.job.override_dirname} 39 | 40 | -------------------------------------------------------------------------------- /finetune/configs/experiment/parseq-patch16-224.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | 5 | model: 6 | img_size: [ 224, 224 ] # [ height, width ] 7 | patch_size: [ 16, 16 ] # [ height, width ] 8 | -------------------------------------------------------------------------------- /finetune/configs/experiment/parseq-real.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | - override /charset: 36_lowercase 5 | - override /dataset: real 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/pretrain_encoder/trans-hybird-mulemb3-prob/checkpoint-19.pth 13 | batch_size: 192 14 | 15 | data: 16 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 17 | num_workers: 8 18 | augment: False 19 | 20 | trainer: 21 | # max_steps: 4000 22 | #limit_train_batches: 0.3 # to match the steps per epoch of `real` 23 | #accelerator=gpu trainer.devices=2 24 | accelerator: gpu 25 | devices: 4 26 | 27 | exp_name: parseq-4x192-real-4x4-noaug-pre 28 | 29 | hydra: 30 | output_subdir: config 31 | run: 32 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 33 | sweep: 34 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 35 | subdir: ${hydra.job.override_dirname} 36 | -------------------------------------------------------------------------------- /finetune/configs/experiment/parseq-syn-mdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | - override /charset: 36_lowercase 5 | - override /dataset: synth 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/output/hybird-mul-union-4gpu256b-emb3/checkpoint-19.pth 13 | batch_size: 192 14 | 15 | data: 16 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 17 | num_workers: 8 18 | 19 | trainer: 20 | # max_steps: 4000 21 | #limit_train_batches: 0.3 # to match the steps per epoch of `real` 22 | #accelerator=gpu trainer.devices=2 23 | accelerator: gpu 24 | devices: 4 25 | 26 | exp_name: parseq-4A100-768-syn-4x4-mdr 27 | 28 | hydra: 29 | output_subdir: config 30 | run: 31 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 32 | sweep: 33 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | subdir: ${hydra.job.override_dirname} 35 | -------------------------------------------------------------------------------- /finetune/configs/experiment/parseq-syn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | - override /charset: 36_lowercase 5 | - override /dataset: synth 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/output/pretrain_dig-union/checkpoint-9.pth 13 | batch_size: 192 14 | 15 | data: 16 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 17 | num_workers: 8 18 | 19 | trainer: 20 | # max_steps: 4000 21 | #limit_train_batches: 0.3 # to match the steps per epoch of `real` 22 | #accelerator=gpu trainer.devices=2 23 | accelerator: gpu 24 | devices: 4 25 | 26 | exp_name: parseq-4A100-768-syn-4x4-predig 27 | 28 | hydra: 29 | output_subdir: config 30 | run: 31 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 32 | sweep: 33 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | subdir: ${hydra.job.override_dirname} 35 | -------------------------------------------------------------------------------- /finetune/configs/experiment/parseq-union-pre.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | - override /charset: 36_lowercase 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/output/hybird-mul-union-4gpu256b-emb3/checkpoint-19.pth 13 | batch_size: 192 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 18 | num_workers: 8 19 | augment: False 20 | 21 | trainer: 22 | # max_steps: 4000 23 | max_epochs: 20 24 | gpus: 4 25 | 26 | exp_name: parseq-Union-4x192-noaug-predig-4x4 27 | 28 | hydra: 29 | output_subdir: config 30 | run: 31 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 32 | sweep: 33 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | subdir: ${hydra.job.override_dirname} 35 | 36 | -------------------------------------------------------------------------------- /finetune/configs/experiment/parseq-union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | - override /charset: 36_lowercase 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: False 12 | encoder_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20em3/checkpoint-19.pth 13 | batch_size: 384 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 18 | num_workers: 8 19 | 20 | trainer: 21 | # max_steps: 4000 22 | max_epochs: 20 23 | gpus: 2 24 | 25 | exp_name: parseq-Union-2A100-384-nopre-4x4 26 | 27 | hydra: 28 | output_subdir: config 29 | run: 30 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 31 | sweep: 32 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 33 | subdir: ${hydra.job.override_dirname} -------------------------------------------------------------------------------- /finetune/configs/experiment/parseq.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | -------------------------------------------------------------------------------- /finetune/configs/experiment/rand-sdr18-moco-dec6-union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/pretrain_sdr_moco_20e_randaug/checkpoint-18.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | standard: False 17 | 18 | 19 | data: 20 | root_dir: /home/gaoz/datasets/data/ 21 | num_workers: 8 22 | augment: True 23 | 24 | trainer: 25 | # val_check_interval: 8000 26 | max_epochs: 10 27 | accelerator: gpu 28 | devices: 2 29 | 30 | 31 | exp_name: rand-sdr18-moco-dec6-union-384-10e-aug 32 | 33 | 34 | hydra: 35 | output_subdir: config 36 | run: 37 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | sweep: 39 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 40 | subdir: ${hydra.job.override_dirname} 41 | 42 | -------------------------------------------------------------------------------- /finetune/configs/experiment/randsdr19-moco-dec6-union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/pretrain_sdr_moco_20e_randaug/checkpoint-19.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | standard: False 17 | 18 | 19 | data: 20 | root_dir: /home/gaoz/datasets/data/ 21 | num_workers: 8 22 | augment: True 23 | 24 | trainer: 25 | # val_check_interval: 8000 26 | max_epochs: 10 27 | accelerator: gpu 28 | devices: 2 29 | 30 | 31 | exp_name: rand-sdr19-moco-dec6-union-384-10e-aug 32 | 33 | 34 | hydra: 35 | output_subdir: config 36 | run: 37 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | sweep: 39 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 40 | subdir: ${hydra.job.override_dirname} 41 | 42 | -------------------------------------------------------------------------------- /finetune/configs/experiment/sdr-dino-nocycle-dec6-union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/pretrain_sdr_dino_nocycle_10eb1024/checkpoint-9.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/data 20 | num_workers: 10 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 8000 25 | max_epochs: 10 26 | accelerator: gpu 27 | devices: 2 28 | 29 | 30 | exp_name: sdr-dino-cross-dec6-union-384-10e-aug 31 | 32 | 33 | hydra: 34 | output_subdir: config 35 | run: 36 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | sweep: 38 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | subdir: ${hydra.job.override_dirname} 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/sdr-moco-dec6-union.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/pretrain_sdr_cross_moco_10e/checkpoint-9.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/data/ 20 | num_workers: 6 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 8000 25 | max_epochs: 10 26 | accelerator: gpu 27 | devices: 2 28 | 29 | 30 | exp_name: sdr_cross_moco-dec6-union-384-10e-aug 31 | 32 | 33 | hydra: 34 | output_subdir: config 35 | run: 36 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | sweep: 38 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | subdir: ${hydra.job.override_dirname} 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/sdr-moco-dec6-union20e.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/pretrain_sdr_mocov3_nocycle_10e/checkpoint-9.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/data/ 20 | num_workers: 8 21 | augment: True 22 | 23 | trainer: 24 | val_check_interval: 8000 25 | max_epochs: 20 26 | accelerator: gpu 27 | devices: 2 28 | 29 | 30 | exp_name: sdr-moco-dec6-union-384-20e-aug 31 | 32 | 33 | hydra: 34 | output_subdir: config 35 | run: 36 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | sweep: 38 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | subdir: ${hydra.job.override_dirname} 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/sdr-only2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/pretrain_sdr_20e/checkpoint-19.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/data/ 20 | num_workers: 8 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 8000 25 | max_epochs: 10 26 | accelerator: gpu 27 | devices: 2 28 | 29 | 30 | exp_name: sdr-only-dec6-union-768-10e-aug 31 | 32 | 33 | hydra: 34 | output_subdir: config 35 | run: 36 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | sweep: 38 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | subdir: ${hydra.job.override_dirname} 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/sdr-randaug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/gaoz/output/pretrain_sdr_radaug_new_20eb2048/checkpoint-18.pth 13 | batch_size: 192 14 | dec_depth: 6 15 | lr: 0.0003 16 | 17 | 18 | data: 19 | root_dir: /home/gaoz/datasets/data/ 20 | num_workers: 8 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 5000 25 | max_epochs: 10 26 | accelerator: gpu 27 | devices: 2 28 | 29 | 30 | exp_name: new-sdr18-randaug-dec6-union-384-10e-aug 31 | 32 | 33 | hydra: 34 | output_subdir: config 35 | run: 36 | dir: /home/gaoz/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | sweep: 38 | dir: /home/gaoz/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | subdir: ${hydra.job.override_dirname} 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/trans-cvl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: cvl 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/output/hybird-mul-union-4gpu256b/checkpoint-19.pth 13 | batch_size: 192 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/HandWritten/ 18 | num_workers: 8 19 | augment: True 20 | 21 | trainer: 22 | val_check_interval: 40 23 | accelerator: gpu 24 | devices: 2 25 | 26 | 27 | exp_name: transocr-2A100-384-cvl-4x4-aug-preunion 28 | 29 | 30 | hydra: 31 | output_subdir: config 32 | run: 33 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | sweep: 35 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 36 | subdir: ${hydra.job.override_dirname} 37 | 38 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-base-pre.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | enc_embed_dim: 768 9 | enc_num_heads: 12 10 | enc_mlp_ratio: 4 11 | enc_depth: 12 12 | dec_embed_dim: 768 13 | dec_num_heads: 8 14 | dec_mlp_ratio: 4 15 | dec_depth: 4 16 | patch_size: [ 4, 4 ] 17 | parseq_pretrained: False 18 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 19 | encoder_pretrained: True 20 | encoder_checkpoint: /home/sist/zuangao/output/hybrid-mul-base-realrec/checkpoint-19.pth 21 | batch_size: 192 22 | 23 | data: 24 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 25 | num_workers: 8 26 | augment: False 27 | 28 | trainer: 29 | # val_check_interval: 5 30 | accelerator: gpu 31 | devices: 4 32 | 33 | 34 | exp_name: transocr-vitbase-4A100-192x4-Union-4x4-noaug-94full-prerealrec 35 | 36 | 37 | hydra: 38 | output_subdir: config 39 | run: 40 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 41 | sweep: 42 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 43 | subdir: ${hydra.job.override_dirname} 44 | 45 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-dev.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: real 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/pretrain_encoder/trans-hybird-mulemb3-prob/checkpoint-19.pth 13 | batch_size: 192 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 18 | num_workers: 8 19 | augment: False 20 | 21 | trainer: 22 | # val_check_interval: 5 23 | accelerator: gpu 24 | devices: 4 25 | 26 | 27 | exp_name: transocr-4x192-real-4x4-noaug-prereal 28 | 29 | 30 | hydra: 31 | output_subdir: config 32 | run: 33 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | sweep: 35 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 36 | subdir: ${hydra.job.override_dirname} 37 | 38 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-dev2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/output/hybird-mul-union-4gpu256b-emb3/checkpoint-19.pth 13 | batch_size: 192 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 18 | num_workers: 8 19 | augment: False 20 | 21 | trainer: 22 | # val_check_interval: 5 23 | accelerator: gpu 24 | devices: 4 25 | 26 | 27 | exp_name: transocr-4A100-768-Union-4x4-noaug-preunion-emb3-94full 28 | 29 | 30 | hydra: 31 | output_subdir: config 32 | run: 33 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | sweep: 35 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 36 | subdir: ${hydra.job.override_dirname} 37 | 38 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-dig.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: Union 6 | 7 | model: 8 | dec_num_heads: 8 9 | dec_depth: 6 10 | patch_size: [ 4, 4 ] 11 | # parseq_pretrained: False 12 | # parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 13 | encoder_pretrained: True 14 | encoder_checkpoint: /home/sist/zuangao/output/pretrain_dig-union/checkpoint-9.pth 15 | batch_size: 96 16 | 17 | 18 | data: 19 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 20 | num_workers: 4 21 | augment: True 22 | 23 | trainer: 24 | # val_check_interval: 5 25 | max_epochs: 10 26 | accelerator: gpu 27 | devices: 4 28 | 29 | 30 | exp_name: trstr-6bdec-4x192-Union-4x4-aug-digunion 31 | 32 | 33 | hydra: 34 | output_subdir: config 35 | run: 36 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | sweep: 38 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | subdir: ${hydra.job.override_dirname} 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-iam.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: iam 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/output/hybird-mul-union-4gpu256b/checkpoint-19.pth 13 | batch_size: 192 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/HandWritten/ 18 | num_workers: 8 19 | augment: True 20 | 21 | trainer: 22 | val_check_interval: 40 23 | accelerator: gpu 24 | devices: 2 25 | 26 | 27 | exp_name: transocr-2A100-384-iam-4x4-aug-preunion 28 | 29 | 30 | hydra: 31 | output_subdir: config 32 | run: 33 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | sweep: 35 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 36 | subdir: ${hydra.job.override_dirname} 37 | 38 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-jlw.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: Union 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/pretrain_encoder/FromUnion14M/vit_small_checkpoint-19.pth 13 | batch_size: 192 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 18 | num_workers: 8 19 | augment: false 20 | 21 | trainer: 22 | # val_check_interval: 5 23 | accelerator: gpu 24 | devices: 4 25 | 26 | 27 | exp_name: transocr-jlw-4x192-noaug 28 | 29 | hydra: 30 | output_subdir: config 31 | run: 32 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 33 | sweep: 34 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 35 | subdir: ${hydra.job.override_dirname} 36 | 37 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-maerec.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: synth 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: True 12 | encoder_checkpoint: /home/sist/zuangao/pretrain_encoder/FromUnion14M/vit_small_checkpoint-19.pth 13 | batch_size: 192 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 18 | num_workers: 4 19 | augment: True 20 | 21 | trainer: 22 | # val_check_interval: 5 23 | max_epochs: 20 24 | accelerator: gpu 25 | devices: 4 26 | 27 | 28 | exp_name: transocr-4A100-768-synth-4x4-aug-jlwunion-emb3 29 | 30 | 31 | hydra: 32 | output_subdir: config 33 | run: 34 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 35 | sweep: 36 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | subdir: ${hydra.job.override_dirname} 38 | 39 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-mim.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 36_lowercase 5 | - override /dataset: synth 6 | 7 | model: 8 | patch_size: [ 4, 4 ] 9 | parseq_pretrained: False 10 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 11 | encoder_pretrained: False 12 | encoder_checkpoint: /home/sist/zuangao/output/mae-1024b-real-rec/checkpoint-19.pth 13 | batch_size: 192 14 | 15 | 16 | data: 17 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 18 | num_workers: 4 19 | augment: True 20 | 21 | trainer: 22 | # val_check_interval: 5 23 | max_epochs: 20 24 | accelerator: gpu 25 | devices: 4 26 | 27 | 28 | exp_name: transocr-4A100-768-synth-4x4-nopre 29 | 30 | 31 | hydra: 32 | output_subdir: config 33 | run: 34 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 35 | sweep: 36 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 37 | subdir: ${hydra.job.override_dirname} 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-scenezh.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 6623_chinese 5 | - override /dataset: scenezh 6 | 7 | model: 8 | img_size: [ 32, 400 ] # [ height, width ] 9 | max_label_length: 40 10 | patch_size: [ 4, 4 ] 11 | parseq_pretrained: False 12 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 13 | encoder_pretrained: True 14 | encoder_checkpoint: /home/sist/zuangao/output/hybird-mul-zh-8gpu96bv2/checkpoint-19.pth 15 | batch_size: 128 16 | 17 | 18 | data: 19 | _target_: strhub.data.module.ChineseSceneTextDataModule 20 | root_dir: /home/sist/zuangao/datasets/zh_data/benchmark_dataset/ 21 | num_workers: 8 22 | augment: True 23 | normalize_unicode: False 24 | 25 | trainer: 26 | val_check_interval: 100 27 | max_epochs: 100 28 | accelerator: gpu 29 | devices: 1 30 | 31 | 32 | exp_name: transocr-1A100-128-scenezh-4x4-aug-prezh-re71 33 | 34 | hydra: 35 | output_subdir: config 36 | run: 37 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | sweep: 39 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 40 | subdir: ${hydra.job.override_dirname} 41 | 42 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-tiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | name: transocr-tiny 9 | enc_embed_dim: 192 10 | enc_num_heads: 3 11 | dec_embed_dim: 192 12 | dec_num_heads: 6 13 | patch_size: [ 4, 4 ] 14 | parseq_pretrained: False 15 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 16 | encoder_pretrained: False 17 | encoder_checkpoint: /home/sist/zuangao/output/hybird-mul-union-4gpu256b-emb3/checkpoint-19.pth 18 | batch_size: 192 19 | 20 | 21 | data: 22 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 23 | num_workers: 8 24 | augment: False 25 | 26 | trainer: 27 | # val_check_interval: 5 28 | accelerator: gpu 29 | devices: 4 30 | 31 | 32 | exp_name: transocr-tiny-4A100-768-Union-4x4-noaug-94full 33 | 34 | 35 | hydra: 36 | output_subdir: config 37 | run: 38 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 39 | sweep: 40 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 41 | subdir: ${hydra.job.override_dirname} 42 | 43 | -------------------------------------------------------------------------------- /finetune/configs/experiment/transocr-webzh.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 6623_chinese 5 | - override /dataset: webzh 6 | 7 | model: 8 | img_size: [ 32, 400 ] # [ height, width ] 9 | max_label_length: 40 10 | patch_size: [ 4, 4 ] 11 | parseq_pretrained: False 12 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 13 | encoder_pretrained: false 14 | encoder_checkpoint: /home/sist/zuangao/output/hybird-mul-zh-8gpu96bv2/checkpoint-19.pth 15 | batch_size: 128 16 | 17 | 18 | data: 19 | _target_: strhub.data.module.ChineseSceneTextDataModule 20 | root_dir: /home/sist/zuangao/datasets/zh_data/benchmark_dataset/ 21 | num_workers: 8 22 | augment: True 23 | normalize_unicode: False 24 | 25 | trainer: 26 | val_check_interval: 100 27 | max_epochs: 100 28 | accelerator: gpu 29 | devices: 1 30 | 31 | 32 | exp_name: transocr-1A100-96-webzh-4x4-aug-noprezh 33 | 34 | hydra: 35 | output_subdir: config 36 | run: 37 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | sweep: 39 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 40 | subdir: ${hydra.job.override_dirname} 41 | 42 | -------------------------------------------------------------------------------- /finetune/configs/experiment/trba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: trba 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | ckpt_pretrained: False 9 | ckpt_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 10 | encoder_pretrained: True 11 | encoder_checkpoint: /home/sist/zuangao/pretrain_encoder/trba/checkpoint-19.pth 12 | batch_size: 192 13 | 14 | 15 | data: 16 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 17 | num_workers: 8 18 | augment: False 19 | 20 | trainer: 21 | # val_check_interval: 5 22 | max_epochs: 10 23 | accelerator: gpu 24 | devices: 4 25 | 26 | 27 | exp_name: trba-4A100-768-Union-4x4-noaug-10e-94full-pre 28 | 29 | 30 | hydra: 31 | output_subdir: config 32 | run: 33 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 34 | sweep: 35 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 36 | subdir: ${hydra.job.override_dirname} 37 | 38 | -------------------------------------------------------------------------------- /finetune/configs/experiment/trbc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: trba 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | name: trbc 9 | _target_: strhub.models.trba.system.TRBC 10 | lr: 1e-4 11 | ckpt_pretrained: False 12 | ckpt_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 13 | encoder_pretrained: False 14 | encoder_checkpoint: /home/sist/zuangao/pretrain_encoder/trba/checkpoint-19.pth 15 | batch_size: 96 16 | 17 | data: 18 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 19 | num_workers: 8 20 | augment: False 21 | 22 | trainer: 23 | # val_check_interval: 5 24 | max_epochs: 10 25 | accelerator: gpu 26 | devices: 8 27 | 28 | 29 | exp_name: trbc-8A100-8x96-Union-4x4-noaug-10e-94full-nopre 30 | 31 | 32 | hydra: 33 | output_subdir: config 34 | run: 35 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 36 | sweep: 37 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 38 | subdir: ${hydra.job.override_dirname} 39 | -------------------------------------------------------------------------------- /finetune/configs/experiment/trstr-mdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: transocr 4 | - override /charset: 94_full 5 | - override /dataset: Union 6 | 7 | model: 8 | enc_embed_dim: 768 9 | enc_num_heads: 12 10 | dec_embed_dim: 768 11 | dec_num_heads: 8 12 | dec_depth: 6 13 | patch_size: [ 4, 4 ] 14 | parseq_pretrained: False 15 | parseq_checkpoint: /home/gaoz/output/flip/flipae_vit_base_patch16_dec256d4b-20e/checkpoint-10.pth 16 | encoder_pretrained: True 17 | encoder_checkpoint: /home/sist/zuangao/output/hybrid-mul-base-pretext-4gpus256b/checkpoint-19.pth 18 | batch_size: 256 19 | 20 | 21 | data: 22 | root_dir: /home/sist/zuangao/datasets/unidata/data/ 23 | num_workers: 8 24 | augment: True 25 | 26 | trainer: 27 | # val_check_interval: 5 28 | max_epochs: 10 29 | accelerator: gpu 30 | devices: 8 31 | 32 | 33 | exp_name: trstr-6db10e-8A100-768-4x4-noaug-mdrbase 34 | 35 | 36 | hydra: 37 | output_subdir: config 38 | run: 39 | dir: /home/sist/zuangao/outputs/${model.name}/${exp_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 40 | sweep: 41 | dir: /home/sist/zuangao/multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 42 | subdir: ${hydra.job.override_dirname} 43 | 44 | -------------------------------------------------------------------------------- /finetune/configs/experiment/tune_abinet-lm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: abinet 4 | 5 | model: 6 | name: abinet-lm 7 | lm_only: true 8 | 9 | data: 10 | augment: false 11 | num_workers: 3 12 | 13 | tune: 14 | gpus_per_trial: 0.5 15 | lr: 16 | min: 1e-5 17 | max: 1e-3 18 | -------------------------------------------------------------------------------- /finetune/configs/experiment/vitstr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vitstr 4 | 5 | model: 6 | img_size: [ 32, 128 ] # [ height, width ] 7 | patch_size: [ 4, 8 ] # [ height, width ] 8 | -------------------------------------------------------------------------------- /finetune/configs/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: parseq 4 | - charset: 94_full 5 | - dataset: real 6 | 7 | model: 8 | _convert_: all 9 | img_size: [ 32, 128 ] # [ height, width ] 10 | max_label_length: 25 11 | # The ordering in charset_train matters. It determines the token IDs assigned to each character. 12 | charset_train: ??? 13 | # For charset_test, ordering doesn't matter. 14 | charset_test: "0123456789abcdefghijklmnopqrstuvwxyz" 15 | batch_size: 384 16 | weight_decay: 0.0 17 | warmup_pct: 0.075 # equivalent to 1.5 epochs of warm up 18 | # standard: False 19 | 20 | data: 21 | _target_: strhub.data.module.SceneTextDataModule 22 | root_dir: data 23 | train_dir: ??? 24 | batch_size: ${model.batch_size} 25 | img_size: ${model.img_size} 26 | charset_train: ${model.charset_train} 27 | charset_test: ${model.charset_test} 28 | max_label_length: ${model.max_label_length} 29 | remove_whitespace: true 30 | normalize_unicode: true 31 | augment: true 32 | num_workers: 2 33 | 34 | trainer: 35 | _target_: pytorch_lightning.Trainer 36 | _convert_: all 37 | val_check_interval: 1000 38 | #max_steps: 169680 # 20 epochs x 8484 steps (for batch size = 384, real data) 39 | max_epochs: 20 40 | gradient_clip_val: 20 41 | accelerator: gpu 42 | devices: 2 43 | 44 | ckpt_path: null 45 | pretrained: null 46 | 47 | hydra: 48 | output_subdir: config 49 | run: 50 | dir: outputs/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 51 | sweep: 52 | dir: multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 53 | subdir: ${hydra.job.override_dirname} 54 | -------------------------------------------------------------------------------- /finetune/configs/model/abinet.yaml: -------------------------------------------------------------------------------- 1 | name: abinet 2 | _target_: strhub.models.abinet.system.ABINet 3 | 4 | # Shared Transformer configuration 5 | d_model: 512 6 | nhead: 8 7 | d_inner: 2048 8 | activation: relu 9 | dropout: 0.1 10 | 11 | # Architecture 12 | v_backbone: transformer 13 | v_num_layers: 3 14 | v_attention: position 15 | v_attention_mode: nearest 16 | l_num_layers: 4 17 | l_use_self_attn: false 18 | 19 | # Training 20 | lr: 3.4e-4 21 | l_lr: 3e-4 22 | iter_size: 3 23 | a_loss_weight: 1. 24 | v_loss_weight: 1. 25 | l_loss_weight: 1. 26 | l_detach: true 27 | -------------------------------------------------------------------------------- /finetune/configs/model/clusterseq.yaml: -------------------------------------------------------------------------------- 1 | name: transocr 2 | _target_: strhub.models.transocr.system.ClusterSeq 3 | 4 | # Data 5 | patch_size: [ 4, 8 ] # [ height, width ] 6 | 7 | # Architecture 8 | enc_embed_dim: 384 9 | enc_num_heads: 6 10 | enc_mlp_ratio: 4 11 | enc_depth: 12 12 | dec_embed_dim: 384 13 | dec_num_heads: 12 14 | dec_mlp_ratio: 4 15 | dec_depth: 4 16 | query_type: 'learn' 17 | 18 | # Training 19 | lr: 7e-4 20 | perm_num: 6 21 | perm_forward: true 22 | perm_mirrored: true 23 | dropout: 0.1 24 | 25 | # Decoding mode (test) 26 | decode_ar: true 27 | refine_iters: 1 28 | -------------------------------------------------------------------------------- /finetune/configs/model/crnn.yaml: -------------------------------------------------------------------------------- 1 | name: crnn 2 | _target_: strhub.models.crnn.system.CRNN 3 | 4 | # Architecture 5 | hidden_size: 256 6 | leaky_relu: false 7 | 8 | # Training 9 | lr: 5.1e-4 10 | -------------------------------------------------------------------------------- /finetune/configs/model/parseq.yaml: -------------------------------------------------------------------------------- 1 | name: parseq 2 | _target_: strhub.models.parseq.system.PARSeq 3 | 4 | # Data 5 | patch_size: [ 4, 8 ] # [ height, width ] 6 | 7 | # Architecture 8 | embed_dim: 384 9 | enc_num_heads: 6 10 | enc_mlp_ratio: 4 11 | enc_depth: 12 12 | dec_num_heads: 12 13 | dec_mlp_ratio: 4 14 | dec_depth: 1 15 | 16 | # Training 17 | lr: 7e-4 18 | perm_num: 6 19 | perm_forward: true 20 | perm_mirrored: true 21 | dropout: 0.1 22 | 23 | # Decoding mode (test) 24 | decode_ar: true 25 | refine_iters: 1 26 | -------------------------------------------------------------------------------- /finetune/configs/model/transocr.yaml: -------------------------------------------------------------------------------- 1 | name: transocr 2 | _target_: strhub.models.transocr.system.TransOCR 3 | 4 | # Data 5 | patch_size: [ 4, 8 ] # [ height, width ] 6 | 7 | # Architecture 8 | enc_embed_dim: 384 9 | enc_num_heads: 6 10 | enc_mlp_ratio: 4 11 | enc_depth: 12 12 | dec_embed_dim: 384 13 | dec_num_heads: 12 14 | dec_mlp_ratio: 4 15 | dec_depth: 4 16 | query_type: 'learn' 17 | standard: True 18 | 19 | # Training 20 | lr: 7e-4 21 | perm_num: 6 22 | perm_forward: true 23 | perm_mirrored: true 24 | dropout: 0.1 25 | 26 | # Decoding mode (test) 27 | decode_ar: true 28 | refine_iters: 1 29 | -------------------------------------------------------------------------------- /finetune/configs/model/trba.yaml: -------------------------------------------------------------------------------- 1 | name: trba 2 | _target_: strhub.models.trba.system.TRBA 3 | 4 | # Architecture 5 | num_fiducial: 20 6 | output_channel: 512 7 | hidden_size: 256 8 | 9 | # Training 10 | lr: 6.9e-4 11 | -------------------------------------------------------------------------------- /finetune/configs/model/vitstr.yaml: -------------------------------------------------------------------------------- 1 | name: vitstr 2 | _target_: strhub.models.vitstr.system.ViTSTR 3 | 4 | # Data 5 | img_size: [ 224, 224 ] # [ height, width ] 6 | patch_size: [ 16, 16 ] # [ height, width ] 7 | 8 | # Architecture 9 | embed_dim: 384 10 | num_heads: 6 11 | 12 | # Training 13 | lr: 8.9e-4 14 | -------------------------------------------------------------------------------- /finetune/configs/tune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - main 3 | - _self_ 4 | 5 | trainer: 6 | devices: 1 # tuning with DDP is not yet supported. 7 | 8 | tune: 9 | num_samples: 10 10 | gpus_per_trial: 1 11 | lr: 12 | min: 1e-4 13 | max: 2e-3 14 | resume_dir: null 15 | 16 | hydra: 17 | run: 18 | dir: ray_results/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 19 | -------------------------------------------------------------------------------- /finetune/hubconf.py: -------------------------------------------------------------------------------- 1 | from strhub.models.utils import create_model 2 | 3 | 4 | dependencies = ['torch', 'pytorch_lightning', 'timm'] 5 | 6 | 7 | def parseq_tiny(pretrained: bool = False, decode_ar: bool = True, refine_iters: int = 1, **kwargs): 8 | """ 9 | PARSeq tiny model (img_size=128x32, patch_size=8x4, d_model=192) 10 | @param pretrained: (bool) Use pretrained weights 11 | @param decode_ar: (bool) use AR decoding 12 | @param refine_iters: (int) number of refinement iterations to use 13 | """ 14 | return create_model('parseq-tiny', pretrained, decode_ar=decode_ar, refine_iters=refine_iters, **kwargs) 15 | 16 | 17 | def parseq(pretrained: bool = False, decode_ar: bool = True, refine_iters: int = 1, **kwargs): 18 | """ 19 | PARSeq base model (img_size=128x32, patch_size=8x4, d_model=384) 20 | @param pretrained: (bool) Use pretrained weights 21 | @param decode_ar: (bool) use AR decoding 22 | @param refine_iters: (int) number of refinement iterations to use 23 | """ 24 | return create_model('parseq', pretrained, decode_ar=decode_ar, refine_iters=refine_iters, **kwargs) 25 | 26 | 27 | def abinet(pretrained: bool = False, iter_size: int = 3, **kwargs): 28 | """ 29 | ABINet model (img_size=128x32) 30 | @param pretrained: (bool) Use pretrained weights 31 | @param iter_size: (int) number of refinement iterations to use 32 | """ 33 | return create_model('abinet', pretrained, iter_size=iter_size, **kwargs) 34 | 35 | 36 | def trba(pretrained: bool = False, **kwargs): 37 | """ 38 | TRBA model (img_size=128x32) 39 | @param pretrained: (bool) Use pretrained weights 40 | """ 41 | return create_model('trba', pretrained, **kwargs) 42 | 43 | 44 | def vitstr(pretrained: bool = False, **kwargs): 45 | """ 46 | ViTSTR small model (img_size=128x32, patch_size=8x4, d_model=384) 47 | @param pretrained: (bool) Use pretrained weights 48 | """ 49 | return create_model('vitstr', pretrained, **kwargs) 50 | 51 | 52 | def crnn(pretrained: bool = False, **kwargs): 53 | """ 54 | CRNN model (img_size=128x32) 55 | @param pretrained: (bool) Use pretrained weights 56 | """ 57 | return create_model('crnn', pretrained, **kwargs) 58 | -------------------------------------------------------------------------------- /finetune/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "strhub" 7 | version = "1.1.0" 8 | description = "Scene Text Recognition Model Hub: A collection of deep learning models for Scene Text Recognition" 9 | authors = [ 10 | {name = "Darwin Bautista", email = "baudm@users.noreply.github.com"}, 11 | ] 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | dynamic = ["optional-dependencies"] 15 | 16 | [project.urls] 17 | Homepage = "https://github.com/baudm/parseq" 18 | 19 | [tool.setuptools] 20 | packages = ["strhub"] 21 | license-files = ["NOTICE", "LICENSE", "strhub/models/*/LICENSE"] 22 | 23 | [tool.setuptools.dynamic] 24 | optional-dependencies.train = { file = ["requirements/train.txt"] } 25 | optional-dependencies.test = { file = ["requirements/test.txt"] } 26 | optional-dependencies.bench = { file = ["requirements/bench.txt"] } 27 | optional-dependencies.tune = { file = ["requirements/tune.txt"] } 28 | -------------------------------------------------------------------------------- /finetune/read.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Scene Text Recognition Model Hub 3 | # Copyright 2022 Darwin Bautista 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | 19 | import torch 20 | 21 | from PIL import Image 22 | 23 | from strhub.data.module import SceneTextDataModule 24 | from strhub.models.utils import load_from_checkpoint, parse_model_args 25 | 26 | 27 | @torch.inference_mode() 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('checkpoint', help="Model checkpoint (or 'pretrained=')") 31 | parser.add_argument('--images', nargs='+', help='Images to read') 32 | parser.add_argument('--device', default='cuda') 33 | args, unknown = parser.parse_known_args() 34 | kwargs = parse_model_args(unknown) 35 | print(f'Additional keyword arguments: {kwargs}') 36 | 37 | model = load_from_checkpoint(args.checkpoint, **kwargs).eval().to(args.device) 38 | img_transform = SceneTextDataModule.get_transform(model.hparams.img_size) 39 | 40 | for fname in args.images: 41 | # Load image and prepare for input 42 | image = Image.open(fname).convert('RGB') 43 | image = img_transform(image).unsqueeze(0).to(args.device) 44 | 45 | p = model(image).softmax(-1) 46 | pred, p = model.tokenizer.decode(p) 47 | print(f'{fname}: {pred[0]}') 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /finetune/requirements/bench.in: -------------------------------------------------------------------------------- 1 | -c ${CONSTRAINTS} 2 | 3 | hydra-core >=1.2.0 4 | fvcore >=0.1.5.post20220512 5 | -------------------------------------------------------------------------------- /finetune/requirements/bench.txt: -------------------------------------------------------------------------------- 1 | antlr4-python3-runtime==4.9.3 2 | fvcore==0.1.5.post20221221 3 | hydra-core==1.3.2 4 | importlib-resources==5.12.0 5 | iopath==0.1.10 6 | numpy==1.24.3 7 | omegaconf==2.3.0 8 | packaging==23.1 9 | pillow==9.5.0 10 | portalocker==2.7.0 11 | pyyaml==6.0 12 | tabulate==0.9.0 13 | termcolor==2.3.0 14 | tqdm==4.65.0 15 | typing-extensions==4.6.2 16 | yacs==0.1.8 17 | zipp==3.15.0 18 | -------------------------------------------------------------------------------- /finetune/requirements/core.in: -------------------------------------------------------------------------------- 1 | -c ${CONSTRAINTS} 2 | 3 | torch >=1.10.0, <2.0.0 4 | torchvision >=0.11.0, <0.15.0 5 | timm >=0.6.5 6 | pytorch-lightning >=1.7.0, <2.0.0 # TODO: refactor code to separate model from training code. 7 | nltk >=3.7.0 # TODO: refactor/reorganize code. This is a train/test dependency. 8 | PyYAML >=6.0.0 # TODO: can we move this to train/test? 9 | -------------------------------------------------------------------------------- /finetune/requirements/core.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | async-timeout==4.0.2 6 | attrs==23.1.0 7 | certifi==2023.5.7 8 | charset-normalizer==3.1.0 9 | click==8.0.4 10 | filelock==3.12.0 11 | frozenlist==1.3.3 12 | fsspec[http]==2023.5.0 13 | huggingface-hub==0.15.1 14 | idna==3.4 15 | joblib==1.2.0 16 | lightning-utilities==0.8.0 17 | multidict==6.0.4 18 | nltk==3.8.1 19 | numpy==1.24.3 20 | packaging==23.1 21 | pillow==9.5.0 22 | pytorch-lightning==1.9.5 23 | pyyaml==6.0 24 | regex==2023.5.5 25 | requests==2.31.0 26 | safetensors==0.3.1 27 | timm==0.9.2 28 | torch==1.13.1+cpu 29 | torchmetrics==0.11.4 30 | torchvision==0.14.1+cpu 31 | tqdm==4.65.0 32 | typing-extensions==4.6.2 33 | urllib3==2.0.2 34 | yarl==1.9.2 35 | -------------------------------------------------------------------------------- /finetune/requirements/test.in: -------------------------------------------------------------------------------- 1 | -c ${CONSTRAINTS} 2 | 3 | lmdb >=1.3.0 4 | Pillow >=9.2.0 5 | tqdm >=4.64.0 6 | -------------------------------------------------------------------------------- /finetune/requirements/test.txt: -------------------------------------------------------------------------------- 1 | lmdb==1.4.1 2 | pillow==9.5.0 3 | tqdm==4.65.0 4 | -------------------------------------------------------------------------------- /finetune/requirements/train.in: -------------------------------------------------------------------------------- 1 | -c ${CONSTRAINTS} 2 | 3 | lmdb >=1.3.0 4 | Pillow >=9.2.0 5 | imgaug >=0.4.0 6 | hydra-core >=1.2.0 7 | -------------------------------------------------------------------------------- /finetune/requirements/train.txt: -------------------------------------------------------------------------------- 1 | antlr4-python3-runtime==4.9.3 2 | contourpy==1.0.7 3 | cycler==0.11.0 4 | fonttools==4.39.4 5 | hydra-core==1.3.2 6 | imageio==2.30.0 7 | imgaug==0.4.0 8 | importlib-resources==5.12.0 9 | kiwisolver==1.4.4 10 | lazy-loader==0.2 11 | lmdb==1.4.1 12 | matplotlib==3.7.1 13 | networkx==3.1 14 | numpy==1.24.3 15 | omegaconf==2.3.0 16 | opencv-python==4.7.0.72 17 | packaging==23.1 18 | pillow==9.5.0 19 | pyparsing==3.0.9 20 | python-dateutil==2.8.2 21 | pywavelets==1.4.1 22 | pyyaml==6.0 23 | scikit-image==0.20.0 24 | scipy==1.9.1 25 | shapely==2.0.1 26 | six==1.16.0 27 | tifffile==2023.4.12 28 | zipp==3.15.0 29 | -------------------------------------------------------------------------------- /finetune/requirements/tune.in: -------------------------------------------------------------------------------- 1 | -c ${CONSTRAINTS} 2 | 3 | lmdb >=1.3.0 4 | Pillow >=9.2.0 5 | imgaug >=0.4.0 6 | hydra-core >=1.2.0 7 | ray[tune] >=1.13.0, <2.0.0 8 | ax-platform >=0.2.5.1 9 | -------------------------------------------------------------------------------- /finetune/requirements/tune.txt: -------------------------------------------------------------------------------- 1 | aiosignal==1.3.1 2 | antlr4-python3-runtime==4.9.3 3 | asttokens==2.2.1 4 | attrs==23.1.0 5 | ax-platform==0.3.2 6 | backcall==0.2.0 7 | botorch==0.8.5 8 | certifi==2023.5.7 9 | charset-normalizer==3.1.0 10 | click==8.0.4 11 | comm==0.1.3 12 | contourpy==1.0.7 13 | cycler==0.11.0 14 | debugpy==1.6.7 15 | decorator==5.1.1 16 | distlib==0.3.6 17 | executing==1.2.0 18 | filelock==3.12.0 19 | fonttools==4.39.4 20 | frozenlist==1.3.3 21 | gpytorch==1.10 22 | grpcio==1.43.0 23 | hydra-core==1.3.2 24 | idna==3.4 25 | imageio==2.30.0 26 | imgaug==0.4.0 27 | importlib-metadata==6.6.0 28 | importlib-resources==5.12.0 29 | ipykernel==6.23.1 30 | ipython==8.12.2 31 | ipywidgets==8.0.6 32 | jedi==0.18.2 33 | jinja2==3.1.2 34 | joblib==1.2.0 35 | jsonschema==4.17.3 36 | jupyter-client==8.2.0 37 | jupyter-core==5.3.0 38 | jupyterlab-widgets==3.0.7 39 | kiwisolver==1.4.4 40 | lazy-loader==0.2 41 | linear-operator==0.4.0 42 | lmdb==1.4.1 43 | markupsafe==2.1.2 44 | matplotlib==3.7.1 45 | matplotlib-inline==0.1.6 46 | msgpack==1.0.5 47 | multipledispatch==0.6.0 48 | nest-asyncio==1.5.6 49 | networkx==3.1 50 | numpy==1.24.3 51 | omegaconf==2.3.0 52 | opencv-python==4.7.0.72 53 | opt-einsum==3.3.0 54 | packaging==23.1 55 | pandas==2.0.2 56 | parso==0.8.3 57 | pexpect==4.8.0 58 | pickleshare==0.7.5 59 | pillow==9.5.0 60 | pkgutil-resolve-name==1.3.10 61 | platformdirs==3.5.1 62 | plotly==5.14.1 63 | prompt-toolkit==3.0.38 64 | protobuf==3.20.3 65 | psutil==5.9.5 66 | ptyprocess==0.7.0 67 | pure-eval==0.2.2 68 | pygments==2.15.1 69 | pyparsing==3.0.9 70 | pyro-api==0.1.2 71 | pyro-ppl==1.8.4 72 | pyrsistent==0.19.3 73 | python-dateutil==2.8.2 74 | pytz==2023.3 75 | pywavelets==1.4.1 76 | pyyaml==6.0 77 | pyzmq==25.1.0 78 | ray[tune]==1.13.0 79 | requests==2.31.0 80 | scikit-image==0.20.0 81 | scikit-learn==1.2.2 82 | scipy==1.9.1 83 | shapely==2.0.1 84 | six==1.16.0 85 | stack-data==0.6.2 86 | tabulate==0.9.0 87 | tenacity==8.2.2 88 | tensorboardx==2.6 89 | threadpoolctl==3.1.0 90 | tifffile==2023.4.12 91 | tornado==6.3.2 92 | tqdm==4.65.0 93 | traitlets==5.9.0 94 | typeguard==2.13.3 95 | typing-extensions==4.6.2 96 | tzdata==2023.3 97 | urllib3==2.0.2 98 | virtualenv==20.23.0 99 | wcwidth==0.2.6 100 | widgetsnbextension==4.0.7 101 | zipp==3.15.0 102 | -------------------------------------------------------------------------------- /finetune/strhub/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/__init__.py -------------------------------------------------------------------------------- /finetune/strhub/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__init__.py -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/aa_overrides.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/aa_overrides.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/aa_overrides.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/aa_overrides.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/augment.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/augment.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/augment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/augment.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/module.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/module.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/module.cpython-39.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/data/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/data/aa_overrides.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Extends default ops to accept optional parameters.""" 17 | from functools import partial 18 | 19 | from timm.data.auto_augment import _LEVEL_DENOM, _randomly_negate, LEVEL_TO_ARG, NAME_TO_OP, rotate 20 | 21 | 22 | def rotate_expand(img, degrees, **kwargs): 23 | """Rotate operation with expand=True to avoid cutting off the characters""" 24 | kwargs['expand'] = True 25 | return rotate(img, degrees, **kwargs) 26 | 27 | 28 | def _level_to_arg(level, hparams, key, default): 29 | magnitude = hparams.get(key, default) 30 | level = (level / _LEVEL_DENOM) * magnitude 31 | level = _randomly_negate(level) 32 | return level, 33 | 34 | 35 | def apply(): 36 | # Overrides 37 | NAME_TO_OP.update({ 38 | 'Rotate': rotate_expand 39 | }) 40 | LEVEL_TO_ARG.update({ 41 | 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.), 42 | 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3), 43 | 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3), 44 | 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45), 45 | 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45), 46 | }) 47 | -------------------------------------------------------------------------------- /finetune/strhub/data/augment.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | 18 | import imgaug.augmenters as iaa 19 | import numpy as np 20 | from PIL import ImageFilter, Image 21 | from timm.data import auto_augment 22 | 23 | from strhub.data import aa_overrides 24 | 25 | aa_overrides.apply() 26 | 27 | _OP_CACHE = {} 28 | 29 | 30 | def _get_op(key, factory): 31 | try: 32 | op = _OP_CACHE[key] 33 | except KeyError: 34 | op = factory() 35 | _OP_CACHE[key] = op 36 | return op 37 | 38 | 39 | def _get_param(level, img, max_dim_factor, min_level=1): 40 | max_level = max(min_level, max_dim_factor * max(img.size)) 41 | return round(min(level, max_level)) 42 | 43 | 44 | def gaussian_blur(img, radius, **__): 45 | radius = _get_param(radius, img, 0.02) 46 | key = 'gaussian_blur_' + str(radius) 47 | op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius)) 48 | return img.filter(op) 49 | 50 | 51 | def motion_blur(img, k, **__): 52 | k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values 53 | key = 'motion_blur_' + str(k) 54 | op = _get_op(key, lambda: iaa.MotionBlur(k)) 55 | return Image.fromarray(op(image=np.asarray(img))) 56 | 57 | 58 | def gaussian_noise(img, scale, **_): 59 | scale = _get_param(scale, img, 0.25) | 1 # bin to odd values 60 | key = 'gaussian_noise_' + str(scale) 61 | op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale)) 62 | return Image.fromarray(op(image=np.asarray(img))) 63 | 64 | 65 | def poisson_noise(img, lam, **_): 66 | lam = _get_param(lam, img, 0.2) | 1 # bin to odd values 67 | key = 'poisson_noise_' + str(lam) 68 | op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam)) 69 | return Image.fromarray(op(image=np.asarray(img))) 70 | 71 | 72 | def _level_to_arg(level, _hparams, max): 73 | level = max * level / auto_augment._LEVEL_DENOM 74 | return level, 75 | 76 | 77 | _RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy() 78 | _RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops 79 | _RAND_TRANSFORMS.extend([ 80 | 'GaussianBlur', 81 | # 'MotionBlur', 82 | # 'GaussianNoise', 83 | 'PoissonNoise' 84 | ]) 85 | auto_augment.LEVEL_TO_ARG.update({ 86 | 'GaussianBlur': partial(_level_to_arg, max=4), 87 | 'MotionBlur': partial(_level_to_arg, max=20), 88 | 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255), 89 | 'PoissonNoise': partial(_level_to_arg, max=40) 90 | }) 91 | auto_augment.NAME_TO_OP.update({ 92 | 'GaussianBlur': gaussian_blur, 93 | 'MotionBlur': motion_blur, 94 | 'GaussianNoise': gaussian_noise, 95 | 'PoissonNoise': poisson_noise 96 | }) 97 | 98 | 99 | def rand_augment_transform(magnitude=5, num_layers=3): 100 | # These are tuned for magnitude=5, which means that effective magnitudes are half of these values. 101 | hparams = { 102 | 'rotate_deg': 30, 103 | 'shear_x_pct': 0.9, 104 | 'shear_y_pct': 0.2, 105 | 'translate_x_pct': 0.10, 106 | 'translate_y_pct': 0.30 107 | } 108 | ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS) 109 | # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice) 110 | choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))] 111 | return auto_augment.RandAugment(ra_ops, num_layers, choice_weights) 112 | -------------------------------------------------------------------------------- /finetune/strhub/data/count_word.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import os 3 | import json 4 | 5 | ## 过滤一遍train,记录到json中 6 | 7 | # settings 8 | path_list = ['/home/sist/zuangao/datasets/unidata/data/train/real/OpenVINO/train-2', \ 9 | '/home/sist/zuangao/datasets/unidata/data/train/real/OpenVINO/train_1', \ 10 | '/home/sist/zuangao/datasets/unidata/data/train/real/OpenVINO/train_5', \ 11 | '/home/sist/zuangao/datasets/unidata/data/train/real/OpenVINO/train_f', \ 12 | '/home/sist/zuangao/datasets/unidata/data/train/real/TextOCR/train', \ 13 | '/home/sist/zuangao/datasets/unidata/data/train/real/TextOCR/val', \ 14 | '/home/sist/zuangao/datasets/unidata/data/train/real/MLT19/train',\ 15 | '/home/sist/zuangao/datasets/unidata/data/train/real/MLT19/test',\ 16 | '/home/sist/zuangao/datasets/unidata/data/train/real/MLT19/val',\ 17 | '/home/sist/zuangao/datasets/unidata/data/train/real/RCTW17/train',\ 18 | '/home/sist/zuangao/datasets/unidata/data/train/real/RCTW17/test',\ 19 | '/home/sist/zuangao/datasets/unidata/data/train/real/RCTW17/val',\ 20 | '/home/sist/zuangao/datasets/unidata/data/train/real/COCOv2.0/train',\ 21 | '/home/sist/zuangao/datasets/unidata/data/train/real/COCOv2.0/val',\ 22 | '/home/sist/zuangao/datasets/unidata/data/train/real/ArT/train',\ 23 | '/home/sist/zuangao/datasets/unidata/data/train/real/ArT/val',\ 24 | '/home/sist/zuangao/datasets/unidata/data/train/real/Uber/train',\ 25 | '/home/sist/zuangao/datasets/unidata/data/train/real/Uber/val',\ 26 | '/home/sist/zuangao/datasets/unidata/data/train/real/LSVT/train',\ 27 | '/home/sist/zuangao/datasets/unidata/data/train/real/LSVT/test',\ 28 | '/home/sist/zuangao/datasets/unidata/data/train/real/LSVT/val',\ 29 | '/home/sist/zuangao/datasets/unidata/data/train/real/ReCTS/train',\ 30 | '/home/sist/zuangao/datasets/unidata/data/train/real/ReCTS/test',\ 31 | '/home/sist/zuangao/datasets/unidata/data/train/real/ReCTS/val'] 32 | 33 | data_root = '' 34 | save_root = '/home/sist/zuangao/datasets/unidata/data/test_custom/unseen_origin/' 35 | os.makedirs(save_root,exist_ok=True) 36 | map_size=30073741824 37 | max_length = 25 38 | 39 | ## json: 统计单词出现的频率 40 | record_json = os.path.join(save_root, 'wordfre_json') 41 | f = open(record_json,'w',encoding='utf-8') 42 | 43 | ## 先进行长度统计 44 | 45 | word_fre = {} 46 | 47 | ## 统计所有数据样本 48 | cnt = 0 49 | for data_path in path_list: 50 | data_path = os.path.join(data_root,data_path) 51 | env = lmdb.open(data_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 52 | with env.begin(write=False) as txn: 53 | nSamples = int(txn.get('num-samples'.encode())) 54 | cache = {} 55 | 56 | for index in range(nSamples): 57 | index += 1 # lmdb starts with 1 58 | label_key = 'label-%09d'.encode() % index 59 | label = txn.get(label_key).decode('utf-8') 60 | length = len(label) 61 | 62 | if label in word_fre: 63 | word_fre[label]+=1 64 | else: 65 | word_fre[label] = 1 66 | 67 | if cnt % 5000==0: 68 | print(cnt,'cur:',label, length) 69 | cnt += 1 70 | 71 | json.dump(word_fre, f) 72 | 73 | 74 | -------------------------------------------------------------------------------- /finetune/strhub/data/read_from_lmdb.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import os 3 | import json 4 | from PIL import Image 5 | import re 6 | import six 7 | import numpy as np 8 | 9 | import re 10 | 11 | def get_trailing_number(s): 12 | # 使用正则表达式查找字符串末尾的数字 13 | match = re.search(r'(\d+)$', s) 14 | return int(match.group()) if match else 0 15 | 16 | strings = ["item12", "item2", "item112", "item11"] 17 | sorted_strings = sorted(strings, key=get_trailing_number) 18 | print(sorted_strings) 19 | # 输出: ['item2', 'item11', 'item12', 'item112'] 20 | 21 | 22 | r = '/home/sist/zuangao/datasets/unidata/data/test_custom/union_benchmark_split/' 23 | ps = [ ('union_benchmark_split/' + i) for i in os.listdir(r) if i.startswith('len_') ] 24 | ps = sorted(ps, key=get_trailing_number) 25 | ps = tuple(ps) 26 | print(ps) 27 | 28 | # # settings 29 | # path_list = ['len_31'] 30 | # data_root = '/home/sist/zuangao/datasets/unidata/data/test_custom/union_benchmark_split/' 31 | # save_root = '/home/sist/zuangao/datasets/unidata/data/union_benchmark_split/test_split/' 32 | 33 | # os.makedirs(save_root,exist_ok=True) 34 | # map_size=30073741824 35 | # max_length = 25 36 | 37 | # data_path = os.path.join(data_root, path_list[0]) 38 | # # data_path = '/home/sist/zuangao/datasets/unidata/data/test_custom/filter_union_from_uniontrain/lmdb' 39 | # env = lmdb.open(data_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 40 | # print(data_path) 41 | # print(env) 42 | 43 | # ct = 0 44 | # with env.begin(write=False) as txn: 45 | # nSamples = int(txn.get('num-samples'.encode())) 46 | # print('get in') 47 | # print(nSamples) 48 | # cache = {} 49 | 50 | # for index in range(nSamples): 51 | # index += 1 # lmdb starts with 1 52 | # label_key = 'label-%09d'.encode() % index 53 | # label = txn.get(label_key).decode('utf-8') 54 | 55 | # print(ct,label,len(label)) 56 | # img_key = 'image-%09d'.encode() % index 57 | # imgbuf = txn.get(img_key) 58 | 59 | # buf = six.BytesIO() 60 | # buf.write(imgbuf) 61 | # buf.seek(0) 62 | # try: 63 | # img = Image.open(buf).convert('RGB') # for color image 64 | 65 | 66 | # except IOError: 67 | # print(f'Corrupted image for {index}') 68 | # # make dummy image and dummy label for corrupted image. 69 | # img = Image.new('RGB', (32, 128)) 70 | 71 | # label = '[dummy_label]' 72 | 73 | 74 | # img.save(os.path.join(save_root, str(index)+'.jpg')) 75 | # ct += 1 76 | # if ct==100:break -------------------------------------------------------------------------------- /finetune/strhub/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/__init__.py -------------------------------------------------------------------------------- /finetune/strhub/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/LICENSE: -------------------------------------------------------------------------------- 1 | ABINet for non-commercial purposes 2 | 3 | Copyright (c) 2021, USTC 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang. 3 | "Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." . 4 | In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021. 5 | 6 | https://arxiv.org/abs/2103.06495 7 | 8 | All source files, except `system.py`, are based on the implementation listed below, 9 | and hence are released under the license of the original. 10 | 11 | Source: https://github.com/FangShancheng/ABINet 12 | License: 2-clause BSD License (see included LICENSE file) 13 | """ 14 | -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/backbone.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/model_abinet_iter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/model_abinet_iter.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/model_alignment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/model_alignment.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/model_language.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/model_language.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/model_vision.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/model_vision.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/system.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/system.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/abinet/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import TransformerEncoderLayer, TransformerEncoder 3 | 4 | from .resnet import resnet45 5 | from .transformer import PositionalEncoding 6 | 7 | 8 | class ResTranformer(nn.Module): 9 | def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2): 10 | super().__init__() 11 | self.resnet = resnet45() 12 | self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32) 13 | encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, 14 | dim_feedforward=d_inner, dropout=dropout, activation=activation) 15 | self.transformer = TransformerEncoder(encoder_layer, backbone_ln) 16 | 17 | def forward(self, images): 18 | feature = self.resnet(images) 19 | n, c, h, w = feature.shape 20 | feature = feature.view(n, c, -1).permute(2, 0, 1) 21 | feature = self.pos_encoder(feature) 22 | feature = self.transformer(feature) 23 | feature = feature.permute(1, 2, 0).view(n, c, h, w) 24 | return feature 25 | -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Model(nn.Module): 6 | 7 | def __init__(self, dataset_max_length: int, null_label: int): 8 | super().__init__() 9 | self.max_length = dataset_max_length + 1 # additional stop token 10 | self.null_label = null_label 11 | 12 | def _get_length(self, logit, dim=-1): 13 | """ Greed decoder to obtain length from logit""" 14 | out = (logit.argmax(dim=-1) == self.null_label) 15 | abn = out.any(dim) 16 | out = ((out.cumsum(dim) == 1) & out).max(dim)[1] 17 | out = out + 1 # additional end token 18 | out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device)) 19 | return out 20 | 21 | @staticmethod 22 | def _get_padding_mask(length, max_length): 23 | length = length.unsqueeze(-1) 24 | grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) 25 | return grid >= length 26 | 27 | @staticmethod 28 | def _get_location_mask(sz, device=None): 29 | mask = torch.eye(sz, device=device) 30 | mask = mask.float().masked_fill(mask == 1, float('-inf')) 31 | return mask 32 | -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/model_abinet_iter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .model_alignment import BaseAlignment 5 | from .model_language import BCNLanguage 6 | from .model_vision import BaseVision 7 | 8 | 9 | class ABINetIterModel(nn.Module): 10 | def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1, 11 | d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', 12 | v_loss_weight=1., v_attention='position', v_attention_mode='nearest', 13 | v_backbone='transformer', v_num_layers=2, 14 | l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False, 15 | a_loss_weight=1.): 16 | super().__init__() 17 | self.iter_size = iter_size 18 | self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode, 19 | v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers) 20 | self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout, 21 | activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight) 22 | self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight) 23 | 24 | def forward(self, images): 25 | v_res = self.vision(images) 26 | a_res = v_res 27 | all_l_res, all_a_res = [], [] 28 | for _ in range(self.iter_size): 29 | tokens = torch.softmax(a_res['logits'], dim=-1) 30 | lengths = a_res['pt_lengths'] 31 | lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model 32 | l_res = self.language(tokens, lengths) 33 | all_l_res.append(l_res) 34 | a_res = self.alignment(l_res['feature'], v_res['feature']) 35 | all_a_res.append(a_res) 36 | if self.training: 37 | return all_a_res, all_l_res, v_res 38 | else: 39 | return a_res, all_l_res[-1], v_res 40 | -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/model_alignment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .model import Model 5 | 6 | 7 | class BaseAlignment(Model): 8 | def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0): 9 | super().__init__(dataset_max_length, null_label) 10 | self.loss_weight = loss_weight 11 | self.w_att = nn.Linear(2 * d_model, d_model) 12 | self.cls = nn.Linear(d_model, num_classes) 13 | 14 | def forward(self, l_feature, v_feature): 15 | """ 16 | Args: 17 | l_feature: (N, T, E) where T is length, N is batch size and d is dim of model 18 | v_feature: (N, T, E) shape the same as l_feature 19 | """ 20 | f = torch.cat((l_feature, v_feature), dim=2) 21 | f_att = torch.sigmoid(self.w_att(f)) 22 | output = f_att * v_feature + (1 - f_att) * l_feature 23 | 24 | logits = self.cls(output) # (N, T, C) 25 | pt_lengths = self._get_length(logits) 26 | 27 | return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight, 28 | 'name': 'alignment'} 29 | -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/model_language.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import TransformerDecoder 3 | 4 | from .model import Model 5 | from .transformer import PositionalEncoding, TransformerDecoderLayer 6 | 7 | 8 | class BCNLanguage(Model): 9 | def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1, 10 | activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0, 11 | global_debug=False): 12 | super().__init__(dataset_max_length, null_label) 13 | self.detach = detach 14 | self.loss_weight = loss_weight 15 | self.proj = nn.Linear(num_classes, d_model, False) 16 | self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) 17 | self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) 18 | decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, 19 | activation, self_attn=use_self_attn, debug=global_debug) 20 | self.model = TransformerDecoder(decoder_layer, num_layers) 21 | self.cls = nn.Linear(d_model, num_classes) 22 | 23 | def forward(self, tokens, lengths): 24 | """ 25 | Args: 26 | tokens: (N, T, C) where T is length, N is batch size and C is classes number 27 | lengths: (N,) 28 | """ 29 | if self.detach: 30 | tokens = tokens.detach() 31 | embed = self.proj(tokens) # (N, T, E) 32 | embed = embed.permute(1, 0, 2) # (T, N, E) 33 | embed = self.token_encoder(embed) # (T, N, E) 34 | padding_mask = self._get_padding_mask(lengths, self.max_length) 35 | 36 | zeros = embed.new_zeros(*embed.shape) 37 | qeury = self.pos_encoder(zeros) 38 | location_mask = self._get_location_mask(self.max_length, tokens.device) 39 | output = self.model(qeury, embed, 40 | tgt_key_padding_mask=padding_mask, 41 | memory_mask=location_mask, 42 | memory_key_padding_mask=padding_mask) # (T, N, E) 43 | output = output.permute(1, 0, 2) # (N, T, E) 44 | 45 | logits = self.cls(output) # (N, T, C) 46 | pt_lengths = self._get_length(logits) 47 | 48 | res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, 49 | 'loss_weight': self.loss_weight, 'name': 'language'} 50 | return res 51 | -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/model_vision.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .attention import PositionAttention, Attention 4 | from .backbone import ResTranformer 5 | from .model import Model 6 | from .resnet import resnet45 7 | 8 | 9 | class BaseVision(Model): 10 | def __init__(self, dataset_max_length, null_label, num_classes, 11 | attention='position', attention_mode='nearest', loss_weight=1.0, 12 | d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', 13 | backbone='transformer', backbone_ln=2): 14 | super().__init__(dataset_max_length, null_label) 15 | self.loss_weight = loss_weight 16 | self.out_channels = d_model 17 | 18 | if backbone == 'transformer': 19 | self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln) 20 | else: 21 | self.backbone = resnet45() 22 | 23 | if attention == 'position': 24 | self.attention = PositionAttention( 25 | max_length=self.max_length, 26 | mode=attention_mode 27 | ) 28 | elif attention == 'attention': 29 | self.attention = Attention( 30 | max_length=self.max_length, 31 | n_feature=8 * 32, 32 | ) 33 | else: 34 | raise ValueError(f'invalid attention: {attention}') 35 | 36 | self.cls = nn.Linear(self.out_channels, num_classes) 37 | 38 | def forward(self, images): 39 | features = self.backbone(images) # (N, E, H, W) 40 | attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) 41 | logits = self.cls(attn_vecs) # (N, T, C) 42 | pt_lengths = self._get_length(logits) 43 | 44 | return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, 45 | 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'} 46 | -------------------------------------------------------------------------------- /finetune/strhub/models/abinet/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Callable 3 | 4 | import torch.nn as nn 5 | from torchvision.models import resnet 6 | 7 | 8 | class BasicBlock(resnet.BasicBlock): 9 | 10 | def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, 11 | groups: int = 1, base_width: int = 64, dilation: int = 1, 12 | norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: 13 | super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer) 14 | self.conv1 = resnet.conv1x1(inplanes, planes) 15 | self.conv2 = resnet.conv3x3(planes, planes, stride) 16 | 17 | 18 | class ResNet(nn.Module): 19 | 20 | def __init__(self, block, layers): 21 | super().__init__() 22 | self.inplanes = 32 23 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, 24 | bias=False) 25 | self.bn1 = nn.BatchNorm2d(32) 26 | self.relu = nn.ReLU(inplace=True) 27 | 28 | self.layer1 = self._make_layer(block, 32, layers[0], stride=2) 29 | self.layer2 = self._make_layer(block, 64, layers[1], stride=1) 30 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 31 | self.layer4 = self._make_layer(block, 256, layers[3], stride=1) 32 | self.layer5 = self._make_layer(block, 512, layers[4], stride=1) 33 | 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | m.weight.data.normal_(0, math.sqrt(2. / n)) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | m.weight.data.fill_(1) 40 | m.bias.data.zero_() 41 | 42 | def _make_layer(self, block, planes, blocks, stride=1): 43 | downsample = None 44 | if stride != 1 or self.inplanes != planes * block.expansion: 45 | downsample = nn.Sequential( 46 | nn.Conv2d(self.inplanes, planes * block.expansion, 47 | kernel_size=1, stride=stride, bias=False), 48 | nn.BatchNorm2d(planes * block.expansion), 49 | ) 50 | 51 | layers = [] 52 | layers.append(block(self.inplanes, planes, stride, downsample)) 53 | self.inplanes = planes * block.expansion 54 | for i in range(1, blocks): 55 | layers.append(block(self.inplanes, planes)) 56 | 57 | return nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | x = self.conv1(x) 61 | x = self.bn1(x) 62 | x = self.relu(x) 63 | x = self.layer1(x) 64 | x = self.layer2(x) 65 | x = self.layer3(x) 66 | x = self.layer4(x) 67 | x = self.layer5(x) 68 | return x 69 | 70 | 71 | def resnet45(): 72 | return ResNet(BasicBlock, [3, 4, 6, 6, 3]) 73 | -------------------------------------------------------------------------------- /finetune/strhub/models/clusterseq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/clusterseq/__init__.py -------------------------------------------------------------------------------- /finetune/strhub/models/clusterseq/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/clusterseq/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/clusterseq/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/clusterseq/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/clusterseq/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/clusterseq/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/clusterseq/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/clusterseq/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/clusterseq/__pycache__/system.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/clusterseq/__pycache__/system.cpython-37.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/clusterseq/__pycache__/system.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/clusterseq/__pycache__/system.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/crnn/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017 Jieru Mei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /finetune/strhub/models/crnn/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Shi, Baoguang, Xiang Bai, and Cong Yao. 3 | "An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition." 4 | IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304. 5 | 6 | https://arxiv.org/abs/1507.05717 7 | 8 | All source files, except `system.py`, are based on the implementation listed below, 9 | and hence are released under the license of the original. 10 | 11 | Source: https://github.com/meijieru/crnn.pytorch 12 | License: MIT License (see included LICENSE file) 13 | """ 14 | -------------------------------------------------------------------------------- /finetune/strhub/models/crnn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/crnn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/crnn/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/crnn/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/crnn/__pycache__/system.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/crnn/__pycache__/system.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/crnn/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from strhub.models.modules import BidirectionalLSTM 4 | 5 | 6 | class CRNN(nn.Module): 7 | 8 | def __init__(self, img_h, nc, nclass, nh, leaky_relu=False): 9 | super().__init__() 10 | assert img_h % 16 == 0, 'img_h has to be a multiple of 16' 11 | 12 | ks = [3, 3, 3, 3, 3, 3, 2] 13 | ps = [1, 1, 1, 1, 1, 1, 0] 14 | ss = [1, 1, 1, 1, 1, 1, 1] 15 | nm = [64, 128, 256, 256, 512, 512, 512] 16 | 17 | cnn = nn.Sequential() 18 | 19 | def convRelu(i, batchNormalization=False): 20 | nIn = nc if i == 0 else nm[i - 1] 21 | nOut = nm[i] 22 | cnn.add_module(f'conv{i}', 23 | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization)) 24 | if batchNormalization: 25 | cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(nOut)) 26 | if leaky_relu: 27 | cnn.add_module(f'relu{i}', 28 | nn.LeakyReLU(0.2, inplace=True)) 29 | else: 30 | cnn.add_module(f'relu{i}', nn.ReLU(True)) 31 | 32 | convRelu(0) 33 | cnn.add_module('pooling0', nn.MaxPool2d(2, 2)) # 64x16x64 34 | convRelu(1) 35 | cnn.add_module('pooling1', nn.MaxPool2d(2, 2)) # 128x8x32 36 | convRelu(2, True) 37 | convRelu(3) 38 | cnn.add_module('pooling2', 39 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 40 | convRelu(4, True) 41 | convRelu(5) 42 | cnn.add_module('pooling3', 43 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 44 | convRelu(6, True) # 512x1x16 45 | 46 | self.cnn = cnn 47 | self.rnn = nn.Sequential( 48 | BidirectionalLSTM(512, nh, nh), 49 | BidirectionalLSTM(nh, nh, nclass)) 50 | 51 | def forward(self, input): 52 | # conv features 53 | conv = self.cnn(input) 54 | b, c, h, w = conv.size() 55 | assert h == 1, 'the height of conv must be 1' 56 | conv = conv.squeeze(2) 57 | conv = conv.transpose(1, 2) # [b, w, c] 58 | 59 | # rnn features 60 | output = self.rnn(conv) 61 | 62 | return output 63 | -------------------------------------------------------------------------------- /finetune/strhub/models/crnn/system.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Sequence, Optional 17 | 18 | from pytorch_lightning.utilities.types import STEP_OUTPUT 19 | from torch import Tensor 20 | 21 | from strhub.models.base import CTCSystem 22 | from strhub.models.utils import init_weights 23 | from .model import CRNN as Model 24 | 25 | 26 | class CRNN(CTCSystem): 27 | 28 | def __init__(self, charset_train: str, charset_test: str, max_label_length: int, 29 | batch_size: int, lr: float, warmup_pct: float, weight_decay: float, 30 | img_size: Sequence[int], hidden_size: int, leaky_relu: bool, **kwargs) -> None: 31 | super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) 32 | self.save_hyperparameters() 33 | self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu) 34 | self.model.apply(init_weights) 35 | 36 | def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: 37 | return self.model.forward(images) 38 | 39 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 40 | images, labels = batch 41 | loss = self.forward_logits_loss(images, labels)[1] 42 | self.log('loss', loss) 43 | return loss 44 | -------------------------------------------------------------------------------- /finetune/strhub/models/modules.py: -------------------------------------------------------------------------------- 1 | r"""Shared modules used by CRNN and TRBA""" 2 | from torch import nn 3 | 4 | 5 | class BidirectionalLSTM(nn.Module): 6 | """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py""" 7 | 8 | def __init__(self, input_size, hidden_size, output_size): 9 | super().__init__() 10 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 11 | self.linear = nn.Linear(hidden_size * 2, output_size) 12 | 13 | def forward(self, input): 14 | """ 15 | input : visual feature [batch_size x T x input_size], T = num_steps. 16 | output : contextual feature [batch_size x T x output_size] 17 | """ 18 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 19 | output = self.linear(recurrent) # batch_size x T x output_size 20 | return output 21 | -------------------------------------------------------------------------------- /finetune/strhub/models/parseq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/parseq/__init__.py -------------------------------------------------------------------------------- /finetune/strhub/models/parseq/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/parseq/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/parseq/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/parseq/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/parseq/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/parseq/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/parseq/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/parseq/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/parseq/__pycache__/system.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/parseq/__pycache__/system.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/parseq/__pycache__/system.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/parseq/__pycache__/system.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/transocr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/transocr/__init__.py -------------------------------------------------------------------------------- /finetune/strhub/models/transocr/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/transocr/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/transocr/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/transocr/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/transocr/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/transocr/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/transocr/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/transocr/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/transocr/__pycache__/system.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/transocr/__pycache__/system.cpython-37.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/transocr/__pycache__/system.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/transocr/__pycache__/system.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/trba/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Baek, Jeonghun, Geewook Kim, Junyeop Lee, Sungrae Park, Dongyoon Han, Sangdoo Yun, Seong Joon Oh, and Hwalsuk Lee. 3 | "What is wrong with scene text recognition model comparisons? dataset and model analysis." 4 | In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4715-4723. 2019. 5 | 6 | https://arxiv.org/abs/1904.01906 7 | 8 | All source files, except `system.py`, are based on the implementation listed below, 9 | and hence are released under the license of the original. 10 | 11 | Source: https://github.com/clovaai/deep-text-recognition-benchmark 12 | License: Apache License 2.0 (see LICENSE file in project root) 13 | """ 14 | -------------------------------------------------------------------------------- /finetune/strhub/models/trba/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/trba/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/trba/__pycache__/feature_extraction.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/trba/__pycache__/feature_extraction.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/trba/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/trba/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/trba/__pycache__/prediction.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/trba/__pycache__/prediction.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/trba/__pycache__/system.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/trba/__pycache__/system.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/trba/__pycache__/transformation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/finetune/strhub/models/trba/__pycache__/transformation.cpython-38.pyc -------------------------------------------------------------------------------- /finetune/strhub/models/trba/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from strhub.models.modules import BidirectionalLSTM 4 | from .feature_extraction import ResNet_FeatureExtractor 5 | from .prediction import Attention 6 | from .transformation import TPS_SpatialTransformerNetwork 7 | 8 | 9 | class TRBA(nn.Module): 10 | 11 | def __init__(self, img_h, img_w, num_class, num_fiducial=20, input_channel=3, output_channel=512, hidden_size=256, 12 | use_ctc=False): 13 | super().__init__() 14 | """ Transformation """ 15 | self.Transformation = TPS_SpatialTransformerNetwork( 16 | F=num_fiducial, I_size=(img_h, img_w), I_r_size=(img_h, img_w), 17 | I_channel_num=input_channel) 18 | 19 | """ FeatureExtraction """ 20 | self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel) 21 | self.FeatureExtraction_output = output_channel 22 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 23 | 24 | """ Sequence modeling""" 25 | self.SequenceModeling = nn.Sequential( 26 | BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), 27 | BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) 28 | self.SequenceModeling_output = hidden_size 29 | 30 | """ Prediction """ 31 | if use_ctc: 32 | self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) 33 | else: 34 | self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class) 35 | 36 | def forward(self, image, max_label_length, text=None): 37 | """ Transformation stage """ 38 | image = self.Transformation(image) 39 | 40 | """ Feature extraction stage """ 41 | visual_feature = self.FeatureExtraction(image) 42 | visual_feature = visual_feature.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h] 43 | visual_feature = self.AdaptiveAvgPool(visual_feature) # [b, w, c, h] -> [b, w, c, 1] 44 | visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c] 45 | 46 | """ Sequence modeling stage """ 47 | contextual_feature = self.SequenceModeling(visual_feature) # [b, num_steps, hidden_size] 48 | 49 | """ Prediction stage """ 50 | if isinstance(self.Prediction, Attention): 51 | prediction = self.Prediction(contextual_feature.contiguous(), text, max_label_length) 52 | else: 53 | prediction = self.Prediction(contextual_feature.contiguous()) # CTC 54 | 55 | return prediction # [b, num_steps, num_class] 56 | -------------------------------------------------------------------------------- /finetune/strhub/models/trba/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | 8 | def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256): 9 | super().__init__() 10 | self.attention_cell = AttentionCell(input_size, hidden_size, num_char_embeddings) 11 | self.hidden_size = hidden_size 12 | self.num_class = num_class 13 | self.generator = nn.Linear(hidden_size, num_class) 14 | self.char_embeddings = nn.Embedding(num_class, num_char_embeddings) 15 | 16 | def forward(self, batch_H, text, max_label_length=25): 17 | """ 18 | input: 19 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_class] 20 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS]. 21 | output: probability distribution at each step [batch_size x num_steps x num_class] 22 | """ 23 | batch_size = batch_H.size(0) 24 | num_steps = max_label_length + 1 # +1 for [EOS] at end of sentence. 25 | 26 | output_hiddens = batch_H.new_zeros((batch_size, num_steps, self.hidden_size), dtype=torch.float) 27 | hidden = (batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float), 28 | batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float)) 29 | 30 | if self.training: 31 | for i in range(num_steps): 32 | char_embeddings = self.char_embeddings(text[:, i]) 33 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1}) 34 | hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) 35 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 36 | probs = self.generator(output_hiddens) 37 | 38 | else: 39 | targets = text[0].expand(batch_size) # should be fill with [SOS] token 40 | probs = batch_H.new_zeros((batch_size, num_steps, self.num_class), dtype=torch.float) 41 | 42 | for i in range(num_steps): 43 | char_embeddings = self.char_embeddings(targets) 44 | hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) 45 | probs_step = self.generator(hidden[0]) 46 | probs[:, i, :] = probs_step 47 | _, next_input = probs_step.max(1) 48 | targets = next_input 49 | 50 | return probs # batch_size x num_steps x num_class 51 | 52 | 53 | class AttentionCell(nn.Module): 54 | 55 | def __init__(self, input_size, hidden_size, num_embeddings): 56 | super().__init__() 57 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 58 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 59 | self.score = nn.Linear(hidden_size, 1, bias=False) 60 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 61 | self.hidden_size = hidden_size 62 | 63 | def forward(self, prev_hidden, batch_H, char_embeddings): 64 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 65 | batch_H_proj = self.i2h(batch_H) 66 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 67 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 68 | 69 | alpha = F.softmax(e, dim=1) 70 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 71 | concat_context = torch.cat([context, char_embeddings], 1) # batch_size x (num_channel + num_embedding) 72 | cur_hidden = self.rnn(concat_context, prev_hidden) 73 | return cur_hidden, alpha 74 | -------------------------------------------------------------------------------- /finetune/strhub/models/vitstr/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Atienza, Rowel. "Vision Transformer for Fast and Efficient Scene Text Recognition." 3 | In International Conference on Document Analysis and Recognition (ICDAR). 2021. 4 | 5 | https://arxiv.org/abs/2105.08582 6 | 7 | All source files, except `system.py`, are based on the implementation listed below, 8 | and hence are released under the license of the original. 9 | 10 | Source: https://github.com/roatienza/deep-text-recognition-benchmark 11 | License: Apache License 2.0 (see LICENSE file in project root) 12 | """ 13 | -------------------------------------------------------------------------------- /finetune/strhub/models/vitstr/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of ViTSTR based on timm VisionTransformer. 3 | 4 | TODO: 5 | 1) distilled deit backbone 6 | 2) base deit backbone 7 | 8 | Copyright 2021 Rowel Atienza 9 | """ 10 | 11 | from timm.models.vision_transformer import VisionTransformer 12 | 13 | 14 | class ViTSTR(VisionTransformer): 15 | """ 16 | ViTSTR is basically a ViT that uses DeiT weights. 17 | Modified head to support a sequence of characters prediction for STR. 18 | """ 19 | 20 | def forward(self, x, seqlen: int = 25): 21 | x = self.forward_features(x) 22 | x = x[:, :seqlen] 23 | 24 | # batch, seqlen, embsize 25 | b, s, e = x.size() 26 | x = x.reshape(b * s, e) 27 | x = self.head(x).view(b, s, self.num_classes) 28 | return x 29 | -------------------------------------------------------------------------------- /finetune/strhub/models/vitstr/system.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Sequence, Any, Optional 17 | 18 | import torch 19 | from pytorch_lightning.utilities.types import STEP_OUTPUT 20 | from torch import Tensor 21 | 22 | from strhub.models.base import CrossEntropySystem 23 | from strhub.models.utils import init_weights 24 | from .model import ViTSTR as Model 25 | 26 | 27 | class ViTSTR(CrossEntropySystem): 28 | 29 | def __init__(self, charset_train: str, charset_test: str, max_label_length: int, 30 | batch_size: int, lr: float, warmup_pct: float, weight_decay: float, 31 | img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, num_heads: int, 32 | **kwargs: Any) -> None: 33 | super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) 34 | self.save_hyperparameters() 35 | self.max_label_length = max_label_length 36 | # We don't predict nor 37 | self.model = Model(img_size=img_size, patch_size=patch_size, depth=12, mlp_ratio=4, qkv_bias=True, 38 | embed_dim=embed_dim, num_heads=num_heads, num_classes=len(self.tokenizer) - 2) 39 | # Non-zero weight init for the head 40 | self.model.head.apply(init_weights) 41 | 42 | @torch.jit.ignore 43 | def no_weight_decay(self): 44 | return {'model.' + n for n in self.model.no_weight_decay()} 45 | 46 | def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: 47 | max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) 48 | logits = self.model.forward(images, max_length + 2) # +2 tokens for [GO] and [s] 49 | # Truncate to conform to other models. [GO] in ViTSTR is actually used as the padding (therefore, ignored). 50 | # First position corresponds to the class token, which is unused and ignored in the original work. 51 | logits = logits[:, 1:] 52 | return logits 53 | 54 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 55 | images, labels = batch 56 | loss = self.forward_logits_loss(images, labels)[1] 57 | self.log('loss', loss) 58 | return loss 59 | -------------------------------------------------------------------------------- /finetune/tools/art_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | 5 | with open('train_task2_labels.json', 'r', encoding='utf8') as f: 6 | d = json.load(f) 7 | 8 | with open('gt.txt', 'w', encoding='utf8') as f: 9 | for k, v in d.items(): 10 | if len(v) != 1: 11 | print('error', v) 12 | v = v[0] 13 | if v['language'].lower() != 'latin': 14 | # print('Skipping non-Latin:', v) 15 | continue 16 | if v['illegibility']: 17 | # print('Skipping unreadable:', v) 18 | continue 19 | label = v['transcription'].strip() 20 | if not label: 21 | # print('Skipping blank label') 22 | continue 23 | if '#' in label and label != 'LocaL#3': 24 | # print('Skipping corrupted label') 25 | continue 26 | f.write('\t'.join(['train_task2_images/' + k + '.jpg', label]) + '\n') 27 | -------------------------------------------------------------------------------- /finetune/tools/case_sensitive_str_datasets_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os.path 4 | import sys 5 | from pathlib import Path 6 | 7 | d = sys.argv[1] 8 | p = Path(d) 9 | 10 | gt = [] 11 | 12 | num_samples = len(list(p.glob('label/*.txt'))) 13 | ext = 'jpg' if p.joinpath('IMG', '1.jpg').is_file() else 'png' 14 | 15 | for i in range(1, num_samples + 1): 16 | img = p.joinpath('IMG', f'{i}.{ext}') 17 | name = os.path.splitext(img.name)[0] 18 | 19 | with open(p.joinpath('label', f'{i}.txt'), 'r') as f: 20 | label = f.readline() 21 | gt.append((os.path.join('IMG', img.name), label)) 22 | 23 | with open(d + '/lmdb.txt', 'w', encoding='utf-8') as f: 24 | for line in gt: 25 | fname, label = line 26 | fname = fname.strip() 27 | label = label.strip() 28 | f.write('\t'.join([fname, label]) + '\n') 29 | -------------------------------------------------------------------------------- /finetune/tools/coco_text_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | for s in ['train', 'val']: 4 | with open('{}_words_gt.txt'.format(s), 'r', encoding='utf8') as f: 5 | d = f.readlines() 6 | 7 | with open('{}_lmdb.txt'.format(s), 'w', encoding='utf8') as f: 8 | for line in d: 9 | try: 10 | fname, label = line.split(',', maxsplit=1) 11 | except ValueError: 12 | continue 13 | fname = '{}_words/{}.jpg'.format(s, fname.strip()) 14 | label = label.strip().strip('|') 15 | f.write('\t'.join([fname, label]) + '\n') 16 | -------------------------------------------------------------------------------- /finetune/tools/create_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ 3 | import io 4 | import os 5 | 6 | import fire 7 | import lmdb 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | def checkImageIsValid(imageBin): 13 | if imageBin is None: 14 | return False 15 | img = Image.open(io.BytesIO(imageBin)).convert('RGB') 16 | return np.prod(img.size) > 0 17 | 18 | 19 | def writeCache(env, cache): 20 | with env.begin(write=True) as txn: 21 | for k, v in cache.items(): 22 | txn.put(k, v) 23 | 24 | 25 | def createDataset(inputPath, gtFile, outputPath, checkValid=True): 26 | """ 27 | Create LMDB dataset for training and evaluation. 28 | ARGS: 29 | inputPath : input folder path where starts imagePath 30 | outputPath : LMDB output path 31 | gtFile : list of image path and label 32 | checkValid : if true, check the validity of every image 33 | """ 34 | os.makedirs(outputPath, exist_ok=True) 35 | env = lmdb.open(outputPath, map_size=1099511627776) 36 | 37 | cache = {} 38 | cnt = 1 39 | 40 | with open(gtFile, 'r', encoding='utf-8') as f: 41 | data = f.readlines() 42 | 43 | nSamples = len(data) 44 | for i, line in enumerate(data): 45 | imagePath, label = line.strip().split(maxsplit=1) 46 | imagePath = os.path.join(inputPath, imagePath) 47 | with open(imagePath, 'rb') as f: 48 | imageBin = f.read() 49 | if checkValid: 50 | try: 51 | img = Image.open(io.BytesIO(imageBin)).convert('RGB') 52 | except IOError as e: 53 | with open(outputPath + '/error_image_log.txt', 'a') as log: 54 | log.write('{}-th image data occured error: {}, {}\n'.format(i, imagePath, e)) 55 | continue 56 | if np.prod(img.size) == 0: 57 | print('%s is not a valid image' % imagePath) 58 | continue 59 | 60 | imageKey = 'image-%09d'.encode() % cnt 61 | labelKey = 'label-%09d'.encode() % cnt 62 | cache[imageKey] = imageBin 63 | cache[labelKey] = label.encode() 64 | 65 | if cnt % 1000 == 0: 66 | writeCache(env, cache) 67 | cache = {} 68 | print('Written %d / %d' % (cnt, nSamples)) 69 | cnt += 1 70 | nSamples = cnt - 1 71 | cache['num-samples'.encode()] = str(nSamples).encode() 72 | writeCache(env, cache) 73 | env.close() 74 | print('Created dataset with %d samples' % nSamples) 75 | 76 | 77 | if __name__ == '__main__': 78 | fire.Fire(createDataset) 79 | -------------------------------------------------------------------------------- /finetune/tools/filter_lmdb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import io 3 | import os 4 | from argparse import ArgumentParser 5 | 6 | import numpy as np 7 | import lmdb 8 | from PIL import Image 9 | 10 | 11 | def main(): 12 | parser = ArgumentParser() 13 | parser.add_argument('inputs', nargs='+', help='Path to input LMDBs') 14 | parser.add_argument('--output', help='Path to output LMDB') 15 | parser.add_argument('--min_image_dim', type=int, default=8) 16 | args = parser.parse_args() 17 | 18 | os.makedirs(args.output, exist_ok=True) 19 | with lmdb.open(args.output, map_size=1099511627776) as env_out: 20 | in_samples = 0 21 | out_samples = 0 22 | samples_per_chunk = 1000 23 | for lmdb_in in args.inputs: 24 | with lmdb.open(lmdb_in, readonly=True, max_readers=1, lock=False) as env_in: 25 | with env_in.begin() as txn: 26 | num_samples = int(txn.get('num-samples'.encode())) 27 | in_samples += num_samples 28 | chunks = np.array_split(range(num_samples), num_samples // samples_per_chunk) 29 | for chunk in chunks: 30 | cache = {} 31 | with env_in.begin() as txn: 32 | for index in chunk: 33 | index += 1 # lmdb starts at 1 34 | image_key = f'image-{index:09d}'.encode() 35 | image_bin = txn.get(image_key) 36 | img = Image.open(io.BytesIO(image_bin)) 37 | w, h = img.size 38 | if w < args.min_image_dim or h < args.min_image_dim: 39 | print(f'Skipping: {index}, w = {w}, h = {h}') 40 | continue 41 | out_samples += 1 # increment. start at 1 42 | label_key = f'label-{index:09d}'.encode() 43 | out_label_key = f'label-{out_samples:09d}'.encode() 44 | out_image_key = f'image-{out_samples:09d}'.encode() 45 | cache[out_label_key] = txn.get(label_key) 46 | cache[out_image_key] = image_bin 47 | with env_out.begin(write=True) as txn: 48 | for k, v in cache.items(): 49 | txn.put(k, v) 50 | print(f'Written samples from {chunk[0]} to {chunk[-1]}') 51 | with env_out.begin(write=True) as txn: 52 | txn.put('num-samples'.encode(), str(out_samples).encode()) 53 | print(f'Written {out_samples} samples to {args.output} out of {in_samples} input samples.') 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /finetune/tools/lsvt_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import re 6 | from functools import partial 7 | 8 | import mmcv 9 | import numpy as np 10 | from PIL import Image 11 | from mmocr.utils.fileio import list_to_file 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Generate training set of LSVT ' 17 | 'by cropping box image.') 18 | parser.add_argument('root_path', help='Root dir path of LSVT') 19 | parser.add_argument( 20 | 'n_proc', default=1, type=int, help='Number of processes to run') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def process_img(args, src_image_root, dst_image_root): 26 | # Dirty hack for multiprocessing 27 | img_idx, img_info, anns = args 28 | try: 29 | src_img = Image.open(osp.join(src_image_root, 'train_full_images_0/{}.jpg'.format(img_info))) 30 | except IOError: 31 | src_img = Image.open(osp.join(src_image_root, 'train_full_images_1/{}.jpg'.format(img_info))) 32 | blacklist = ['LOFTINESS*'] 33 | whitelist = ['#Find YOUR Fun#', 'Story #', '*0#'] 34 | labels = [] 35 | for ann_idx, ann in enumerate(anns): 36 | text_label = ann['transcription'] 37 | 38 | # Ignore illegible or words with non-Latin characters 39 | if ann['illegibility'] or re.findall(r'[\u4e00-\u9fff]+', text_label) or text_label in blacklist or \ 40 | ('#' in text_label and text_label not in whitelist): 41 | continue 42 | 43 | points = np.asarray(ann['points']) 44 | x1, y1 = points.min(axis=0) 45 | x2, y2 = points.max(axis=0) 46 | 47 | dst_img = src_img.crop((x1, y1, x2, y2)) 48 | dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' 49 | dst_img_path = osp.join(dst_image_root, dst_img_name) 50 | # Preserve JPEG quality 51 | dst_img.save(dst_img_path, qtables=src_img.quantization) 52 | labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' 53 | f' {text_label}') 54 | src_img.close() 55 | return labels 56 | 57 | 58 | def convert_lsvt(root_path, 59 | dst_image_path, 60 | dst_label_filename, 61 | annotation_filename, 62 | img_start_idx=0, 63 | nproc=1): 64 | annotation_path = osp.join(root_path, annotation_filename) 65 | if not osp.exists(annotation_path): 66 | raise Exception( 67 | f'{annotation_path} not exists, please check and try again.') 68 | src_image_root = root_path 69 | 70 | # outputs 71 | dst_label_file = osp.join(root_path, dst_label_filename) 72 | dst_image_root = osp.join(root_path, dst_image_path) 73 | os.makedirs(dst_image_root, exist_ok=True) 74 | 75 | annotation = mmcv.load(annotation_path) 76 | 77 | process_img_with_path = partial( 78 | process_img, 79 | src_image_root=src_image_root, 80 | dst_image_root=dst_image_root) 81 | tasks = [] 82 | for img_idx, (img_info, anns) in enumerate(annotation.items()): 83 | tasks.append((img_idx + img_start_idx, img_info, anns)) 84 | labels_list = mmcv.track_parallel_progress( 85 | process_img_with_path, tasks, keep_order=True, nproc=nproc) 86 | final_labels = [] 87 | for label_list in labels_list: 88 | final_labels += label_list 89 | list_to_file(dst_label_file, final_labels) 90 | return len(annotation) 91 | 92 | 93 | def main(): 94 | args = parse_args() 95 | root_path = args.root_path 96 | print('Processing training set...') 97 | convert_lsvt( 98 | root_path=root_path, 99 | dst_image_path='image_train', 100 | dst_label_filename='train_label.txt', 101 | annotation_filename='train_full_labels.json', 102 | nproc=args.n_proc) 103 | print('Finish') 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /finetune/tools/mlt19_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | 5 | root = sys.argv[1] 6 | 7 | with open(root + '/gt.txt', 'r') as f: 8 | d = f.readlines() 9 | 10 | with open(root + '/lmdb.txt', 'w') as f: 11 | for line in d: 12 | img, script, label = line.split(',', maxsplit=2) 13 | label = label.strip() 14 | if label and script in ['Latin', 'Symbols']: 15 | f.write('\t'.join([img, label]) + '\n') 16 | -------------------------------------------------------------------------------- /finetune/train.sh: -------------------------------------------------------------------------------- 1 | export NCCL_IB_DISABLE=1 2 | CUDA_VISIBLE_DEVICES=0,1,2,3 ./train.py +experiment=mdr-dec6-union 3 | -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/blur.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/blur.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/blur.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/blur.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/camera.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/camera.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/camera.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/camera.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/geometry.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/geometry.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/geometry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/geometry.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/noise.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/noise.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/noise.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/noise.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/ops.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/ops.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/pattern.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/pattern.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/pattern.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/pattern.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/process.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/process.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/process.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/process.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/warp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/warp.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/warp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/warp.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/weather.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/weather.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/__pycache__/weather.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/__pycache__/weather.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/augmentation/camera.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import numpy as np 4 | import skimage as sk 5 | from PIL import Image, ImageOps 6 | from io import BytesIO 7 | 8 | from skimage import color 9 | ''' 10 | PIL resize (W,H) 11 | cv2 image is BGR 12 | PIL image is RGB 13 | ''' 14 | class Contrast: 15 | def __init__(self): 16 | pass 17 | 18 | def __call__(self, img, mag=-1, prob=1.): 19 | if np.random.uniform(0,1) > prob: 20 | return img 21 | 22 | #c = [0.4, .3, .2, .1, .05] 23 | c = [0.4, .3, .2] 24 | if mag<0 or mag>=len(c): 25 | index = np.random.randint(0, len(c)) 26 | else: 27 | index = mag 28 | c = c[index] 29 | img = np.array(img) / 255. 30 | means = np.mean(img, axis=(0, 1), keepdims=True) 31 | img = np.clip((img - means) * c + means, 0, 1) * 255 32 | 33 | return Image.fromarray(img.astype(np.uint8)) 34 | 35 | 36 | class Brightness: 37 | def __init__(self): 38 | pass 39 | 40 | def __call__(self, img, mag=-1, prob=1.): 41 | if np.random.uniform(0,1) > prob: 42 | return img 43 | 44 | #W, H = img.size 45 | #c = [.1, .2, .3, .4, .5] 46 | c = [.1, .2, .3] 47 | if mag<0 or mag>=len(c): 48 | index = np.random.randint(0, len(c)) 49 | else: 50 | index = mag 51 | c = c[index] 52 | 53 | n_channels = len(img.getbands()) 54 | isgray = n_channels == 1 55 | 56 | img = np.array(img) / 255. 57 | if isgray: 58 | img = np.expand_dims(img, axis=2) 59 | img = np.repeat(img, 3, axis=2) 60 | 61 | img = sk.color.rgb2hsv(img) 62 | img[:, :, 2] = np.clip(img[:, :, 2] + c, 0, 1) 63 | img = sk.color.hsv2rgb(img) 64 | 65 | #if isgray: 66 | # img = img[:,:,0] 67 | # img = np.squeeze(img) 68 | 69 | img = np.clip(img, 0, 1) * 255 70 | img = Image.fromarray(img.astype(np.uint8)) 71 | if isgray: 72 | img = ImageOps.grayscale(img) 73 | 74 | return img 75 | #if isgray: 76 | #if isgray: 77 | # img = color.rgb2gray(img) 78 | 79 | #return Image.fromarray(img.astype(np.uint8)) 80 | 81 | 82 | class JpegCompression: 83 | def __init__(self): 84 | pass 85 | 86 | def __call__(self, img, mag=-1, prob=1.): 87 | if np.random.uniform(0,1) > prob: 88 | return img 89 | 90 | #c = [25, 18, 15, 10, 7] 91 | c = [25, 18, 15] 92 | if mag<0 or mag>=len(c): 93 | index = np.random.randint(0, len(c)) 94 | else: 95 | index = mag 96 | c = c[index] 97 | output = BytesIO() 98 | img.save(output, 'JPEG', quality=c) 99 | return Image.open(output) 100 | 101 | 102 | class Pixelate: 103 | def __init__(self): 104 | pass 105 | 106 | def __call__(self, img, mag=-1, prob=1.): 107 | if np.random.uniform(0,1) > prob: 108 | return img 109 | 110 | W, H = img.size 111 | #c = [0.6, 0.5, 0.4, 0.3, 0.25] 112 | c = [0.6, 0.5, 0.4] 113 | if mag<0 or mag>=len(c): 114 | index = np.random.randint(0, len(c)) 115 | else: 116 | index = mag 117 | c = c[index] 118 | img = img.resize((int(W* c), int(H * c)), Image.BOX) 119 | return img.resize((W, H), Image.BOX) 120 | 121 | -------------------------------------------------------------------------------- /pretrain/augmentation/frost/frost1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/frost/frost1.png -------------------------------------------------------------------------------- /pretrain/augmentation/frost/frost2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/frost/frost2.png -------------------------------------------------------------------------------- /pretrain/augmentation/frost/frost3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/frost/frost3.png -------------------------------------------------------------------------------- /pretrain/augmentation/frost/frost4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/frost/frost4.jpg -------------------------------------------------------------------------------- /pretrain/augmentation/frost/frost5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/frost/frost5.jpg -------------------------------------------------------------------------------- /pretrain/augmentation/frost/frost6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/frost/frost6.jpg -------------------------------------------------------------------------------- /pretrain/augmentation/images/delivery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/images/delivery.png -------------------------------------------------------------------------------- /pretrain/augmentation/images/education.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/images/education.png -------------------------------------------------------------------------------- /pretrain/augmentation/images/manila.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/images/manila.png -------------------------------------------------------------------------------- /pretrain/augmentation/images/nokia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/images/nokia.png -------------------------------------------------------------------------------- /pretrain/augmentation/images/telekom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/augmentation/images/telekom.png -------------------------------------------------------------------------------- /pretrain/augmentation/noise.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import skimage as sk 4 | from PIL import Image 5 | 6 | ''' 7 | PIL resize (W,H) 8 | ''' 9 | class GaussianNoise: 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, img, mag=-1, prob=1.): 14 | if np.random.uniform(0,1) > prob: 15 | return img 16 | 17 | W, H = img.size 18 | #c = np.random.uniform(.08, .38) 19 | b = [.08, 0.1, 0.12] 20 | if mag<0 or mag>=len(b): 21 | index = 0 22 | else: 23 | index = mag 24 | a = b[index] 25 | c = np.random.uniform(a, a+0.03) 26 | img = np.array(img) / 255. 27 | img = np.clip(img + np.random.normal(size=img.shape, scale=c), 0, 1) * 255 28 | return Image.fromarray(img.astype(np.uint8)) 29 | 30 | 31 | class ShotNoise: 32 | def __init__(self): 33 | pass 34 | 35 | def __call__(self, img, mag=-1, prob=1.): 36 | if np.random.uniform(0,1) > prob: 37 | return img 38 | 39 | W, H = img.size 40 | #c = np.random.uniform(3, 60) 41 | b = [13, 8, 3] 42 | if mag<0 or mag>=len(b): 43 | index = 2 44 | else: 45 | index = mag 46 | a = b[index] 47 | c = np.random.uniform(a, a+7) 48 | img = np.array(img) / 255. 49 | img = np.clip(np.random.poisson(img * c) / float(c), 0, 1) * 255 50 | return Image.fromarray(img.astype(np.uint8)) 51 | 52 | 53 | class ImpulseNoise: 54 | def __init__(self): 55 | pass 56 | 57 | def __call__(self, img, mag=-1, prob=1.): 58 | if np.random.uniform(0,1) > prob: 59 | return img 60 | 61 | W, H = img.size 62 | #c = np.random.uniform(.03, .27) 63 | b = [.03, .07, .11] 64 | if mag<0 or mag>=len(b): 65 | index = 0 66 | else: 67 | index = mag 68 | a = b[index] 69 | c = np.random.uniform(a, a+.04) 70 | img = sk.util.random_noise(np.array(img) / 255., mode='s&p', amount=c) * 255 71 | return Image.fromarray(img.astype(np.uint8)) 72 | 73 | 74 | class SpeckleNoise: 75 | def __init__(self): 76 | pass 77 | 78 | def __call__(self, img, mag=-1, prob=1.): 79 | if np.random.uniform(0,1) > prob: 80 | return img 81 | 82 | W, H = img.size 83 | # c = np.random.uniform(.15, .6) 84 | b = [.15, .2, .25] 85 | if mag<0 or mag>=len(b): 86 | index = 0 87 | else: 88 | index = mag 89 | a = b[index] 90 | c = np.random.uniform(a, a+.05) 91 | img = np.array(img) / 255. 92 | img = np.clip(img + img * np.random.normal(size=img.shape, scale=c), 0, 1) * 255 93 | return Image.fromarray(img.astype(np.uint8)) 94 | 95 | -------------------------------------------------------------------------------- /pretrain/augmentation/ops.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import numpy as np 4 | # from wand.image import Image as WandImage 5 | from scipy.ndimage import zoom as scizoom 6 | # from wand.api import library as wandlibrary 7 | 8 | # class MotionImage(WandImage): 9 | # def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0): 10 | # wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle) 11 | 12 | def clipped_zoom(img, zoom_factor): 13 | h = img.shape[1] 14 | # ceil crop height(= crop width) 15 | ch = int(np.ceil(h / float(zoom_factor))) 16 | 17 | top = (h - ch) // 2 18 | img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1) 19 | # trim off any extra pixels 20 | trim_top = (img.shape[0] - h) // 2 21 | 22 | return img[trim_top:trim_top + h, trim_top:trim_top + h] 23 | 24 | def disk(radius, alias_blur=0.1, dtype=np.float32): 25 | if radius <= 8: 26 | L = np.arange(-8, 8 + 1) 27 | ksize = (3, 3) 28 | else: 29 | L = np.arange(-radius, radius + 1) 30 | ksize = (5, 5) 31 | X, Y = np.meshgrid(L, L) 32 | aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype) 33 | aliased_disk /= np.sum(aliased_disk) 34 | 35 | # supersample disk to antialias 36 | return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) 37 | 38 | # modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py 39 | def plasma_fractal(mapsize=256, wibbledecay=3): 40 | """ 41 | Generate a heightmap using diamond-square algorithm. 42 | Return square 2d array, side length 'mapsize', of floats in range 0-255. 43 | 'mapsize' must be a power of two. 44 | """ 45 | assert (mapsize & (mapsize - 1) == 0) 46 | maparray = np.empty((mapsize, mapsize), dtype=np.float_) 47 | maparray[0, 0] = 0 48 | stepsize = mapsize 49 | wibble = 100 50 | 51 | def wibbledmean(array): 52 | return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape) 53 | 54 | def fillsquares(): 55 | """For each square of points stepsize apart, 56 | calculate middle value as mean of points + wibble""" 57 | cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize] 58 | squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0) 59 | squareaccum += np.roll(squareaccum, shift=-1, axis=1) 60 | maparray[stepsize // 2:mapsize:stepsize, 61 | stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum) 62 | 63 | def filldiamonds(): 64 | """For each diamond of points stepsize apart, 65 | calculate middle value as mean of points + wibble""" 66 | mapsize = maparray.shape[0] 67 | drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize] 68 | ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize] 69 | ldrsum = drgrid + np.roll(drgrid, 1, axis=0) 70 | lulsum = ulgrid + np.roll(ulgrid, -1, axis=1) 71 | ltsum = ldrsum + lulsum 72 | maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum) 73 | tdrsum = drgrid + np.roll(drgrid, 1, axis=1) 74 | tulsum = ulgrid + np.roll(ulgrid, -1, axis=0) 75 | ttsum = tdrsum + tulsum 76 | maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum) 77 | 78 | while stepsize >= 2: 79 | fillsquares() 80 | filldiamonds() 81 | stepsize //= 2 82 | wibble /= wibbledecay 83 | 84 | maparray -= maparray.min() 85 | return maparray / maparray.max() 86 | 87 | 88 | -------------------------------------------------------------------------------- /pretrain/augmentation/pattern.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import numpy as np 4 | from PIL import Image, ImageOps, ImageDraw 5 | 6 | ''' 7 | PIL resize (W,H) 8 | Torch resize is (H,W) 9 | ''' 10 | class VGrid: 11 | def __init__(self): 12 | pass 13 | 14 | def __call__(self, img, copy=True, max_width=4, mag=-1, prob=1.): 15 | if np.random.uniform(0,1) > prob: 16 | return img 17 | 18 | if copy: 19 | img = img.copy() 20 | W, H = img.size 21 | 22 | if mag<0 or mag>max_width: 23 | line_width = np.random.randint(1, max_width) 24 | image_stripe = np.random.randint(1, max_width) 25 | else: 26 | line_width = 1 27 | image_stripe = 3 - mag 28 | 29 | n_lines = W // (line_width + image_stripe) + 1 30 | draw = ImageDraw.Draw(img) 31 | for i in range(1, n_lines): 32 | x = image_stripe*i + line_width*(i-1) 33 | draw.line([(x,0), (x,H)], width=line_width, fill='black') 34 | 35 | return img 36 | 37 | class HGrid: 38 | def __init__(self): 39 | pass 40 | 41 | def __call__(self, img, copy=True, max_width=4, mag=-1, prob=1.): 42 | if np.random.uniform(0,1) > prob: 43 | return img 44 | 45 | if copy: 46 | img = img.copy() 47 | W, H = img.size 48 | if mag<0 or mag>max_width: 49 | line_width = np.random.randint(1, max_width) 50 | image_stripe = np.random.randint(1, max_width) 51 | else: 52 | line_width = 1 53 | image_stripe = 3 - mag 54 | 55 | n_lines = H // (line_width + image_stripe) + 1 56 | draw = ImageDraw.Draw(img) 57 | for i in range(1, n_lines): 58 | y = image_stripe*i + line_width*(i-1) 59 | draw.line([(0,y), (W, y)], width=line_width, fill='black') 60 | 61 | return img 62 | 63 | class Grid: 64 | def __init__(self): 65 | pass 66 | 67 | def __call__(self, img, mag=-1, prob=1.): 68 | if np.random.uniform(0,1) > prob: 69 | return img 70 | 71 | img = VGrid()(img, copy=True, mag=mag) 72 | img = HGrid()(img, copy=False, mag=mag) 73 | return img 74 | 75 | class RectGrid: 76 | def __init__(self): 77 | pass 78 | 79 | def __call__(self, img, isellipse=False, mag=-1, prob=1.): 80 | if np.random.uniform(0,1) > prob: 81 | return img 82 | 83 | img = img.copy() 84 | W, H = img.size 85 | line_width = 1 86 | image_stripe = 3 - mag #np.random.randint(2, 6) 87 | offset = 4 if isellipse else 1 88 | n_lines = ((H//2) // (line_width + image_stripe)) + offset 89 | draw = ImageDraw.Draw(img) 90 | x_center = W // 2 91 | y_center = H // 2 92 | for i in range(1, n_lines): 93 | dx = image_stripe*i + line_width*(i-1) 94 | dy = image_stripe*i + line_width*(i-1) 95 | x1 = x_center - (dx * W//H) 96 | y1 = y_center - dy 97 | x2 = x_center + (dx * W/H) 98 | y2 = y_center + dy 99 | if isellipse: 100 | draw.ellipse([(x1,y1), (x2, y2)], width=line_width, outline='black') 101 | else: 102 | draw.rectangle([(x1,y1), (x2, y2)], width=line_width, outline='black') 103 | 104 | return img 105 | 106 | class EllipseGrid: 107 | def __init__(self): 108 | pass 109 | 110 | def __call__(self, img, mag=-1, prob=1.): 111 | if np.random.uniform(0,1) > prob: 112 | return img 113 | 114 | img = RectGrid()(img, isellipse=True, mag=mag, prob=prob) 115 | return img 116 | -------------------------------------------------------------------------------- /pretrain/augmentation/process.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | import PIL.ImageOps, PIL.ImageEnhance 4 | import numpy as np 5 | 6 | class Posterize: 7 | def __init__(self): 8 | pass 9 | 10 | def __call__(self, img, mag=-1, prob=1.): 11 | if np.random.uniform(0,1) > prob: 12 | return img 13 | 14 | c = [1, 3, 6] 15 | if mag<0 or mag>=len(c): 16 | index = np.random.randint(0, len(c)) 17 | else: 18 | index = mag 19 | c = c[index] 20 | bit = np.random.randint(c, c+2) 21 | img = PIL.ImageOps.posterize(img, bit) 22 | 23 | return img 24 | 25 | 26 | class Solarize: 27 | def __init__(self): 28 | pass 29 | 30 | def __call__(self, img, mag=-1, prob=1.): 31 | if np.random.uniform(0,1) > prob: 32 | return img 33 | 34 | c = [64, 128, 192] 35 | if mag<0 or mag>=len(c): 36 | index = np.random.randint(0, len(c)) 37 | else: 38 | index = mag 39 | c = c[index] 40 | thresh = np.random.randint(c, c+64) 41 | img = PIL.ImageOps.solarize(img, thresh) 42 | 43 | return img 44 | 45 | class Invert: 46 | def __init__(self): 47 | pass 48 | 49 | def __call__(self, img, mag=-1, prob=1.): 50 | if np.random.uniform(0,1) > prob: 51 | return img 52 | 53 | img = PIL.ImageOps.invert(img) 54 | 55 | return img 56 | 57 | 58 | class Equalize: 59 | def __init__(self): 60 | pass 61 | 62 | def __call__(self, img, mag=-1, prob=1.): 63 | if np.random.uniform(0,1) > prob: 64 | return img 65 | 66 | mg = PIL.ImageOps.equalize(img) 67 | 68 | return img 69 | 70 | 71 | class AutoContrast: 72 | def __init__(self): 73 | pass 74 | 75 | def __call__(self, img, mag=-1, prob=1.): 76 | if np.random.uniform(0,1) > prob: 77 | return img 78 | 79 | mg = PIL.ImageOps.autocontrast(img) 80 | 81 | return img 82 | 83 | 84 | class Sharpness: 85 | def __init__(self): 86 | pass 87 | 88 | def __call__(self, img, mag=-1, prob=1.): 89 | if np.random.uniform(0,1) > prob: 90 | return img 91 | 92 | c = [.1, .7, 1.3] 93 | if mag<0 or mag>=len(c): 94 | index = np.random.randint(0, len(c)) 95 | else: 96 | index = mag 97 | c = c[index] 98 | magnitude = np.random.uniform(c, c+.6) 99 | img = PIL.ImageEnhance.Sharpness(img).enhance(magnitude) 100 | 101 | return img 102 | 103 | 104 | class Color: 105 | def __init__(self): 106 | pass 107 | 108 | def __call__(self, img, mag=-1, prob=1.): 109 | if np.random.uniform(0,1) > prob: 110 | return img 111 | 112 | c = [.1, .7, 1.3] 113 | if mag<0 or mag>=len(c): 114 | index = np.random.randint(0, len(c)) 115 | else: 116 | index = mag 117 | c = c[index] 118 | magnitude = np.random.uniform(c, c+.6) 119 | img = PIL.ImageEnhance.Color(img).enhance(magnitude) 120 | 121 | return img 122 | 123 | 124 | -------------------------------------------------------------------------------- /pretrain/augmentation/test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | from warp import Curve, Distort, Stretch 5 | from geometry import Rotate, Perspective, Shrink, TranslateX, TranslateY 6 | from pattern import VGrid, HGrid, Grid, RectGrid, EllipseGrid 7 | from noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise 8 | from blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur 9 | from camera import Contrast, Brightness, JpegCompression, Pixelate 10 | from weather import Fog, Snow, Frost, Rain, Shadow 11 | from process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color 12 | 13 | from PIL import Image 14 | import PIL.ImageOps 15 | import numpy as np 16 | import argparse 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--image', default="images/delivery.png", help='Load image file') 22 | parser.add_argument('--results', default="results", help='Load image file') 23 | parser.add_argument('--gray', action='store_true', help='Convert to grayscale 1st') 24 | opt = parser.parse_args() 25 | os.makedirs(opt.results, exist_ok=True) 26 | 27 | img = Image.open(opt.image) 28 | img = img.resize( (100,32) ) 29 | ops = [Curve(), Rotate(), Perspective(), Distort(), Stretch(), Shrink(), TranslateX(), TranslateY(), VGrid(), HGrid(), Grid(), RectGrid(), EllipseGrid()] 30 | ops.extend([GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()]) 31 | ops.extend([GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur(), ZoomBlur()]) 32 | ops.extend([Contrast(), Brightness(), JpegCompression(), Pixelate()]) 33 | ops.extend([Fog(), Snow(), Frost(), Rain(), Shadow()]) 34 | ops.extend([Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()]) 35 | for op in ops: 36 | for mag in range(3): 37 | filename = type(op).__name__ + "-" + str(mag) + ".png" 38 | out_img = op(img, mag=mag) 39 | if opt.gray: 40 | out_img = PIL.ImageOps.grayscale(out_img) 41 | out_img.save(os.path.join(opt.results, filename)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /pretrain/load_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # /home/gaoz/output/dig-cc/ 3 | path = '/home/gaoz/output/dig-cc/checkpoint-9.pth' 4 | checkpoint = torch.load(path, map_location='cpu') 5 | pretrained_dict = checkpoint['model'] 6 | # print(pretrained_dict.keys()) 7 | 8 | print('-' * 100) 9 | 10 | path = '/home/gaoz/output_pixel/SR-100e-noaug_norm/checkpoint-167.pth' 11 | model_checkpoint = torch.load(path, map_location='cpu') 12 | model_dict = model_checkpoint['model'] 13 | # print(model_dict.keys()) 14 | 15 | 16 | pretrained_dict_new = {k[8:]: v for k, v in pretrained_dict.items() \ 17 | if k.startswith('encoder') and (k[8:] in model_dict)}#filter out unnecessary keys and (k[8:] != '.pos_embed') 18 | 19 | for k, v in pretrained_dict.items(): 20 | if k.startswith('encoder'): 21 | print(k[8:]) 22 | 23 | 24 | 25 | print(len(pretrained_dict_new.keys())) -------------------------------------------------------------------------------- /pretrain/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm.models.vision_transformer 18 | from timm.models.vision_transformer import VisionTransformer, PatchEmbed 19 | 20 | class Encoder(VisionTransformer): 21 | 22 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 23 | qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed): 24 | super().__init__(img_size, patch_size, in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, 25 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 26 | drop_path_rate=drop_path_rate, embed_layer=embed_layer, 27 | num_classes=0, global_pool='', class_token=False) # these disable the classifier head 28 | 29 | def forward(self, x): 30 | # Return all tokens 31 | return self.forward_features(x) 32 | 33 | 34 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 35 | """ Vision Transformer with support for global average pooling 36 | """ 37 | def __init__(self, global_pool=False, **kwargs): 38 | super(VisionTransformer, self).__init__(**kwargs) 39 | 40 | self.global_pool = global_pool 41 | if self.global_pool: 42 | norm_layer = kwargs['norm_layer'] 43 | embed_dim = kwargs['embed_dim'] 44 | self.fc_norm = norm_layer(embed_dim) 45 | 46 | del self.norm # remove the original norm 47 | 48 | def forward_features(self, x): 49 | B = x.shape[0] 50 | x = self.patch_embed(x) 51 | 52 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 53 | x = torch.cat((cls_tokens, x), dim=1) 54 | x = x + self.pos_embed 55 | x = self.pos_drop(x) 56 | 57 | for blk in self.blocks: 58 | x = blk(x) 59 | 60 | if self.global_pool: 61 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 62 | outcome = self.fc_norm(x) 63 | else: 64 | x = self.norm(x) 65 | outcome = x[:, 0] 66 | 67 | return outcome 68 | 69 | 70 | def vit_base_patch16(**kwargs): 71 | model = VisionTransformer( 72 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | return model 75 | 76 | 77 | def vit_large_patch16(**kwargs): 78 | model = VisionTransformer( 79 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 80 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 81 | return model 82 | 83 | 84 | def vit_huge_patch14(**kwargs): 85 | model = VisionTransformer( 86 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 87 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 88 | return model -------------------------------------------------------------------------------- /pretrain/requirements.txt: -------------------------------------------------------------------------------- 1 | validators==0.20.0 2 | timm==0.4.12 3 | lmdb==1.2.1 4 | pillow==8.1.0 5 | nltk==3.6.2 6 | natsort==7.1.1 7 | opencv-python==4.5.1.48 8 | opencv-contrib-python==4.5.1.48 9 | wand==0.6.7 10 | transformers==4.2.1 11 | strsimpy==0.2.1 -------------------------------------------------------------------------------- /pretrain/scripts/.ipynb_checkpoints/encoder-pre4-checkpoint.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 --master_port 29055 flip_pretrain.py \ 2 | --lr 1e-3 \ 3 | --batch_size 512 \ 4 | --mode single \ 5 | --model flipae_vit_small_str \ 6 | --epochs 20 \ 7 | --warmup_epochs 2 \ 8 | --output_dir /home/gaoz/output/flip/base-flip-sb2048-debug \ 9 | --data_path /home/gaoz/datasets/real-rec/ 10 | 11 | 12 | CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 --master_port 29055 flip_pretrain.py \ 13 | --lr 1e-3 \ 14 | --batch_size 1024 \ 15 | --mode single \ 16 | --model flipae_vit_small_str \ 17 | --epochs 20 \ 18 | --warmup_epochs 2 \ 19 | --output_dir /root/autodl-tmp/output/base-flip/base-hflip-2048 \ 20 | --data_path /root/autodl-tmp/unidata/real-rec/ 21 | --direction HF --num_workers 10 22 | 23 | 24 | CUDA_VISIBLE_DEVICES=0,1 python3 -m tor 25 | ch.distributed.launch --nproc_per_node=2 --nnodes=1 --master_port 29055 flip_pretrain.py \ 26 | --lr 1e-3 \ 27 | --batch_size 1024 \ 28 | --mode single \ 29 | --model flipae_vit_small_str \ 30 | --epochs 20 \ 31 | --warmup_epochs 2 \ 32 | --output_dir /root/autodl-tmp/output/base-flip/base-hflip-2048 \ 33 | --data_path /root/autodl-tmp/unidata/real-rec/ \ 34 | --direction HF --num_workers 10 35 | 36 | 37 | CUDA_VISIBLE_DEVICES=2,3 python3 -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --master_port 29050 flip_pretrain.py \ 38 | --lr 1e-3 \ 39 | --batch_size 1024 \ 40 | --mode single \ 41 | --model flipae_vit_small_str \ 42 | --epochs 20 \ 43 | --warmup_epochs 2 \ 44 | --output_dir /root/autodl-tmp/output/base-flip/base-vflip-2048 \ 45 | --data_path /root/autodl-tmp/unidata/real-rec/ \ 46 | --direction VF --num_workers 10 47 | 48 | 49 | CUDA_VISIBLE_DEVICES=2,3 python3 -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --master_port 29055 flip_pretrain.py \ 50 | --lr 1e-3 \ 51 | --batch_size 512 \ 52 | --mode single \ 53 | --model flipae_vit_small_str \ 54 | --epochs 20 \ 55 | --warmup_epochs 1 \ 56 | --output_dir /root/autodl-tmp/output/base-flip/trans-judgue-1024 \ 57 | --data_path /root/autodl-tmp/unidata/real-rec/ \ 58 | --direction Hybrid --num_workers 10 59 | 60 | 61 | CUDA_VISIBLE_DEVICES=2,3 python3 -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --master_port 29055 flip_pretrain.py \ 62 | --lr 1e-3 \ 63 | --batch_size 384s \ 64 | --mode single \ 65 | --model flipae_vit_small_str \ 66 | --epochs 20 \ 67 | --warmup_epochs 1 \ 68 | --output_dir /root/autodl-tmp/output/base-flip/trans-hybird-mul-prob \ 69 | --data_path /root/autodl-tmp/unidata/real-rec/ \ 70 | --direction Hybrid --num_workers 10 71 | 72 | 73 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 --master_port 29055 flip_pretrain.py \ 74 | --lr 1e-3 \ 75 | --batch_size 256 \ 76 | --mode single \ 77 | --model flipae_vit_small_str \ 78 | --epochs 20 \ 79 | --warmup_epochs 1 \ 80 | --output_dir /root/autodl-tmp/output/base-flip/trans-hybird-mulemb3-prob \ 81 | --data_path /root/autodl-tmp/unidata/real-rec/ \ 82 | --direction Hybrid --num_workers 10 83 | 84 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 --master_port 29055 flip_pretrain.py \ 85 | --lr 1e-3 \ 86 | --batch_size 256 \ 87 | --mode single \ 88 | --model flipae_vit_small_str \ 89 | --epochs 20 \ 90 | --warmup_epochs 1 \ 91 | --output_dir /root/autodl-tmp/output/base-flip/trans-hybird-mulemb3-prob-p4x8 \ 92 | --data_path /root/autodl-tmp/unidata/real-rec/ \ 93 | --direction Hybrid --num_workers 10 94 | 95 | 96 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 --master_port 29055 flip_pretrain.py \ 97 | --lr 1e-3 \ 98 | --batch_size 256 \ 99 | --mode single \ 100 | --model flipae_vit_small_strdim384 \ 101 | --epochs 20 \ 102 | --warmup_epochs 1 \ 103 | --output_dir /root/autodl-tmp/output/base-flip/emb3-prob4x4-dim384 \ 104 | --data_path /root/autodl-tmp/unidata/real-rec/ \ 105 | --direction Hybrid --num_workers 10 -------------------------------------------------------------------------------- /pretrain/scripts/encoder-pretrain.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --master_port 29060 sim_pretrain.py \ 2 | --lr 5e-4 \ 3 | --batch_size 144 \ 4 | --mode single \ 5 | --model flipae_sim_vit_small_str \ 6 | --epochs 20 \ 7 | --warmup_epochs 1 \ 8 | --mm 0.995 \ 9 | --mmschedule 'cosine' \ 10 | --output_dir $output_path$ \ 11 | --data_path $data_path$ \ 12 | --direction aug_pool --num_workers 10 -------------------------------------------------------------------------------- /pretrain/scripts/eval-SR-compare.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29060 main_pixel.py \ 2 | --batch_size 10 \ 3 | --model pixel_vit_small \ 4 | --pixel_type SR \ 5 | --demo_dir /home/gaoz/output_pixel/SR-100e-noaug_norm/hard_best_vis-compare/ \ 6 | --eval_data_path /home/gaoz/datasets/textzoom/test/hard \ 7 | --model_path /home/gaoz/output_pixel/SR-100e-noaug_norm/best/psnr-22.1962-ssim-0.7610-checkpoint-72.pth \ 8 | --scratch_path /home/gaoz/output_pixel/SR-100e-noaug_norm/best/psnr-22.1962-ssim-0.7610-checkpoint-72.pth \ 9 | --dig_path /home/gaoz/output_pixel/SR-100e-noaug_norm/best/psnr-22.1962-ssim-0.7610-checkpoint-72.pth \ 10 | --ccd_path /home/gaoz/output_pixel/SR-100e-noaug_norm/best/psnr-22.1962-ssim-0.7610-checkpoint-72.pth \ 11 | --num_workers 10 --eval_only --vis --compare -------------------------------------------------------------------------------- /pretrain/scripts/eval-SR.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29060 main_pixel.py \ 2 | --batch_size 10 \ 3 | --model pixel_vit_small \ 4 | --pixel_type SR \ 5 | --demo_dir /home/gaoz/output_pixel/SR-100e-noaug_norm/hard_best_vis/ \ 6 | --eval_data_path /home/gaoz/datasets/textzoom/test/hard \ 7 | --model_path /home/gaoz/output_pixel/SR-100e-noaug_norm/best/psnr-22.1962-ssim-0.7610-checkpoint-72.pth \ 8 | --num_workers 10 --eval_only --vis -------------------------------------------------------------------------------- /pretrain/scripts/eval-Seg-compare.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29060 main_pixel.py \ 2 | --batch_size 10 \ 3 | --model pixel_vit_small \ 4 | --pixel_type Seg \ 5 | --demo_dir /home/gaoz/output_pixel/Seg-100e-norm/best_vis_compare/ \ 6 | --eval_data_path /home/gaoz/datasets/TextSeg/TextSeg/image_slice/test/ \ 7 | --model_path /home/gaoz/output_pixel/Seg-100e-norm/best/IoU-0.8288-checkpoint-39.pth \ 8 | --scratch_path /home/gaoz/output_pixel/Seg-100e-norm/best/IoU-0.8288-checkpoint-39.pth \ 9 | --dig_path /home/gaoz/output_pixel/Seg-100e-norm/best/IoU-0.8288-checkpoint-39.pth \ 10 | --ccd_path /home/gaoz/output_pixel/Seg-100e-norm/best/IoU-0.8288-checkpoint-39.pth \ 11 | --num_workers 10 --eval_only --vis --compare 12 | -------------------------------------------------------------------------------- /pretrain/scripts/eval-Seg.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29060 main_pixel.py \ 2 | --batch_size 1 \ 3 | --model pixel_vit_small \ 4 | --pixel_type Seg \ 5 | --demo_dir /home/gaoz/output_pixel/Seg-100e-norm/best_vis/ \ 6 | --eval_data_path /home/gaoz/datasets/TextSeg/TextSeg/image_slice/test/ \ 7 | --model_path /home/gaoz/output_pixel/Seg-100e-norm/best/IoU-0.8288-checkpoint-39.pth \ 8 | --num_workers 10 --eval_only --vis 9 | -------------------------------------------------------------------------------- /pretrain/scripts/eval.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29058 demo.py \ 2 | --batch_size 10 \ 3 | --model flipae_vit_small_str \ 4 | --model_path /home/gaoz/pretrained/trans-hybird-mulemb3-prob/checkpoint-19.pth \ 5 | --data_path /home/gaoz/datasets/data/test/data_shuffle/ \ 6 | --demo_dir /home/gaoz/output/HF-shuffle-nomask/ \ 7 | --direction HF --num_workers 10 \ 8 | --debug 9 | 10 | 11 | CUDA_VISIBLE_DEVICES=6 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29058 demo.py \ 12 | --batch_size 10 \ 13 | --model flipae_vit_small_str \ 14 | --model_path /home/gaoz/output/mdr-HF-pair-mix/checkpoint-9.pth \ 15 | --data_path /home/gaoz/datasets/data/test/IIIT5k/ \ 16 | --demo_dir /home/gaoz/output/pair-hf-mix-9/IIIT5k/ \ 17 | --direction HF --num_workers 10 --pair_mix 18 | 19 | 20 | SVTP SVT salient artistic IC13_1095 multi_words IC15_2077 21 | 22 | CUDA_VISIBLE_DEVICES=4 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29058 demo.py \ 23 | --batch_size 10 \ 24 | --model flipae_vit_small_str \ 25 | --model_path /home/gaoz/output/mdr-union-pool-b256/checkpoint-19.pth \ 26 | --data_path /home/gaoz/datasets/test/IIIT5k/ \ 27 | --demo_dir /home/gaoz/output/vis/IIIT5k/RO \ 28 | --direction RO --num_workers 10 \ 29 | --debug 30 | -------------------------------------------------------------------------------- /pretrain/scripts/eval_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 定义第一个数组 4 | first_array=(IIIT5k SVTP SVT salient artistic IC13_1095 multi_words IC15_2077) 5 | 6 | # 定义第二个数组 7 | second_array=(HF VF RO) 8 | 9 | # 第一层循环遍历第一个数组 10 | for i in "${first_array[@]}"; do 11 | 12 | # 第二层循环遍历第二个数组 13 | for j in "${second_array[@]}"; do 14 | CUDA_VISIBLE_DEVICES=4 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29058 demo.py \ 15 | --batch_size 10 \ 16 | --model flipae_vit_small_str \ 17 | --model_path /home/gaoz/output/mdr-union-pool-b256/checkpoint-19.pth \ 18 | --data_path /home/gaoz/datasets/test/${i}/ \ 19 | --demo_dir /home/gaoz/output/vis/${i}/${j} \ 20 | --direction ${j} --num_workers 10 21 | done 22 | 23 | done 24 | -------------------------------------------------------------------------------- /pretrain/scripts/finetune-SR-CCD.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29060 main_pixel.py \ 2 | --lr 1e-4 \ 3 | --batch_size 96 \ 4 | --mode single \ 5 | --model pixel_vit_small \ 6 | --epochs 100 \ 7 | --warmup_epochs 5 \ 8 | --pixel_type SR \ 9 | --output_dir /home/gaoz/output_pixel/SR-dd/ \ 10 | --best_dir /home/gaoz/output_pixel/SR-dd/best/ \ 11 | --data_path /home/gaoz/datasets/textzoom/train/ \ 12 | --eval_data_path /home/gaoz/datasets/textzoom/test/ \ 13 | --pretrained /home/gaoz/output/ccd-cc/CCD_ViT_Small_pretrain.pth \ 14 | --num_workers 10 --train_val -------------------------------------------------------------------------------- /pretrain/scripts/finetune-SR-DiG.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29060 main_pixel.py \ 2 | --lr 1e-4 \ 3 | --batch_size 96 \ 4 | --mode single \ 5 | --model pixel_vit_small \ 6 | --epochs 100 \ 7 | --warmup_epochs 5 \ 8 | --pixel_type SR \ 9 | --output_dir /home/gaoz/output_pixel/SR-dd/ \ 10 | --best_dir /home/gaoz/output_pixel/SR-dd/best/ \ 11 | --data_path /home/gaoz/datasets/textzoom/train/ \ 12 | --eval_data_path /home/gaoz/datasets/textzoom/test/ \ 13 | --pretrained /home/gaoz/output/dig-cc/checkpoint-9.pth \ 14 | --num_workers 10 --train_val -------------------------------------------------------------------------------- /pretrain/scripts/finetune-Seg.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --master_port 29062 main_pixel.py \ 2 | --lr 1e-4 \ 3 | --batch_size 96 \ 4 | --mode single \ 5 | --model pixel_vit_small \ 6 | --epochs 100 \ 7 | --warmup_epochs 10 \ 8 | --pixel_type Seg \ 9 | --output_dir /home/gaoz/output_pixel/Seg-100e-abiaug/ \ 10 | --best_dir /home/gaoz/output_pixel/Seg-100e-abiaug/best/ \ 11 | --data_path /home/gaoz/datasets/TextSeg/TextSeg/image_slice/train_all/ \ 12 | --eval_data_path /home/gaoz/datasets/TextSeg/TextSeg/image_slice/test/ \ 13 | --pretrained /home/gaoz/output/mdr-pool-small-OCRCC/checkpoint-19.pth \ 14 | --num_workers 10 --train_val -------------------------------------------------------------------------------- /pretrain/util/__pycache__/lr_sched.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/util/__pycache__/lr_sched.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/util/__pycache__/lr_sched.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/util/__pycache__/lr_sched.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/util/__pycache__/metric_iou.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/util/__pycache__/metric_iou.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/util/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/util/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/util/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/util/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/util/__pycache__/pos_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/util/__pycache__/pos_embed.cpython-37.pyc -------------------------------------------------------------------------------- /pretrain/util/__pycache__/pos_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/util/__pycache__/pos_embed.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/util/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FaltingsA/SSM/9ce7646f49e76fa0ec42815f02913fbc69150eb6/pretrain/util/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /pretrain/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /pretrain/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /pretrain/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /pretrain/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /pretrain/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | --------------------------------------------------------------------------------