├── README.md ├── configs ├── dataset │ ├── icd13.yaml │ ├── locr.yaml │ ├── st.yaml │ └── tsg.yaml ├── demo.yaml ├── test.yaml ├── test │ └── textdesign_sd_2.yaml ├── train.yaml └── train │ └── textdesign_sd_2.yaml ├── dataset ├── __init__.py ├── dataloader.py └── utils │ ├── arial.ttf │ └── words.txt ├── demo ├── examples │ ├── Adode.jpg │ ├── CARE.jpg │ ├── DREAMTEXT.jpg │ ├── Heart.jpg │ ├── Hello.jpg │ ├── now.jpg │ └── third.jpg ├── gradio.png └── teaser.png ├── docs ├── index.html └── static │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css │ ├── images │ ├── architecture.png │ ├── comparison.png │ ├── comparison_2.png │ ├── comparison_3.png │ ├── display.png │ ├── more_visual_result.png │ ├── more_visual_result_2.png │ ├── problem_visualization.png │ └── problems.png │ └── js │ ├── MathJax.js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js ├── figures ├── compare_results_1.png ├── compare_results_2.png ├── model.png └── teaser.png ├── requirements.txt ├── run_gradio.py ├── sgm ├── __init__.py ├── lr_scheduler.py ├── models │ ├── __init__.py │ ├── autoencoder.py │ └── diffusion.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── autoencoding │ │ ├── __init__.py │ │ ├── losses │ │ │ └── __init__.py │ │ └── regularizers │ │ │ └── __init__.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── denoiser.py │ │ ├── denoiser_scaling.py │ │ ├── denoiser_weighting.py │ │ ├── discretizer.py │ │ ├── guiders.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── sampling.py │ │ ├── sampling_utils.py │ │ ├── sigma_sampling.py │ │ ├── util.py │ │ └── wrappers.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ └── predictors │ │ └── model.py └── util.py ├── src └── parseq │ ├── Datasets.md │ ├── LICENSE │ ├── NOTICE │ ├── bench.py │ ├── configs │ ├── bench.yaml │ ├── charset │ │ ├── 36_lowercase.yaml │ │ ├── 62_mixed-case.yaml │ │ └── 94_full.yaml │ ├── dataset │ │ ├── real.yaml │ │ └── synth.yaml │ ├── experiment │ │ ├── abinet-sv.yaml │ │ ├── abinet.yaml │ │ ├── crnn.yaml │ │ ├── parseq-patch16-224.yaml │ │ ├── parseq-tiny.yaml │ │ ├── parseq.yaml │ │ ├── trba.yaml │ │ ├── trbc.yaml │ │ ├── tune_abinet-lm.yaml │ │ └── vitstr.yaml │ ├── main.yaml │ ├── model │ │ ├── abinet.yaml │ │ ├── crnn.yaml │ │ ├── parseq.yaml │ │ ├── trba.yaml │ │ └── vitstr.yaml │ └── tune.yaml │ ├── demo_images │ ├── art-01107.jpg │ ├── coco-1166773.jpg │ ├── cute-184.jpg │ ├── ic13_word_256.png │ ├── ic15_word_26.png │ └── uber-27491.jpg │ ├── hubconf.py │ ├── read.py │ ├── requirements.txt │ ├── setup.cfg │ ├── setup.py │ ├── strhub │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── aa_overrides.py │ │ ├── augment.py │ │ ├── dataset.py │ │ ├── module.py │ │ └── utils.py │ └── models │ │ ├── __init__.py │ │ ├── abinet │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── backbone.py │ │ ├── model.py │ │ ├── model_abinet_iter.py │ │ ├── model_alignment.py │ │ ├── model_language.py │ │ ├── model_vision.py │ │ ├── resnet.py │ │ ├── system.py │ │ └── transformer.py │ │ ├── base.py │ │ ├── crnn │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── model.py │ │ └── system.py │ │ ├── modules.py │ │ ├── parseq │ │ ├── __init__.py │ │ ├── modules.py │ │ └── system.py │ │ ├── trba │ │ ├── __init__.py │ │ ├── feature_extraction.py │ │ ├── model.py │ │ ├── prediction.py │ │ ├── system.py │ │ └── transformation.py │ │ ├── utils.py │ │ └── vitstr │ │ ├── __init__.py │ │ ├── model.py │ │ └── system.py │ ├── test.py │ ├── tools │ ├── art_converter.py │ ├── case_sensitive_str_datasets_converter.py │ ├── coco_2_converter.py │ ├── coco_text_converter.py │ ├── create_lmdb_dataset.py │ ├── filter_lmdb.py │ ├── lsvt_converter.py │ ├── mlt19_converter.py │ ├── openvino_converter.py │ ├── test_abinet_lm_acc.py │ └── textocr_converter.py │ ├── train.py │ └── tune.py ├── test.py ├── train.py └── util.py /configs/dataset/icd13.yaml: -------------------------------------------------------------------------------- 1 | target: ICDAR13Dataset 2 | params: 3 | 4 | data_root: '/remote-home/share/' 5 | 6 | H: 512 7 | W: 512 8 | word_len: [1, 8] 9 | seq_len: 12 10 | mask_min_ratio: 0.01 11 | aug_text_enabled: True 12 | aug_text_ratio: 1.0 13 | bbr_size: 224 14 | -------------------------------------------------------------------------------- /configs/dataset/locr.yaml: -------------------------------------------------------------------------------- 1 | target: LAIONOCRDataset 2 | params: 3 | 4 | data_root: '/remote-home/share/' 5 | 6 | H: 512 7 | W: 512 8 | word_len: [1, 12] 9 | seq_len: 12 10 | mask_min_ratio: 0.01 11 | seg_min_ratio: 0.001 12 | aug_text_enabled: False 13 | aug_text_ratio: 1.0 14 | 15 | use_cached: False 16 | length: 100000 -------------------------------------------------------------------------------- /configs/dataset/st.yaml: -------------------------------------------------------------------------------- 1 | target: SynthTextDataset 2 | params: 3 | 4 | data_root: '/remote-home/share/' 5 | 6 | H: 512 7 | W: 512 8 | word_len: [1, 12] 9 | mask_min_ratio: 0.01 10 | seg_min_ratio: 0.001 11 | 12 | length: 100000 13 | use_cached: True 14 | bbr_size: 224 -------------------------------------------------------------------------------- /configs/dataset/tsg.yaml: -------------------------------------------------------------------------------- 1 | target: TextSegDataset 2 | params: 3 | 4 | data_root: '/remote-home/share/' 5 | 6 | H: 512 7 | W: 512 8 | word_len: [1, 12] 9 | seq_len: 12 10 | mask_min_ratio: 0.01 11 | seg_min_ratio: 0.005 12 | aug_text_enabled: False 13 | aug_text_ratio: 1.0 -------------------------------------------------------------------------------- /configs/demo.yaml: -------------------------------------------------------------------------------- 1 | type: "demo" 2 | 3 | # path 4 | load_ckpt_path: "./checkpoints/pre_trained.ckpt" 5 | 6 | model_cfg_path: "./configs/test/textdesign_sd_2.yaml" 7 | 8 | # param 9 | H: 512 10 | W: 512 11 | seq_len: 12 12 | batch_size: 1 13 | 14 | channel: 4 # AE latent channel 15 | factor: 8 # AE downsample factor 16 | scale: [4.0, 0.0] # cfg scale, None 17 | force_uc_zero_embeddings: ["ref", "label"] 18 | detailed: True 19 | 20 | # runtime 21 | steps: 50 22 | init_step: 0 23 | num_workers: 0 24 | gpu: 3 -------------------------------------------------------------------------------- /configs/test.yaml: -------------------------------------------------------------------------------- 1 | type: "test" 2 | 3 | # path 4 | load_ckpt_path: "./checkpoints/pre_trained.ckpt" 5 | model_cfg_path: "./configs/test/textdesign_sd_2.yaml" 6 | dataset_cfg_path: "./configs/dataset/icd13.yaml" 7 | output_dir: "./outputs" 8 | temp_dir: "./temp" 9 | 10 | # param 11 | channel: 4 # AE latent channel 12 | factor: 8 # AE downsample factor 13 | scale: [5.0, 0.0] # cfg scale, None 14 | force_uc_zero_embeddings: ["label"] # condition label 15 | detailed: True # save visualization results 16 | 17 | # runtime 18 | steps: 50 # sampling steps 19 | init_step: 0 20 | batch_size: 1 21 | num_workers: 0 22 | gpu: 3 # index of your gpu device 23 | shuffle: False 24 | quan_test: False # quantitative test 25 | 26 | # ocr 27 | ocr_enabled: True 28 | predictor_config: 29 | target: sgm.modules.predictors.model.ParseqPredictor 30 | params: 31 | ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt" -------------------------------------------------------------------------------- /configs/test/textdesign_sd_2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | opt_keys: 5 | - t_attn 6 | input_key: image 7 | scale_factor: 0.18215 8 | disable_first_stage_autocast: True 9 | 10 | denoiser_config: 11 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 12 | params: 13 | num_idx: 1000 14 | 15 | weighting_config: 16 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 17 | scaling_config: 18 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 19 | discretization_config: 20 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 21 | 22 | network_config: 23 | target: sgm.modules.diffusionmodules.openaimodel.UnifiedUNetModel 24 | params: 25 | in_channels: 9 26 | out_channels: 4 27 | ctrl_channels: 0 28 | model_channels: 320 29 | attention_resolutions: [4, 2, 1] 30 | save_attn_type: [t_attn] 31 | save_attn_layers: [output_blocks.6.1] 32 | num_res_blocks: 2 33 | channel_mult: [1, 2, 4, 4] 34 | num_head_channels: 64 35 | use_linear_in_transformer: True 36 | transformer_depth: 1 37 | t_context_dim: 2048 38 | 39 | conditioner_config: 40 | target: sgm.modules.GeneralConditioner 41 | params: 42 | emb_models: 43 | # textual crossattn cond 44 | - is_trainable: False 45 | emb_key: t_crossattn 46 | ucg_rate: 0.1 47 | input_key: label 48 | target: sgm.modules.encoders.modules.LabelEncoder 49 | params: 50 | max_len: 12 51 | emb_dim: 2048 52 | n_heads: 8 53 | n_trans_layers: 12 54 | ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt 55 | # concat cond 56 | - is_trainable: False 57 | input_key: mask 58 | target: sgm.modules.encoders.modules.SpatialRescaler 59 | params: 60 | in_channels: 1 61 | multiplier: 0.125 62 | - is_trainable: False 63 | input_key: masked 64 | target: sgm.modules.encoders.modules.LatentEncoder 65 | params: 66 | scale_factor: 0.18215 67 | config: 68 | target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper 69 | params: 70 | ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors 71 | embed_dim: 4 72 | monitor: val/rec_loss 73 | ddconfig: 74 | attn_type: vanilla-xformers 75 | double_z: true 76 | z_channels: 4 77 | resolution: 256 78 | in_channels: 3 79 | out_ch: 3 80 | ch: 128 81 | ch_mult: [1, 2, 4, 4] 82 | num_res_blocks: 2 83 | attn_resolutions: [] 84 | dropout: 0.0 85 | lossconfig: 86 | target: torch.nn.Identity 87 | 88 | first_stage_config: 89 | target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper 90 | params: 91 | ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors 92 | embed_dim: 4 93 | monitor: val/rec_loss 94 | ddconfig: 95 | attn_type: vanilla-xformers 96 | double_z: true 97 | z_channels: 4 98 | resolution: 256 99 | in_channels: 3 100 | out_ch: 3 101 | ch: 128 102 | ch_mult: [1, 2, 4, 4] 103 | num_res_blocks: 2 104 | attn_resolutions: [] 105 | dropout: 0.0 106 | lossconfig: 107 | target: torch.nn.Identity -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | type: "train" 2 | 3 | # path 4 | save_ckpt_dir: ./checkpoints 5 | load_ckpt_path: ./checkpoints/pretrained/512-inpainting-ema.ckpt 6 | 7 | model_cfg_path: ./configs/train/textdesign_sd_2.yaml 8 | dataset_cfg_path: ./configs/dataset/st.yaml 9 | 10 | # 11 | resume: False 12 | 13 | 14 | # param 15 | save_ckpt_freq: 1 16 | num_workers: 0 17 | batch_size: 8 18 | base_learning_rate: 5.0e-5 19 | shuffle: False 20 | 21 | # runtime 22 | lightning: 23 | max_epochs: 100 24 | accelerator: gpu 25 | strategy: ddp_find_unused_parameters_true 26 | accumulate_grad_batches: 4 27 | devices: [0,1,2,3] 28 | default_root_dir: ./logs 29 | profiler: simple -------------------------------------------------------------------------------- /configs/train/textdesign_sd_2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | opt_keys: 5 | - t_attn 6 | - t_norm 7 | input_key: image 8 | scale_factor: 0.18215 9 | disable_first_stage_autocast: True 10 | use_rendered: True 11 | supervise_mask: True 12 | 13 | denoiser_config: 14 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 15 | params: 16 | num_idx: 1000 17 | 18 | weighting_config: 19 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 20 | scaling_config: 21 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 22 | discretization_config: 23 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 24 | 25 | network_config: 26 | target: sgm.modules.diffusionmodules.openaimodel.UnifiedUNetModel 27 | params: 28 | in_channels: 9 29 | out_channels: 4 30 | ctrl_channels: 0 31 | model_channels: 320 32 | attention_resolutions: [4, 2, 1] 33 | save_attn_type: [t_attn] 34 | save_attn_layers: [output_blocks.6.1] 35 | num_res_blocks: 2 36 | channel_mult: [1, 2, 4, 4] 37 | num_head_channels: 64 38 | use_linear_in_transformer: True 39 | transformer_depth: 1 40 | t_context_dim: 2048 41 | 42 | conditioner_config: 43 | target: sgm.modules.GeneralConditioner 44 | params: 45 | emb_models: 46 | # textual crossattn cond 47 | - is_trainable: True 48 | emb_key: t_crossattn 49 | ucg_rate: 0.1 50 | input_key: label 51 | target: sgm.modules.encoders.modules.LabelEncoder 52 | params: 53 | max_len: 12 54 | emb_dim: 2048 55 | n_heads: 8 56 | n_trans_layers: 12 57 | lambda_cls: 0.1 58 | lambda_pos: 0.1 59 | ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820-002.ckpt 60 | 61 | visual_config: 62 | target: sgm.modules.encoders.modules.ViTSTREncoder 63 | params: 64 | freeze: True 65 | ckpt_path: "./checkpoints/encoders/ViTSTR/vitstr_base_patch16_224.pth" 66 | size: 224 67 | patch_size: 16 68 | embed_dim: 768 69 | depth: 12 70 | num_heads: 12 71 | mlp_ratio: 4 72 | qkv_bias: True 73 | in_chans: 1 74 | # concat cond 75 | - is_trainable: False 76 | input_key: mask 77 | target: sgm.modules.encoders.modules.SpatialRescaler 78 | params: 79 | in_channels: 1 80 | multiplier: 0.125 81 | - is_trainable: False 82 | input_key: masked 83 | target: sgm.modules.encoders.modules.LatentEncoder 84 | params: 85 | scale_factor: 0.18215 86 | config: 87 | target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper 88 | params: 89 | ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors 90 | embed_dim: 4 91 | monitor: val/rec_loss 92 | ddconfig: 93 | attn_type: vanilla-xformers 94 | double_z: true 95 | z_channels: 4 96 | resolution: 256 97 | in_channels: 3 98 | out_ch: 3 99 | ch: 128 100 | ch_mult: [1, 2, 4, 4] 101 | num_res_blocks: 2 102 | attn_resolutions: [] 103 | dropout: 0.0 104 | lossconfig: 105 | target: torch.nn.Identity 106 | 107 | first_stage_config: 108 | target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper 109 | params: 110 | ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors 111 | embed_dim: 4 112 | monitor: val/rec_loss 113 | ddconfig: 114 | attn_type: vanilla-xformers 115 | double_z: true 116 | z_channels: 4 117 | resolution: 256 118 | in_channels: 3 119 | out_ch: 3 120 | ch: 128 121 | ch_mult: [1, 2, 4, 4] 122 | num_res_blocks: 2 123 | attn_resolutions: [] 124 | dropout: 0.0 125 | lossconfig: 126 | target: torch.nn.Identity 127 | 128 | loss_fn_config: 129 | target: sgm.modules.diffusionmodules.loss.FullLoss # StandardDiffusionLoss 130 | params: 131 | seq_len: 12 132 | kernel_size: 3 133 | gaussian_sigma: 1.0 134 | min_attn_size: 16 135 | lambda_cross_loss: 0.01 136 | lambda_clip_loss: 0.001 137 | lambda_masked_loss: 0.01 138 | lambda_seg_loss: 0.001 139 | 140 | predictor_config: 141 | target: sgm.modules.predictors.model.ParseqPredictor 142 | params: 143 | ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt" 144 | 145 | sigma_sampler_config: 146 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 147 | params: 148 | num_idx: 1000 149 | 150 | discretization_config: 151 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 152 | 153 | sampler_config: 154 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 155 | params: 156 | num_steps: 50 157 | 158 | discretization_config: 159 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 160 | 161 | guider_config: 162 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 163 | params: 164 | scale: 5.0 -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/utils/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/dataset/utils/arial.ttf -------------------------------------------------------------------------------- /demo/examples/Adode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/examples/Adode.jpg -------------------------------------------------------------------------------- /demo/examples/CARE.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/examples/CARE.jpg -------------------------------------------------------------------------------- /demo/examples/DREAMTEXT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/examples/DREAMTEXT.jpg -------------------------------------------------------------------------------- /demo/examples/Heart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/examples/Heart.jpg -------------------------------------------------------------------------------- /demo/examples/Hello.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/examples/Hello.jpg -------------------------------------------------------------------------------- /demo/examples/now.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/examples/now.jpg -------------------------------------------------------------------------------- /demo/examples/third.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/examples/third.jpg -------------------------------------------------------------------------------- /demo/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/gradio.png -------------------------------------------------------------------------------- /demo/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/demo/teaser.png -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | border: 1px solid #bbb; 121 | border-radius: 10px; 122 | padding: 0; 123 | font-size: 0; 124 | } 125 | 126 | .results-carousel video { 127 | margin: 0; 128 | } 129 | 130 | 131 | .interpolation-panel { 132 | background: #f5f5f5; 133 | border-radius: 10px; 134 | } 135 | 136 | .interpolation-panel .interpolation-image { 137 | width: 100%; 138 | border-radius: 5px; 139 | } 140 | 141 | .interpolation-video-column { 142 | } 143 | 144 | .interpolation-panel .slider { 145 | margin: 0 !important; 146 | } 147 | 148 | .interpolation-panel .slider { 149 | margin: 0 !important; 150 | } 151 | 152 | #interpolation-image-wrapper { 153 | width: 100%; 154 | } 155 | #interpolation-image-wrapper img { 156 | border-radius: 5px; 157 | } 158 | -------------------------------------------------------------------------------- /docs/static/images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/architecture.png -------------------------------------------------------------------------------- /docs/static/images/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/comparison.png -------------------------------------------------------------------------------- /docs/static/images/comparison_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/comparison_2.png -------------------------------------------------------------------------------- /docs/static/images/comparison_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/comparison_3.png -------------------------------------------------------------------------------- /docs/static/images/display.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/display.png -------------------------------------------------------------------------------- /docs/static/images/more_visual_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/more_visual_result.png -------------------------------------------------------------------------------- /docs/static/images/more_visual_result_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/more_visual_result_2.png -------------------------------------------------------------------------------- /docs/static/images/problem_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/problem_visualization.png -------------------------------------------------------------------------------- /docs/static/images/problems.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/docs/static/images/problems.png -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | 5 | $(document).ready(function() { 6 | // Check for click events on the navbar burger icon 7 | $(".navbar-burger").click(function() { 8 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 9 | $(".navbar-burger").toggleClass("is-active"); 10 | $(".navbar-menu").toggleClass("is-active"); 11 | 12 | }); 13 | 14 | var options = { 15 | slidesToScroll: 1, 16 | slidesToShow: 3, 17 | loop: true, 18 | infinite: true, 19 | autoplay: false, 20 | autoplaySpeed: 3000, 21 | } 22 | 23 | // Initialize all div with carousel class 24 | var carousels = bulmaCarousel.attach('.carousel', options); 25 | 26 | // Loop on each carousel initialized 27 | for(var i = 0; i < carousels.length; i++) { 28 | // Add listener to event 29 | carousels[i].on('before:show', state => { 30 | console.log(state); 31 | }); 32 | } 33 | 34 | // Access to bulmaCarousel instance of an element 35 | var element = document.querySelector('#my-element'); 36 | if (element && element.bulmaCarousel) { 37 | // bulmaCarousel instance is available as element.bulmaCarousel 38 | element.bulmaCarousel.on('before-show', function(state) { 39 | console.log(state); 40 | }); 41 | } 42 | 43 | bulmaSlider.attach(); 44 | 45 | }) 46 | -------------------------------------------------------------------------------- /figures/compare_results_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/figures/compare_results_1.png -------------------------------------------------------------------------------- /figures/compare_results_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/figures/compare_results_2.png -------------------------------------------------------------------------------- /figures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/figures/model.png -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/figures/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorlover==0.3.0 2 | einops==0.6.1 3 | gradio==3.41.0 4 | imageio==2.31.2 5 | img2dataset==1.42.0 6 | kornia==0.6.9 7 | lpips==0.1.4 8 | matplotlib==3.7.2 9 | numpy==1.25.1 10 | omegaconf==2.3.0 11 | open-clip-torch==2.20.0 12 | opencv-python==4.6.0.66 13 | Pillow==9.5.0 14 | pytorch-fid==0.3.0 15 | pytorch-lightning==2.0.1 16 | safetensors==0.3.1 17 | scikit-learn==1.3.0 18 | scipy==1.11.1 19 | seaborn==0.12.2 20 | socksio==1.0.0 21 | tensorboard==2.14.0 22 | timm==0.9.2 23 | tokenizers==0.13.3 24 | torch==2.1.0 25 | torchvision==0.16.0 26 | tqdm==4.65.0 27 | transformers==4.30.2 28 | xformers==0.0.22.post7 29 | 30 | -------------------------------------------------------------------------------- /sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine, DiffusionEngine 2 | from .util import instantiate_from_config 3 | -------------------------------------------------------------------------------- /sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import DiagonalGaussianDistribution 9 | 10 | 11 | class AbstractRegularizer(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 16 | raise NotImplementedError() 17 | 18 | @abstractmethod 19 | def get_trainable_parameters(self) -> Any: 20 | raise NotImplementedError() 21 | 22 | 23 | class DiagonalGaussianRegularizer(AbstractRegularizer): 24 | def __init__(self, sample: bool = True): 25 | super().__init__() 26 | self.sample = sample 27 | 28 | def get_trainable_parameters(self) -> Any: 29 | yield from () 30 | 31 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 32 | log = dict() 33 | posterior = DiagonalGaussianDistribution(z) 34 | if self.sample: 35 | z = posterior.sample() 36 | else: 37 | z = posterior.mode() 38 | kl_loss = posterior.kl() 39 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 40 | log["kl_loss"] = kl_loss 41 | return z, log 42 | 43 | 44 | def measure_perplexity(predicted_indices, num_centroids): 45 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 46 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 47 | encodings = ( 48 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 49 | ) 50 | avg_probs = encodings.mean(0) 51 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 52 | cluster_use = torch.sum(avg_probs > 0) 53 | return perplexity, cluster_use 54 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser 2 | from .discretizer import Discretization 3 | from .model import Model, Encoder, Decoder 4 | from .openaimodel import UnifiedUNetModel 5 | from .sampling import BaseDiffusionSampler 6 | from .wrappers import OpenAIWrapper 7 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ...util import append_dims, instantiate_from_config 4 | 5 | 6 | class Denoiser(nn.Module): 7 | def __init__(self, weighting_config, scaling_config): 8 | super().__init__() 9 | 10 | self.weighting = instantiate_from_config(weighting_config) 11 | self.scaling = instantiate_from_config(scaling_config) 12 | 13 | def possibly_quantize_sigma(self, sigma): 14 | return sigma 15 | 16 | def possibly_quantize_c_noise(self, c_noise): 17 | return c_noise 18 | 19 | def w(self, sigma): 20 | return self.weighting(sigma) 21 | 22 | def __call__(self, network, input, sigma, cond): 23 | sigma = self.possibly_quantize_sigma(sigma) 24 | sigma_shape = sigma.shape 25 | sigma = append_dims(sigma, input.ndim) 26 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 27 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 28 | return network(input * c_in, c_noise, cond) * c_out + input * c_skip 29 | 30 | 31 | class DiscreteDenoiser(Denoiser): 32 | def __init__( 33 | self, 34 | weighting_config, 35 | scaling_config, 36 | num_idx, 37 | discretization_config, 38 | do_append_zero=False, 39 | quantize_c_noise=True, 40 | flip=True, 41 | ): 42 | super().__init__(weighting_config, scaling_config) 43 | sigmas = instantiate_from_config(discretization_config)( 44 | num_idx, do_append_zero=do_append_zero, flip=flip 45 | ) 46 | self.register_buffer("sigmas", sigmas) 47 | self.quantize_c_noise = quantize_c_noise 48 | 49 | def sigma_to_idx(self, sigma): 50 | dists = sigma - self.sigmas[:, None] 51 | return dists.abs().argmin(dim=0).view(sigma.shape) 52 | 53 | def idx_to_sigma(self, idx): 54 | return self.sigmas[idx] 55 | 56 | def possibly_quantize_sigma(self, sigma): 57 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 58 | 59 | def possibly_quantize_c_noise(self, c_noise): 60 | if self.quantize_c_noise: 61 | return self.sigma_to_idx(c_noise) 62 | else: 63 | return c_noise 64 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EDMScaling: 5 | def __init__(self, sigma_data=0.5): 6 | self.sigma_data = sigma_data 7 | 8 | def __call__(self, sigma): 9 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 10 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 11 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 12 | c_noise = 0.25 * sigma.log() 13 | return c_skip, c_out, c_in, c_noise 14 | 15 | 16 | class EpsScaling: 17 | def __call__(self, sigma): 18 | c_skip = torch.ones_like(sigma, device=sigma.device) 19 | c_out = -sigma 20 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 21 | c_noise = sigma.clone() 22 | return c_skip, c_out, c_in, c_noise 23 | 24 | 25 | class VScaling: 26 | def __call__(self, sigma): 27 | c_skip = 1.0 / (sigma**2 + 1.0) 28 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 29 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 30 | c_noise = sigma.clone() 31 | return c_skip, c_out, c_in, c_noise 32 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from functools import partial 4 | from abc import abstractmethod 5 | 6 | from ...util import append_zero 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | 9 | 10 | def generate_roughly_equally_spaced_steps( 11 | num_substeps: int, max_step: int 12 | ) -> np.ndarray: 13 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 14 | 15 | 16 | class Discretization: 17 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 18 | sigmas = self.get_sigmas(n, device=device) 19 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 20 | return sigmas if not flip else torch.flip(sigmas, (0,)) 21 | 22 | @abstractmethod 23 | def get_sigmas(self, n, device): 24 | pass 25 | 26 | 27 | class EDMDiscretization(Discretization): 28 | def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0): 29 | self.sigma_min = sigma_min 30 | self.sigma_max = sigma_max 31 | self.rho = rho 32 | 33 | def get_sigmas(self, n, device="cpu"): 34 | ramp = torch.linspace(0, 1, n, device=device) 35 | min_inv_rho = self.sigma_min ** (1 / self.rho) 36 | max_inv_rho = self.sigma_max ** (1 / self.rho) 37 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 38 | return sigmas 39 | 40 | 41 | class LegacyDDPMDiscretization(Discretization): 42 | def __init__( 43 | self, 44 | linear_start=0.00085, 45 | linear_end=0.0120, 46 | num_timesteps=1000, 47 | ): 48 | super().__init__() 49 | self.num_timesteps = num_timesteps 50 | betas = make_beta_schedule( 51 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 52 | ) 53 | alphas = 1.0 - betas 54 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 55 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 56 | 57 | def get_sigmas(self, n, device="cpu"): 58 | if n < self.num_timesteps: 59 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 60 | alphas_cumprod = self.alphas_cumprod[timesteps] 61 | elif n == self.num_timesteps: 62 | alphas_cumprod = self.alphas_cumprod 63 | else: 64 | raise ValueError 65 | 66 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 67 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 68 | return torch.flip(sigmas, (0,)) 69 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | from ...util import default, instantiate_from_config 6 | 7 | 8 | class VanillaCFG: 9 | """ 10 | implements parallelized CFG 11 | """ 12 | 13 | def __init__(self, scale, dyn_thresh_config=None): 14 | scale_schedule = lambda scale, sigma: scale # independent of step 15 | self.scale_schedule = partial(scale_schedule, scale) 16 | self.dyn_thresh = instantiate_from_config( 17 | default( 18 | dyn_thresh_config, 19 | { 20 | "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" 21 | }, 22 | ) 23 | ) 24 | 25 | def __call__(self, x, sigma): 26 | x_u, x_c = x.chunk(2) 27 | scale_value = self.scale_schedule(sigma) 28 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 29 | return x_pred 30 | 31 | def prepare_inputs(self, x, s, c, uc): 32 | c_out = dict() 33 | 34 | for k in c: 35 | if k in ["vector", "t_crossattn", "v_crossattn", "concat"]: 36 | c_out[k] = torch.cat((uc[k], c[k]), 0) 37 | else: 38 | assert c[k] == uc[k] 39 | c_out[k] = c[k] 40 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 41 | 42 | 43 | class IdentityGuider: 44 | def __call__(self, x, sigma): 45 | return x 46 | 47 | def prepare_inputs(self, x, s, c, uc): 48 | c_out = dict() 49 | 50 | for k in c: 51 | c_out[k] = c[k] 52 | 53 | return x, s, c_out 54 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | class NoDynamicThresholding: 8 | def __call__(self, uncond, cond, scale): 9 | return uncond + scale * (cond - uncond) 10 | 11 | 12 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 13 | if order - 1 > i: 14 | raise ValueError(f"Order {order} too high for step {i}") 15 | 16 | def fn(tau): 17 | prod = 1.0 18 | for k in range(order): 19 | if j == k: 20 | continue 21 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 22 | return prod 23 | 24 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 25 | 26 | 27 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 28 | if not eta: 29 | return sigma_to, 0.0 30 | sigma_up = torch.minimum( 31 | sigma_to, 32 | eta 33 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 34 | ) 35 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 36 | return sigma_down, sigma_up 37 | 38 | 39 | def to_d(x, sigma, denoised): 40 | return (x - denoised) / append_dims(sigma, x.ndim) 41 | 42 | 43 | def to_neg_log_sigma(sigma): 44 | return sigma.log().neg() 45 | 46 | 47 | def to_sigma(neg_log_sigma): 48 | return neg_log_sigma.neg().exp() 49 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): 18 | self.num_idx = num_idx 19 | self.sigmas = instantiate_from_config(discretization_config)( 20 | num_idx, do_append_zero=do_append_zero, flip=flip 21 | ) 22 | 23 | def idx_to_sigma(self, idx): 24 | return self.sigmas[idx] 25 | 26 | def __call__(self, n_samples, rand=None): 27 | idx = default( 28 | rand, 29 | torch.randint(0, self.num_idx, (n_samples,)), 30 | ) 31 | return self.idx_to_sigma(idx) 32 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 14 | and compile_model 15 | else lambda x: x 16 | ) 17 | self.diffusion_model = compile(diffusion_model) 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward( 25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 26 | ) -> torch.Tensor: 27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 28 | return self.diffusion_model( 29 | x, 30 | timesteps=t, 31 | t_context=c.get("t_crossattn", None), 32 | v_context=c.get("v_crossattn", None), 33 | y=c.get("vector", None), 34 | **kwargs 35 | ) 36 | -------------------------------------------------------------------------------- /sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /sgm/modules/predictors/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | from torchvision.utils import save_image 5 | 6 | 7 | class ParseqPredictor(nn.Module): 8 | 9 | def __init__(self, ckpt_path=None, freeze=True, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | 12 | self.parseq = torch.hub.load('./src/parseq', 'parseq', source='local').eval() 13 | self.parseq.load_state_dict(torch.load(ckpt_path, map_location="cpu")) 14 | self.parseq_transform = transforms.Compose([ 15 | transforms.Resize(self.parseq.hparams.img_size, transforms.InterpolationMode.BICUBIC, antialias=True), 16 | transforms.Normalize(0.5, 0.5) 17 | ]) 18 | 19 | if freeze: 20 | self.freeze() 21 | 22 | def freeze(self): 23 | for param in self.parseq.parameters(): 24 | param.requires_grad_(False) 25 | 26 | def forward(self, x): 27 | 28 | x = torch.cat([self.parseq_transform(t[None]) for t in x]) 29 | logits = self.parseq(x.to(next(self.parameters()).device)) 30 | 31 | return logits 32 | 33 | def img2txt(self, x): 34 | 35 | pred = self(x) 36 | label, confidence = self.parseq.tokenizer.decode(pred) 37 | return label 38 | 39 | 40 | def calc_loss(self, x, label): 41 | 42 | preds = self(x) # (B, l, C) l=26, C=95 43 | gt_ids = self.parseq.tokenizer.encode(label).to(preds.device) # (B, l_trun) 44 | 45 | losses = [] 46 | for pred, gt_id in zip(preds, gt_ids): 47 | 48 | eos_id = (gt_id == 0).nonzero().item() 49 | gt_id = gt_id[1: eos_id] 50 | pred = pred[:eos_id-1, :] 51 | 52 | ce_loss = nn.functional.cross_entropy(pred.permute(1, 0)[None], gt_id[None]) 53 | ce_loss = torch.clamp(ce_loss, max = 1.0) 54 | losses.append(ce_loss[None]) 55 | 56 | loss = torch.cat(losses) 57 | 58 | return loss -------------------------------------------------------------------------------- /src/parseq/Datasets.md: -------------------------------------------------------------------------------- 1 | We use various synthetic and real datasets. More info is in Appendix F of the supplementary material. Some preprocessing scripts are included in [`tools/`](tools). 2 | 3 | | Dataset | Type | Remarks | 4 | |:-------:|:-----:|:--------| 5 | | [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) | synthetic | Case-sensitive annotations were extracted from the image filenames | 6 | | [SynthText](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | synthetic | Processed with [`crop_by_word_bb_syn90k.py`](https://github.com/FangShancheng/ABINet/blob/main/tools/crop_by_word_bb_syn90k.py) | 7 | | [IC13](https://rrc.cvc.uab.es/?ch=2) | real | Three archives: 857, 1015, 1095 (full) | 8 | | [IC15](https://rrc.cvc.uab.es/?ch=4) | real | Two archives: 1811, 2077 (full) | 9 | | [CUTE80](http://cs-chan.com/downloads_cute80_dataset.html) | real | \[1\] | 10 | | [IIIT5k](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) | real | \[1\] | 11 | | [SVT](http://vision.ucsd.edu/~kai/svt/) | real | \[1\] | 12 | | [SVTP](https://openaccess.thecvf.com/content_iccv_2013/html/Phan_Recognizing_Text_with_2013_ICCV_paper.html) | real | \[1\] | 13 | | [ArT](https://rrc.cvc.uab.es/?ch=14) | real | \[2\] | 14 | | [LSVT](https://rrc.cvc.uab.es/?ch=16) | real | \[2\] | 15 | | [MLT19](https://rrc.cvc.uab.es/?ch=15) | real | \[2\] | 16 | | [RCTW17](https://rctw.vlrlab.net/dataset.html) | real | \[2\] | 17 | | [ReCTS](https://rrc.cvc.uab.es/?ch=12) | real | \[2\] | 18 | | [Uber-Text](https://s3-us-west-2.amazonaws.com/uber-common-public/ubertext/index.html) | real | \[2\] | 19 | | [COCO-Text v1.4](https://rrc.cvc.uab.es/?ch=5) | real | Processed with [`coco_text_converter.py`](tools/coco_text_converter.py) | 20 | | [COCO-Text v2.0](https://bgshih.github.io/cocotext/) | real | Processed with [`coco_2_converter.py`](tools/coco_2_converter.py) | 21 | | [OpenVINO](https://proceedings.mlr.press/v157/krylov21a.html) | real | [Annotations](https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/datasets/open_images_v5_text/) for a subset of [Open Images](https://github.com/cvdfoundation/open-images-dataset). Processed with [`openvino_converter.py`](tools/openvino_converter.py). | 22 | | [TextOCR](https://textvqa.org/textocr/) | real | Annotations for a subset of Open Images. Processed with [`textocr_converter.py`](tools/textocr_converter.py). A _horizontal_ version can be generated by passing `--rectify_pose`. | 23 | 24 | \[1\] Case-sensitive annotations from [Long and Yao](https://github.com/Jyouhou/Case-Sensitive-Scene-Text-Recognition-Datasets) + [our corrections](https://github.com/baudm/Case-Sensitive-Scene-Text-Recognition-Datasets). Processed with [case_sensitive_str_datasets_converter.py](tools/case_sensitive_str_datasets_converter.py)
25 | \[2\] Archives used as-is from [Baek et al.](https://github.com/ku21fan/STR-Fewer-Labels/blob/main/data.md) They are included in the dataset release for convenience. Please refer to their work for more info about the datasets. 26 | 27 | The preprocessed archives are available here: [val + test + most of train](https://drive.google.com/drive/folders/1NYuoi7dfJVgo-zUJogh8UQZgIMpLviOE), [TextOCR + OpenVINO](https://drive.google.com/drive/folders/1D9z_YJVa6f-O0juni-yG5jcwnhvYw-qC) 28 | 29 | The expected filesystem structure is as follows: 30 | ``` 31 | data 32 | ├── test 33 | │ ├── ArT 34 | │ ├── COCOv1.4 35 | │ ├── CUTE80 36 | │ ├── IC13_1015 37 | │ ├── IC13_1095 # Full IC13 test set. Typically not used for benchmarking but provided here for convenience. 38 | │ ├── IC13_857 39 | │ ├── IC15_1811 40 | │ ├── IC15_2077 41 | │ ├── IIIT5k 42 | │ ├── SVT 43 | │ ├── SVTP 44 | │ └── Uber 45 | ├── train 46 | │ ├── real 47 | │ │ ├── ArT 48 | │ │ │ ├── train 49 | │ │ │ └── val 50 | │ │ ├── COCOv2.0 51 | │ │ │ ├── train 52 | │ │ │ └── val 53 | │ │ ├── LSVT 54 | │ │ │ ├── test 55 | │ │ │ ├── train 56 | │ │ │ └── val 57 | │ │ ├── MLT19 58 | │ │ │ ├── test 59 | │ │ │ ├── train 60 | │ │ │ └── val 61 | │ │ ├── OpenVINO 62 | │ │ │ ├── train_1 63 | │ │ │ ├── train_2 64 | │ │ │ ├── train_5 65 | │ │ │ ├── train_f 66 | │ │ │ └── validation 67 | │ │ ├── RCTW17 68 | │ │ │ ├── test 69 | │ │ │ ├── train 70 | │ │ │ └── val 71 | │ │ ├── ReCTS 72 | │ │ │ ├── test 73 | │ │ │ ├── train 74 | │ │ │ └── val 75 | │ │ ├── TextOCR 76 | │ │ │ ├── train 77 | │ │ │ └── val 78 | │ │ └── Uber 79 | │ │ ├── train 80 | │ │ └── val 81 | │ └── synth 82 | │ ├── MJ 83 | │ │ ├── test 84 | │ │ ├── train 85 | │ │ └── val 86 | │ └── ST 87 | └── val 88 | ├── IC13 89 | ├── IC15 90 | ├── IIIT5k 91 | └── SVT 92 | ``` 93 | -------------------------------------------------------------------------------- /src/parseq/NOTICE: -------------------------------------------------------------------------------- 1 | Scene Text Recognition Model Hub 2 | Copyright 2022 Darwin Bautista 3 | 4 | The Initial Developer of strhub/models/abinet (sans system.py) is 5 | Fang et al. (https://github.com/FangShancheng/ABINet). 6 | Copyright 2021-2022 USTC 7 | 8 | The Initial Developer of strhub/models/crnn (sans system.py) is 9 | Jieru Mei (https://github.com/meijieru/crnn.pytorch). 10 | Copyright 2017-2022 Jieru Mei 11 | 12 | The Initial Developer of strhub/models/trba (sans system.py) is 13 | Jeonghun Baek (https://github.com/clovaai/deep-text-recognition-benchmark). 14 | Copyright 2019-2022 NAVER Corp. 15 | 16 | The Initial Developer of strhub/models/vitstr (sans system.py) is 17 | Rowel Atienza (https://github.com/roatienza/deep-text-recognition-benchmark). 18 | Copyright 2021-2022 Rowel Atienza 19 | -------------------------------------------------------------------------------- /src/parseq/bench.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Scene Text Recognition Model Hub 3 | # Copyright 2022 Darwin Bautista 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | 19 | import torch 20 | from torch.utils import benchmark 21 | 22 | from fvcore.nn import FlopCountAnalysis, ActivationCountAnalysis, flop_count_table 23 | 24 | import hydra 25 | from omegaconf import DictConfig 26 | 27 | 28 | @torch.inference_mode() 29 | @hydra.main(config_path='configs', config_name='bench', version_base='1.2') 30 | def main(config: DictConfig): 31 | # For consistent behavior 32 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 33 | torch.backends.cudnn.benchmark = False 34 | torch.use_deterministic_algorithms(True) 35 | 36 | device = config.get('device', 'cuda') 37 | 38 | h, w = config.data.img_size 39 | x = torch.rand(1, 3, h, w, device=device) 40 | model = hydra.utils.instantiate(config.model).eval().to(device) 41 | 42 | if config.get('range', False): 43 | for i in range(1, 26, 4): 44 | timer = benchmark.Timer( 45 | stmt='model(x, len)', 46 | globals={'model': model, 'x': x, 'len': i}) 47 | print(timer.blocked_autorange(min_run_time=1)) 48 | else: 49 | timer = benchmark.Timer( 50 | stmt='model(x)', 51 | globals={'model': model, 'x': x}) 52 | flops = FlopCountAnalysis(model, x) 53 | acts = ActivationCountAnalysis(model, x) 54 | print(timer.blocked_autorange(min_run_time=1)) 55 | print(flop_count_table(flops, 1, acts, False)) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /src/parseq/configs/bench.yaml: -------------------------------------------------------------------------------- 1 | # Disable any logging or output 2 | defaults: 3 | - main 4 | - _self_ 5 | - override hydra/job_logging: disabled 6 | 7 | hydra: 8 | output_subdir: null 9 | run: 10 | dir: . 11 | -------------------------------------------------------------------------------- /src/parseq/configs/charset/36_lowercase.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | charset_train: "0123456789abcdefghijklmnopqrstuvwxyz" 4 | -------------------------------------------------------------------------------- /src/parseq/configs/charset/62_mixed-case.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | charset_train: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 4 | -------------------------------------------------------------------------------- /src/parseq/configs/charset/94_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | charset_train: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" 4 | -------------------------------------------------------------------------------- /src/parseq/configs/dataset/real.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: real 4 | -------------------------------------------------------------------------------- /src/parseq/configs/dataset/synth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data: 3 | train_dir: synth 4 | num_workers: 3 5 | 6 | trainer: 7 | limit_train_batches: 0.20496 # to match the steps per epoch of `real` 8 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/abinet-sv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: abinet 4 | 5 | model: 6 | name: abinet-sv 7 | v_num_layers: 2 8 | v_attention: attention 9 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/abinet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: abinet 4 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/crnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: crnn 4 | 5 | data: 6 | num_workers: 5 7 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/parseq-patch16-224.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | 5 | model: 6 | img_size: [ 224, 224 ] # [ height, width ] 7 | patch_size: [ 16, 16 ] # [ height, width ] 8 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/parseq-tiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | 5 | model: 6 | name: parseq-tiny 7 | embed_dim: 192 8 | enc_num_heads: 3 9 | dec_num_heads: 6 10 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/parseq.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: parseq 4 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/trba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: trba 4 | 5 | data: 6 | num_workers: 3 7 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/trbc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: trba 4 | 5 | model: 6 | name: trbc 7 | _target_: strhub.models.trba.system.TRBC 8 | lr: 1e-4 9 | 10 | data: 11 | num_workers: 3 12 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/tune_abinet-lm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: abinet 4 | 5 | model: 6 | name: abinet-lm 7 | lm_only: true 8 | 9 | data: 10 | augment: false 11 | num_workers: 3 12 | 13 | tune: 14 | gpus_per_trial: 0.5 15 | lr: 16 | min: 1e-5 17 | max: 1e-3 18 | -------------------------------------------------------------------------------- /src/parseq/configs/experiment/vitstr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vitstr 4 | 5 | model: 6 | img_size: [ 32, 128 ] # [ height, width ] 7 | patch_size: [ 4, 8 ] # [ height, width ] 8 | -------------------------------------------------------------------------------- /src/parseq/configs/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: parseq 4 | - charset: 94_full 5 | - dataset: real 6 | 7 | model: 8 | _convert_: all 9 | img_size: [ 32, 128 ] # [ height, width ] 10 | max_label_length: 25 11 | # The ordering in charset_train matters. It determines the token IDs assigned to each character. 12 | charset_train: ??? 13 | # For charset_test, ordering doesn't matter. 14 | charset_test: "0123456789abcdefghijklmnopqrstuvwxyz" 15 | batch_size: 384 16 | weight_decay: 0.0 17 | warmup_pct: 0.075 # equivalent to 1.5 epochs of warm up 18 | 19 | data: 20 | _target_: strhub.data.module.SceneTextDataModule 21 | root_dir: data 22 | train_dir: ??? 23 | batch_size: ${model.batch_size} 24 | img_size: ${model.img_size} 25 | charset_train: ${model.charset_train} 26 | charset_test: ${model.charset_test} 27 | max_label_length: ${model.max_label_length} 28 | remove_whitespace: true 29 | normalize_unicode: true 30 | augment: true 31 | num_workers: 2 32 | 33 | trainer: 34 | _target_: pytorch_lightning.Trainer 35 | _convert_: all 36 | val_check_interval: 1000 37 | #max_steps: 169680 # 20 epochs x 8484 steps (for batch size = 384, real data) 38 | max_epochs: 20 39 | gradient_clip_val: 20 40 | gpus: 2 41 | 42 | ckpt_path: null 43 | pretrained: null 44 | 45 | hydra: 46 | output_subdir: config 47 | run: 48 | dir: outputs/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 49 | sweep: 50 | dir: multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 51 | subdir: ${hydra.job.override_dirname} 52 | -------------------------------------------------------------------------------- /src/parseq/configs/model/abinet.yaml: -------------------------------------------------------------------------------- 1 | name: abinet 2 | _target_: strhub.models.abinet.system.ABINet 3 | 4 | # Shared Transformer configuration 5 | d_model: 512 6 | nhead: 8 7 | d_inner: 2048 8 | activation: relu 9 | dropout: 0.1 10 | 11 | # Architecture 12 | v_backbone: transformer 13 | v_num_layers: 3 14 | v_attention: position 15 | v_attention_mode: nearest 16 | l_num_layers: 4 17 | l_use_self_attn: false 18 | 19 | # Training 20 | lr: 3.4e-4 21 | l_lr: 3e-4 22 | iter_size: 3 23 | a_loss_weight: 1. 24 | v_loss_weight: 1. 25 | l_loss_weight: 1. 26 | l_detach: true 27 | -------------------------------------------------------------------------------- /src/parseq/configs/model/crnn.yaml: -------------------------------------------------------------------------------- 1 | name: crnn 2 | _target_: strhub.models.crnn.system.CRNN 3 | 4 | # Architecture 5 | hidden_size: 256 6 | leaky_relu: false 7 | 8 | # Training 9 | lr: 5.1e-4 10 | -------------------------------------------------------------------------------- /src/parseq/configs/model/parseq.yaml: -------------------------------------------------------------------------------- 1 | name: parseq 2 | _target_: strhub.models.parseq.system.PARSeq 3 | 4 | # Data 5 | patch_size: [ 4, 8 ] # [ height, width ] 6 | 7 | # Architecture 8 | embed_dim: 384 9 | enc_num_heads: 6 10 | enc_mlp_ratio: 4 11 | enc_depth: 12 12 | dec_num_heads: 12 13 | dec_mlp_ratio: 4 14 | dec_depth: 1 15 | 16 | # Training 17 | lr: 7e-4 18 | perm_num: 6 19 | perm_forward: true 20 | perm_mirrored: true 21 | dropout: 0.1 22 | 23 | # Decoding mode (test) 24 | decode_ar: true 25 | refine_iters: 1 26 | -------------------------------------------------------------------------------- /src/parseq/configs/model/trba.yaml: -------------------------------------------------------------------------------- 1 | name: trba 2 | _target_: strhub.models.trba.system.TRBA 3 | 4 | # Architecture 5 | num_fiducial: 20 6 | output_channel: 512 7 | hidden_size: 256 8 | 9 | # Training 10 | lr: 6.9e-4 11 | -------------------------------------------------------------------------------- /src/parseq/configs/model/vitstr.yaml: -------------------------------------------------------------------------------- 1 | name: vitstr 2 | _target_: strhub.models.vitstr.system.ViTSTR 3 | 4 | # Data 5 | img_size: [ 224, 224 ] # [ height, width ] 6 | patch_size: [ 16, 16 ] # [ height, width ] 7 | 8 | # Architecture 9 | embed_dim: 384 10 | num_heads: 6 11 | 12 | # Training 13 | lr: 8.9e-4 14 | -------------------------------------------------------------------------------- /src/parseq/configs/tune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - main 3 | - _self_ 4 | 5 | trainer: 6 | gpus: 1 # tuning with DDP is not yet supported. 7 | 8 | tune: 9 | num_samples: 10 10 | gpus_per_trial: 1 11 | lr: 12 | min: 1e-4 13 | max: 2e-3 14 | resume_dir: null 15 | 16 | hydra: 17 | run: 18 | dir: ray_results/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 19 | -------------------------------------------------------------------------------- /src/parseq/demo_images/art-01107.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/demo_images/art-01107.jpg -------------------------------------------------------------------------------- /src/parseq/demo_images/coco-1166773.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/demo_images/coco-1166773.jpg -------------------------------------------------------------------------------- /src/parseq/demo_images/cute-184.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/demo_images/cute-184.jpg -------------------------------------------------------------------------------- /src/parseq/demo_images/ic13_word_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/demo_images/ic13_word_256.png -------------------------------------------------------------------------------- /src/parseq/demo_images/ic15_word_26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/demo_images/ic15_word_26.png -------------------------------------------------------------------------------- /src/parseq/demo_images/uber-27491.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/demo_images/uber-27491.jpg -------------------------------------------------------------------------------- /src/parseq/hubconf.py: -------------------------------------------------------------------------------- 1 | from strhub.models.utils import create_model 2 | 3 | 4 | dependencies = ['torch', 'pytorch_lightning', 'timm'] 5 | 6 | 7 | def parseq_tiny(pretrained: bool = False, decode_ar: bool = True, refine_iters: int = 1, **kwargs): 8 | """ 9 | PARSeq tiny model (img_size=128x32, patch_size=8x4, d_model=192) 10 | @param pretrained: (bool) Use pretrained weights 11 | @param decode_ar: (bool) use AR decoding 12 | @param refine_iters: (int) number of refinement iterations to use 13 | """ 14 | return create_model('parseq-tiny', pretrained, decode_ar=decode_ar, refine_iters=refine_iters, **kwargs) 15 | 16 | 17 | def parseq(pretrained: bool = False, decode_ar: bool = True, refine_iters: int = 1, **kwargs): 18 | """ 19 | PARSeq base model (img_size=128x32, patch_size=8x4, d_model=384) 20 | @param pretrained: (bool) Use pretrained weights 21 | @param decode_ar: (bool) use AR decoding 22 | @param refine_iters: (int) number of refinement iterations to use 23 | """ 24 | return create_model('parseq', pretrained, decode_ar=decode_ar, refine_iters=refine_iters, **kwargs) 25 | 26 | 27 | def abinet(pretrained: bool = False, iter_size: int = 3, **kwargs): 28 | """ 29 | ABINet model (img_size=128x32) 30 | @param pretrained: (bool) Use pretrained weights 31 | @param iter_size: (int) number of refinement iterations to use 32 | """ 33 | return create_model('abinet', pretrained, iter_size=iter_size, **kwargs) 34 | 35 | 36 | def trba(pretrained: bool = False, **kwargs): 37 | """ 38 | TRBA model (img_size=128x32) 39 | @param pretrained: (bool) Use pretrained weights 40 | """ 41 | return create_model('trba', pretrained, **kwargs) 42 | 43 | 44 | def vitstr(pretrained: bool = False, **kwargs): 45 | """ 46 | ViTSTR small model (img_size=128x32, patch_size=8x4, d_model=384) 47 | @param pretrained: (bool) Use pretrained weights 48 | """ 49 | return create_model('vitstr', pretrained, **kwargs) 50 | 51 | 52 | def crnn(pretrained: bool = False, **kwargs): 53 | """ 54 | CRNN model (img_size=128x32) 55 | @param pretrained: (bool) Use pretrained weights 56 | """ 57 | return create_model('crnn', pretrained, **kwargs) 58 | -------------------------------------------------------------------------------- /src/parseq/read.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Scene Text Recognition Model Hub 3 | # Copyright 2022 Darwin Bautista 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | 19 | import torch 20 | 21 | from PIL import Image 22 | 23 | from strhub.data.module import SceneTextDataModule 24 | from strhub.models.utils import load_from_checkpoint, parse_model_args 25 | 26 | 27 | @torch.inference_mode() 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('checkpoint', help="Model checkpoint (or 'pretrained=')") 31 | parser.add_argument('--images', nargs='+', help='Images to read') 32 | parser.add_argument('--device', default='cuda') 33 | args, unknown = parser.parse_known_args() 34 | kwargs = parse_model_args(unknown) 35 | print(f'Additional keyword arguments: {kwargs}') 36 | 37 | model = load_from_checkpoint(args.checkpoint, **kwargs).eval().to(args.device) 38 | img_transform = SceneTextDataModule.get_transform(model.hparams.img_size) 39 | 40 | for fname in args.images: 41 | # Load image and prepare for input 42 | image = Image.open(fname).convert('RGB') 43 | image = img_transform(image).unsqueeze(0).to(args.device) 44 | 45 | p = model(image).softmax(-1) 46 | pred, p = model.tokenizer.decode(p) 47 | print(f'{fname}: {pred[0]}') 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /src/parseq/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.10.2 2 | torchvision>=0.11.3 3 | pytorch-lightning~=1.6.5 4 | timm~=0.6.5 5 | nltk~=3.7.0 6 | lmdb~=1.3.0 7 | Pillow~=9.2.0 8 | imgaug~=0.4.0 9 | hydra-core~=1.2.0 10 | fvcore~=0.1.5.post20220512 11 | ray[tune]~=1.13.0 12 | ax-platform~=0.2.5.1 13 | PyYAML~=6.0.0 14 | tqdm~=4.64.0 15 | -------------------------------------------------------------------------------- /src/parseq/setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | norecursedirs = 3 | .git 4 | dist 5 | build 6 | addopts = 7 | --strict 8 | --doctest-modules 9 | --durations=0 10 | 11 | [coverage:report] 12 | exclude_lines = 13 | pragma: no-cover 14 | pass 15 | 16 | [flake8] 17 | max-line-length = 120 18 | exclude = .tox,*.egg,build,temp 19 | select = E,W,F 20 | doctests = True 21 | verbose = 2 22 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 23 | format = pylint 24 | # see: https://www.flake8rules.com/ 25 | ignore = 26 | E731 # Do not assign a lambda expression, use a def 27 | W504 # Line break occurred after a binary operator 28 | F401 # Module imported but unused 29 | F841 # Local variable name is assigned to but never used 30 | W605 # Invalid escape sequence 'x' 31 | 32 | # setup.cfg or tox.ini 33 | [check-manifest] 34 | ignore = 35 | *.yml 36 | .github 37 | .github/* 38 | 39 | [metadata] 40 | license_file = LICENSE 41 | description-file = README.md 42 | # long_description = file:README.md 43 | # long_description_content_type = text/markdown 44 | -------------------------------------------------------------------------------- /src/parseq/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='strhub', 7 | version='1.1.0', 8 | description='Scene Text Recognition Model Hub: A collection of deep learning models for Scene Text Recognition', 9 | author='Darwin Bautista', 10 | author_email='baudm@users.noreply.github.com', 11 | url='https://github.com/baudm/parseq', 12 | install_requires=['torch~=1.12.1', 'pytorch-lightning~=1.6.5', 'timm~=0.6.5'], 13 | packages=find_packages(), 14 | ) 15 | -------------------------------------------------------------------------------- /src/parseq/strhub/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/strhub/__init__.py -------------------------------------------------------------------------------- /src/parseq/strhub/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/strhub/data/__init__.py -------------------------------------------------------------------------------- /src/parseq/strhub/data/aa_overrides.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Extends default ops to accept optional parameters.""" 17 | from functools import partial 18 | 19 | from timm.data.auto_augment import _LEVEL_DENOM, _randomly_negate, LEVEL_TO_ARG, NAME_TO_OP, rotate 20 | 21 | 22 | def rotate_expand(img, degrees, **kwargs): 23 | """Rotate operation with expand=True to avoid cutting off the characters""" 24 | kwargs['expand'] = True 25 | return rotate(img, degrees, **kwargs) 26 | 27 | 28 | def _level_to_arg(level, hparams, key, default): 29 | magnitude = hparams.get(key, default) 30 | level = (level / _LEVEL_DENOM) * magnitude 31 | level = _randomly_negate(level) 32 | return level, 33 | 34 | 35 | def apply(): 36 | # Overrides 37 | NAME_TO_OP.update({ 38 | 'Rotate': rotate_expand 39 | }) 40 | LEVEL_TO_ARG.update({ 41 | 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.), 42 | 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3), 43 | 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3), 44 | 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45), 45 | 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45), 46 | }) 47 | -------------------------------------------------------------------------------- /src/parseq/strhub/data/augment.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | 18 | import imgaug.augmenters as iaa 19 | import numpy as np 20 | from PIL import ImageFilter, Image 21 | from timm.data import auto_augment 22 | 23 | from strhub.data import aa_overrides 24 | 25 | aa_overrides.apply() 26 | 27 | _OP_CACHE = {} 28 | 29 | 30 | def _get_op(key, factory): 31 | try: 32 | op = _OP_CACHE[key] 33 | except KeyError: 34 | op = factory() 35 | _OP_CACHE[key] = op 36 | return op 37 | 38 | 39 | def _get_param(level, img, max_dim_factor, min_level=1): 40 | max_level = max(min_level, max_dim_factor * max(img.size)) 41 | return round(min(level, max_level)) 42 | 43 | 44 | def gaussian_blur(img, radius, **__): 45 | radius = _get_param(radius, img, 0.02) 46 | key = 'gaussian_blur_' + str(radius) 47 | op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius)) 48 | return img.filter(op) 49 | 50 | 51 | def motion_blur(img, k, **__): 52 | k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values 53 | key = 'motion_blur_' + str(k) 54 | op = _get_op(key, lambda: iaa.MotionBlur(k)) 55 | return Image.fromarray(op(image=np.asarray(img))) 56 | 57 | 58 | def gaussian_noise(img, scale, **_): 59 | scale = _get_param(scale, img, 0.25) | 1 # bin to odd values 60 | key = 'gaussian_noise_' + str(scale) 61 | op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale)) 62 | return Image.fromarray(op(image=np.asarray(img))) 63 | 64 | 65 | def poisson_noise(img, lam, **_): 66 | lam = _get_param(lam, img, 0.2) | 1 # bin to odd values 67 | key = 'poisson_noise_' + str(lam) 68 | op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam)) 69 | return Image.fromarray(op(image=np.asarray(img))) 70 | 71 | 72 | def _level_to_arg(level, _hparams, max): 73 | level = max * level / auto_augment._LEVEL_DENOM 74 | return level, 75 | 76 | 77 | _RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy() 78 | _RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops 79 | _RAND_TRANSFORMS.extend([ 80 | 'GaussianBlur', 81 | # 'MotionBlur', 82 | # 'GaussianNoise', 83 | 'PoissonNoise' 84 | ]) 85 | auto_augment.LEVEL_TO_ARG.update({ 86 | 'GaussianBlur': partial(_level_to_arg, max=4), 87 | 'MotionBlur': partial(_level_to_arg, max=20), 88 | 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255), 89 | 'PoissonNoise': partial(_level_to_arg, max=40) 90 | }) 91 | auto_augment.NAME_TO_OP.update({ 92 | 'GaussianBlur': gaussian_blur, 93 | 'MotionBlur': motion_blur, 94 | 'GaussianNoise': gaussian_noise, 95 | 'PoissonNoise': poisson_noise 96 | }) 97 | 98 | 99 | def rand_augment_transform(magnitude=5, num_layers=3): 100 | # These are tuned for magnitude=5, which means that effective magnitudes are half of these values. 101 | hparams = { 102 | 'rotate_deg': 30, 103 | 'shear_x_pct': 0.9, 104 | 'shear_y_pct': 0.2, 105 | 'translate_x_pct': 0.10, 106 | 'translate_y_pct': 0.30 107 | } 108 | ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS) 109 | # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice) 110 | choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))] 111 | return auto_augment.RandAugment(ra_ops, num_layers, choice_weights) 112 | -------------------------------------------------------------------------------- /src/parseq/strhub/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import glob 16 | import io 17 | import logging 18 | import unicodedata 19 | from pathlib import Path, PurePath 20 | from typing import Callable, Optional, Union 21 | 22 | import lmdb 23 | from PIL import Image 24 | from torch.utils.data import Dataset, ConcatDataset 25 | 26 | from strhub.data.utils import CharsetAdapter 27 | 28 | log = logging.getLogger(__name__) 29 | 30 | 31 | def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs): 32 | try: 33 | kwargs.pop('root') # prevent 'root' from being passed via kwargs 34 | except KeyError: 35 | pass 36 | root = Path(root).absolute() 37 | log.info(f'dataset root:\t{root}') 38 | datasets = [] 39 | for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True): 40 | mdb = Path(mdb) 41 | ds_name = str(mdb.parent.relative_to(root)) 42 | ds_root = str(mdb.parent.absolute()) 43 | dataset = LmdbDataset(ds_root, *args, **kwargs) 44 | log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}') 45 | datasets.append(dataset) 46 | return ConcatDataset(datasets) 47 | 48 | 49 | class LmdbDataset(Dataset): 50 | """Dataset interface to an LMDB database. 51 | 52 | It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned 53 | as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset. 54 | Labels are transformed according to the charset. 55 | """ 56 | 57 | def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0, 58 | remove_whitespace: bool = True, normalize_unicode: bool = True, 59 | unlabelled: bool = False, transform: Optional[Callable] = None): 60 | self._env = None 61 | self.root = root 62 | self.unlabelled = unlabelled 63 | self.transform = transform 64 | self.labels = [] 65 | self.filtered_index_list = [] 66 | self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode, 67 | max_label_len, min_image_dim) 68 | 69 | def __del__(self): 70 | if self._env is not None: 71 | self._env.close() 72 | self._env = None 73 | 74 | def _create_env(self): 75 | return lmdb.open(self.root, max_readers=1, readonly=True, create=False, 76 | readahead=False, meminit=False, lock=False) 77 | 78 | @property 79 | def env(self): 80 | if self._env is None: 81 | self._env = self._create_env() 82 | return self._env 83 | 84 | def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim): 85 | charset_adapter = CharsetAdapter(charset) 86 | with self._create_env() as env, env.begin() as txn: 87 | num_samples = int(txn.get('num-samples'.encode())) 88 | if self.unlabelled: 89 | return num_samples 90 | for index in range(num_samples): 91 | index += 1 # lmdb starts with 1 92 | label_key = f'label-{index:09d}'.encode() 93 | label = txn.get(label_key).decode() 94 | # Normally, whitespace is removed from the labels. 95 | if remove_whitespace: 96 | label = ''.join(label.split()) 97 | # Normalize unicode composites (if any) and convert to compatible ASCII characters 98 | if normalize_unicode: 99 | label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode() 100 | # Filter by length before removing unsupported characters. The original label might be too long. 101 | if len(label) > max_label_len: 102 | continue 103 | label = charset_adapter(label) 104 | # We filter out samples which don't contain any supported characters 105 | if not label: 106 | continue 107 | # Filter images that are too small. 108 | if min_image_dim > 0: 109 | img_key = f'image-{index:09d}'.encode() 110 | buf = io.BytesIO(txn.get(img_key)) 111 | w, h = Image.open(buf).size 112 | if w < self.min_image_dim or h < self.min_image_dim: 113 | continue 114 | self.labels.append(label) 115 | self.filtered_index_list.append(index) 116 | return len(self.labels) 117 | 118 | def __len__(self): 119 | return self.num_samples 120 | 121 | def __getitem__(self, index): 122 | if self.unlabelled: 123 | label = index 124 | else: 125 | label = self.labels[index] 126 | index = self.filtered_index_list[index] 127 | 128 | img_key = f'image-{index:09d}'.encode() 129 | with self.env.begin() as txn: 130 | imgbuf = txn.get(img_key) 131 | buf = io.BytesIO(imgbuf) 132 | img = Image.open(buf).convert('RGB') 133 | 134 | if self.transform is not None: 135 | img = self.transform(img) 136 | 137 | return img, label 138 | -------------------------------------------------------------------------------- /src/parseq/strhub/data/module.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from pathlib import PurePath 17 | from typing import Optional, Callable, Sequence, Tuple 18 | 19 | import pytorch_lightning as pl 20 | from torch.utils.data import DataLoader 21 | from torchvision import transforms as T 22 | 23 | from .dataset import build_tree_dataset, LmdbDataset 24 | 25 | 26 | class SceneTextDataModule(pl.LightningDataModule): 27 | TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80') 28 | TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80') 29 | TEST_NEW = ('ArT', 'COCOv1.4', 'Uber') 30 | TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW)) 31 | 32 | def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int, 33 | charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool, 34 | remove_whitespace: bool = True, normalize_unicode: bool = True, 35 | min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None): 36 | super().__init__() 37 | self.root_dir = root_dir 38 | self.train_dir = train_dir 39 | self.img_size = tuple(img_size) 40 | self.max_label_length = max_label_length 41 | self.charset_train = charset_train 42 | self.charset_test = charset_test 43 | self.batch_size = batch_size 44 | self.num_workers = num_workers 45 | self.augment = augment 46 | self.remove_whitespace = remove_whitespace 47 | self.normalize_unicode = normalize_unicode 48 | self.min_image_dim = min_image_dim 49 | self.rotation = rotation 50 | self.collate_fn = collate_fn 51 | self._train_dataset = None 52 | self._val_dataset = None 53 | 54 | @staticmethod 55 | def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0): 56 | transforms = [] 57 | if augment: 58 | from .augment import rand_augment_transform 59 | transforms.append(rand_augment_transform()) 60 | if rotation: 61 | transforms.append(lambda img: img.rotate(rotation, expand=True)) 62 | transforms.extend([ 63 | T.Resize(img_size, T.InterpolationMode.BICUBIC), 64 | T.ToTensor(), 65 | T.Normalize(0.5, 0.5) 66 | ]) 67 | return T.Compose(transforms) 68 | 69 | @property 70 | def train_dataset(self): 71 | if self._train_dataset is None: 72 | transform = self.get_transform(self.img_size, self.augment) 73 | root = PurePath(self.root_dir, 'train', self.train_dir) 74 | self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length, 75 | self.min_image_dim, self.remove_whitespace, self.normalize_unicode, 76 | transform=transform) 77 | return self._train_dataset 78 | 79 | @property 80 | def val_dataset(self): 81 | if self._val_dataset is None: 82 | transform = self.get_transform(self.img_size) 83 | root = PurePath(self.root_dir, 'val') 84 | self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length, 85 | self.min_image_dim, self.remove_whitespace, self.normalize_unicode, 86 | transform=transform) 87 | return self._val_dataset 88 | 89 | def train_dataloader(self): 90 | return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, 91 | num_workers=self.num_workers, persistent_workers=self.num_workers > 0, 92 | pin_memory=True, collate_fn=self.collate_fn) 93 | 94 | def val_dataloader(self): 95 | return DataLoader(self.val_dataset, batch_size=self.batch_size, 96 | num_workers=self.num_workers, persistent_workers=self.num_workers > 0, 97 | pin_memory=True, collate_fn=self.collate_fn) 98 | 99 | def test_dataloaders(self, subset): 100 | transform = self.get_transform(self.img_size, rotation=self.rotation) 101 | root = PurePath(self.root_dir, 'test') 102 | datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length, 103 | self.min_image_dim, self.remove_whitespace, self.normalize_unicode, 104 | transform=transform) for s in subset} 105 | return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers, 106 | pin_memory=True, collate_fn=self.collate_fn) 107 | for k, v in datasets.items()} 108 | -------------------------------------------------------------------------------- /src/parseq/strhub/data/utils.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | from abc import ABC, abstractmethod 18 | from itertools import groupby 19 | from typing import List, Optional, Tuple 20 | 21 | import torch 22 | from torch import Tensor 23 | from torch.nn.utils.rnn import pad_sequence 24 | 25 | 26 | class CharsetAdapter: 27 | """Transforms labels according to the target charset.""" 28 | 29 | def __init__(self, target_charset) -> None: 30 | super().__init__() 31 | self.lowercase_only = target_charset == target_charset.lower() 32 | self.uppercase_only = target_charset == target_charset.upper() 33 | self.unsupported = f'[^{re.escape(target_charset)}]' 34 | 35 | def __call__(self, label): 36 | if self.lowercase_only: 37 | label = label.lower() 38 | elif self.uppercase_only: 39 | label = label.upper() 40 | # Remove unsupported characters 41 | label = re.sub(self.unsupported, '', label) 42 | return label 43 | 44 | 45 | class BaseTokenizer(ABC): 46 | 47 | def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: 48 | self._itos = specials_first + tuple(charset) + specials_last 49 | self._stoi = {s: i for i, s in enumerate(self._itos)} 50 | 51 | def __len__(self): 52 | return len(self._itos) 53 | 54 | def _tok2ids(self, tokens: str) -> List[int]: 55 | return [self._stoi[s] for s in tokens] 56 | 57 | def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: 58 | tokens = [self._itos[i] for i in token_ids] 59 | return ''.join(tokens) if join else tokens 60 | 61 | @abstractmethod 62 | def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: 63 | """Encode a batch of labels to a representation suitable for the model. 64 | 65 | Args: 66 | labels: List of labels. Each can be of arbitrary length. 67 | device: Create tensor on this device. 68 | 69 | Returns: 70 | Batched tensor representation padded to the max label length. Shape: N, L 71 | """ 72 | raise NotImplementedError 73 | 74 | @abstractmethod 75 | def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: 76 | """Internal method which performs the necessary filtering prior to decoding.""" 77 | raise NotImplementedError 78 | 79 | def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: 80 | """Decode a batch of token distributions. 81 | 82 | Args: 83 | token_dists: softmax probabilities over the token distribution. Shape: N, L, C 84 | raw: return unprocessed labels (will return list of list of strings) 85 | 86 | Returns: 87 | list of string labels (arbitrary length) and 88 | their corresponding sequence probabilities as a list of Tensors 89 | """ 90 | batch_tokens = [] 91 | batch_probs = [] 92 | for dist in token_dists: 93 | probs, ids = dist.max(-1) # greedy selection 94 | if not raw: 95 | probs, ids = self._filter(probs, ids) 96 | tokens = self._ids2tok(ids, not raw) 97 | batch_tokens.append(tokens) 98 | batch_probs.append(probs) 99 | return batch_tokens, batch_probs 100 | 101 | 102 | class Tokenizer(BaseTokenizer): 103 | BOS = '[B]' 104 | EOS = '[E]' 105 | PAD = '[P]' 106 | 107 | def __init__(self, charset: str) -> None: 108 | specials_first = (self.EOS,) 109 | specials_last = (self.BOS, self.PAD) 110 | super().__init__(charset, specials_first, specials_last) 111 | self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] 112 | 113 | def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: 114 | batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) 115 | for y in labels] 116 | 117 | return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) 118 | 119 | def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: 120 | ids = ids.tolist() 121 | try: 122 | eos_idx = ids.index(self.eos_id) 123 | except ValueError: 124 | eos_idx = len(ids) # Nothing to truncate. 125 | # Truncate after EOS 126 | ids = ids[:eos_idx] 127 | probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) 128 | return probs, ids 129 | 130 | 131 | class CTCTokenizer(BaseTokenizer): 132 | BLANK = '[B]' 133 | 134 | def __init__(self, charset: str) -> None: 135 | # BLANK uses index == 0 by default 136 | super().__init__(charset, specials_first=(self.BLANK,)) 137 | self.blank_id = self._stoi[self.BLANK] 138 | 139 | def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: 140 | # We use a padded representation since we don't want to use CUDNN's CTC implementation 141 | batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels] 142 | return pad_sequence(batch, batch_first=True, padding_value=self.blank_id) 143 | 144 | def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: 145 | # Best path decoding: 146 | ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens 147 | ids = [x for x in ids if x != self.blank_id] # Remove BLANKs 148 | # `probs` is just pass-through since all positions are considered part of the path 149 | return probs, ids 150 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/strhub/models/__init__.py -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/LICENSE: -------------------------------------------------------------------------------- 1 | ABINet for non-commercial purposes 2 | 3 | Copyright (c) 2021, USTC 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Fang, Shancheng, Hongtao, Xie, Yuxin, Wang, Zhendong, Mao, and Yongdong, Zhang. 3 | "Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition." . 4 | In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 7098-7107).2021. 5 | 6 | https://arxiv.org/abs/2103.06495 7 | 8 | All source files, except `system.py`, are based on the implementation listed below, 9 | and hence are released under the license of the original. 10 | 11 | Source: https://github.com/FangShancheng/ABINet 12 | License: 2-clause BSD License (see included LICENSE file) 13 | """ 14 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .transformer import PositionalEncoding 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, in_channels=512, max_length=25, n_feature=256): 9 | super().__init__() 10 | self.max_length = max_length 11 | 12 | self.f0_embedding = nn.Embedding(max_length, in_channels) 13 | self.w0 = nn.Linear(max_length, n_feature) 14 | self.wv = nn.Linear(in_channels, in_channels) 15 | self.we = nn.Linear(in_channels, max_length) 16 | 17 | self.active = nn.Tanh() 18 | self.softmax = nn.Softmax(dim=2) 19 | 20 | def forward(self, enc_output): 21 | enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) 22 | reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) 23 | reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) 24 | reading_order_embed = self.f0_embedding(reading_order) # b,25,512 25 | 26 | t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 27 | t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 28 | 29 | attn = self.we(t) # b,256,25 30 | attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 31 | g_output = torch.bmm(attn, enc_output) # b,25,512 32 | return g_output, attn.view(*attn.shape[:2], 8, 32) 33 | 34 | 35 | def encoder_layer(in_c, out_c, k=3, s=2, p=1): 36 | return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), 37 | nn.BatchNorm2d(out_c), 38 | nn.ReLU(True)) 39 | 40 | 41 | def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): 42 | align_corners = None if mode == 'nearest' else True 43 | return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, 44 | mode=mode, align_corners=align_corners), 45 | nn.Conv2d(in_c, out_c, k, s, p), 46 | nn.BatchNorm2d(out_c), 47 | nn.ReLU(True)) 48 | 49 | 50 | class PositionAttention(nn.Module): 51 | def __init__(self, max_length, in_channels=512, num_channels=64, 52 | h=8, w=32, mode='nearest', **kwargs): 53 | super().__init__() 54 | self.max_length = max_length 55 | self.k_encoder = nn.Sequential( 56 | encoder_layer(in_channels, num_channels, s=(1, 2)), 57 | encoder_layer(num_channels, num_channels, s=(2, 2)), 58 | encoder_layer(num_channels, num_channels, s=(2, 2)), 59 | encoder_layer(num_channels, num_channels, s=(2, 2)) 60 | ) 61 | self.k_decoder = nn.Sequential( 62 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 63 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 64 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 65 | decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) 66 | ) 67 | 68 | self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length) 69 | self.project = nn.Linear(in_channels, in_channels) 70 | 71 | def forward(self, x): 72 | N, E, H, W = x.size() 73 | k, v = x, x # (N, E, H, W) 74 | 75 | # calculate key vector 76 | features = [] 77 | for i in range(0, len(self.k_encoder)): 78 | k = self.k_encoder[i](k) 79 | features.append(k) 80 | for i in range(0, len(self.k_decoder) - 1): 81 | k = self.k_decoder[i](k) 82 | k = k + features[len(self.k_decoder) - 2 - i] 83 | k = self.k_decoder[-1](k) 84 | 85 | # calculate query vector 86 | # TODO q=f(q,k) 87 | zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) 88 | q = self.pos_encoder(zeros) # (T, N, E) 89 | q = q.permute(1, 0, 2) # (N, T, E) 90 | q = self.project(q) # (N, T, E) 91 | 92 | # calculate attention 93 | attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) 94 | attn_scores = attn_scores / (E ** 0.5) 95 | attn_scores = torch.softmax(attn_scores, dim=-1) 96 | 97 | v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) 98 | attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) 99 | 100 | return attn_vecs, attn_scores.view(N, -1, H, W) 101 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import TransformerEncoderLayer, TransformerEncoder 3 | 4 | from .resnet import resnet45 5 | from .transformer import PositionalEncoding 6 | 7 | 8 | class ResTranformer(nn.Module): 9 | def __init__(self, d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', backbone_ln=2): 10 | super().__init__() 11 | self.resnet = resnet45() 12 | self.pos_encoder = PositionalEncoding(d_model, max_len=8 * 32) 13 | encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, 14 | dim_feedforward=d_inner, dropout=dropout, activation=activation) 15 | self.transformer = TransformerEncoder(encoder_layer, backbone_ln) 16 | 17 | def forward(self, images): 18 | feature = self.resnet(images) 19 | n, c, h, w = feature.shape 20 | feature = feature.view(n, c, -1).permute(2, 0, 1) 21 | feature = self.pos_encoder(feature) 22 | feature = self.transformer(feature) 23 | feature = feature.permute(1, 2, 0).view(n, c, h, w) 24 | return feature 25 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Model(nn.Module): 6 | 7 | def __init__(self, dataset_max_length: int, null_label: int): 8 | super().__init__() 9 | self.max_length = dataset_max_length + 1 # additional stop token 10 | self.null_label = null_label 11 | 12 | def _get_length(self, logit, dim=-1): 13 | """ Greed decoder to obtain length from logit""" 14 | out = (logit.argmax(dim=-1) == self.null_label) 15 | abn = out.any(dim) 16 | out = ((out.cumsum(dim) == 1) & out).max(dim)[1] 17 | out = out + 1 # additional end token 18 | out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device)) 19 | return out 20 | 21 | @staticmethod 22 | def _get_padding_mask(length, max_length): 23 | length = length.unsqueeze(-1) 24 | grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) 25 | return grid >= length 26 | 27 | @staticmethod 28 | def _get_location_mask(sz, device=None): 29 | mask = torch.eye(sz, device=device) 30 | mask = mask.float().masked_fill(mask == 1, float('-inf')) 31 | return mask 32 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/model_abinet_iter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .model_alignment import BaseAlignment 5 | from .model_language import BCNLanguage 6 | from .model_vision import BaseVision 7 | 8 | 9 | class ABINetIterModel(nn.Module): 10 | def __init__(self, dataset_max_length, null_label, num_classes, iter_size=1, 11 | d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', 12 | v_loss_weight=1., v_attention='position', v_attention_mode='nearest', 13 | v_backbone='transformer', v_num_layers=2, 14 | l_loss_weight=1., l_num_layers=4, l_detach=True, l_use_self_attn=False, 15 | a_loss_weight=1.): 16 | super().__init__() 17 | self.iter_size = iter_size 18 | self.vision = BaseVision(dataset_max_length, null_label, num_classes, v_attention, v_attention_mode, 19 | v_loss_weight, d_model, nhead, d_inner, dropout, activation, v_backbone, v_num_layers) 20 | self.language = BCNLanguage(dataset_max_length, null_label, num_classes, d_model, nhead, d_inner, dropout, 21 | activation, l_num_layers, l_detach, l_use_self_attn, l_loss_weight) 22 | self.alignment = BaseAlignment(dataset_max_length, null_label, num_classes, d_model, a_loss_weight) 23 | 24 | def forward(self, images): 25 | v_res = self.vision(images) 26 | a_res = v_res 27 | all_l_res, all_a_res = [], [] 28 | for _ in range(self.iter_size): 29 | tokens = torch.softmax(a_res['logits'], dim=-1) 30 | lengths = a_res['pt_lengths'] 31 | lengths.clamp_(2, self.language.max_length) # TODO:move to langauge model 32 | l_res = self.language(tokens, lengths) 33 | all_l_res.append(l_res) 34 | a_res = self.alignment(l_res['feature'], v_res['feature']) 35 | all_a_res.append(a_res) 36 | if self.training: 37 | return all_a_res, all_l_res, v_res 38 | else: 39 | return a_res, all_l_res[-1], v_res 40 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/model_alignment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .model import Model 5 | 6 | 7 | class BaseAlignment(Model): 8 | def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, loss_weight=1.0): 9 | super().__init__(dataset_max_length, null_label) 10 | self.loss_weight = loss_weight 11 | self.w_att = nn.Linear(2 * d_model, d_model) 12 | self.cls = nn.Linear(d_model, num_classes) 13 | 14 | def forward(self, l_feature, v_feature): 15 | """ 16 | Args: 17 | l_feature: (N, T, E) where T is length, N is batch size and d is dim of model 18 | v_feature: (N, T, E) shape the same as l_feature 19 | """ 20 | f = torch.cat((l_feature, v_feature), dim=2) 21 | f_att = torch.sigmoid(self.w_att(f)) 22 | output = f_att * v_feature + (1 - f_att) * l_feature 23 | 24 | logits = self.cls(output) # (N, T, C) 25 | pt_lengths = self._get_length(logits) 26 | 27 | return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight': self.loss_weight, 28 | 'name': 'alignment'} 29 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/model_language.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import TransformerDecoder 3 | 4 | from .model import Model 5 | from .transformer import PositionalEncoding, TransformerDecoderLayer 6 | 7 | 8 | class BCNLanguage(Model): 9 | def __init__(self, dataset_max_length, null_label, num_classes, d_model=512, nhead=8, d_inner=2048, dropout=0.1, 10 | activation='relu', num_layers=4, detach=True, use_self_attn=False, loss_weight=1.0, 11 | global_debug=False): 12 | super().__init__(dataset_max_length, null_label) 13 | self.detach = detach 14 | self.loss_weight = loss_weight 15 | self.proj = nn.Linear(num_classes, d_model, False) 16 | self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) 17 | self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) 18 | decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, 19 | activation, self_attn=use_self_attn, debug=global_debug) 20 | self.model = TransformerDecoder(decoder_layer, num_layers) 21 | self.cls = nn.Linear(d_model, num_classes) 22 | 23 | def forward(self, tokens, lengths): 24 | """ 25 | Args: 26 | tokens: (N, T, C) where T is length, N is batch size and C is classes number 27 | lengths: (N,) 28 | """ 29 | if self.detach: 30 | tokens = tokens.detach() 31 | embed = self.proj(tokens) # (N, T, E) 32 | embed = embed.permute(1, 0, 2) # (T, N, E) 33 | embed = self.token_encoder(embed) # (T, N, E) 34 | padding_mask = self._get_padding_mask(lengths, self.max_length) 35 | 36 | zeros = embed.new_zeros(*embed.shape) 37 | qeury = self.pos_encoder(zeros) 38 | location_mask = self._get_location_mask(self.max_length, tokens.device) 39 | output = self.model(qeury, embed, 40 | tgt_key_padding_mask=padding_mask, 41 | memory_mask=location_mask, 42 | memory_key_padding_mask=padding_mask) # (T, N, E) 43 | output = output.permute(1, 0, 2) # (N, T, E) 44 | 45 | logits = self.cls(output) # (N, T, C) 46 | pt_lengths = self._get_length(logits) 47 | 48 | res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, 49 | 'loss_weight': self.loss_weight, 'name': 'language'} 50 | return res 51 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/model_vision.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .attention import PositionAttention, Attention 4 | from .backbone import ResTranformer 5 | from .model import Model 6 | from .resnet import resnet45 7 | 8 | 9 | class BaseVision(Model): 10 | def __init__(self, dataset_max_length, null_label, num_classes, 11 | attention='position', attention_mode='nearest', loss_weight=1.0, 12 | d_model=512, nhead=8, d_inner=2048, dropout=0.1, activation='relu', 13 | backbone='transformer', backbone_ln=2): 14 | super().__init__(dataset_max_length, null_label) 15 | self.loss_weight = loss_weight 16 | self.out_channels = d_model 17 | 18 | if backbone == 'transformer': 19 | self.backbone = ResTranformer(d_model, nhead, d_inner, dropout, activation, backbone_ln) 20 | else: 21 | self.backbone = resnet45() 22 | 23 | if attention == 'position': 24 | self.attention = PositionAttention( 25 | max_length=self.max_length, 26 | mode=attention_mode 27 | ) 28 | elif attention == 'attention': 29 | self.attention = Attention( 30 | max_length=self.max_length, 31 | n_feature=8 * 32, 32 | ) 33 | else: 34 | raise ValueError(f'invalid attention: {attention}') 35 | 36 | self.cls = nn.Linear(self.out_channels, num_classes) 37 | 38 | def forward(self, images): 39 | features = self.backbone(images) # (N, E, H, W) 40 | attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) 41 | logits = self.cls(attn_vecs) # (N, T, C) 42 | pt_lengths = self._get_length(logits) 43 | 44 | return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, 45 | 'attn_scores': attn_scores, 'loss_weight': self.loss_weight, 'name': 'vision'} 46 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/abinet/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Callable 3 | 4 | import torch.nn as nn 5 | from torchvision.models import resnet 6 | 7 | 8 | class BasicBlock(resnet.BasicBlock): 9 | 10 | def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, 11 | groups: int = 1, base_width: int = 64, dilation: int = 1, 12 | norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: 13 | super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer) 14 | self.conv1 = resnet.conv1x1(inplanes, planes) 15 | self.conv2 = resnet.conv3x3(planes, planes, stride) 16 | 17 | 18 | class ResNet(nn.Module): 19 | 20 | def __init__(self, block, layers): 21 | super().__init__() 22 | self.inplanes = 32 23 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, 24 | bias=False) 25 | self.bn1 = nn.BatchNorm2d(32) 26 | self.relu = nn.ReLU(inplace=True) 27 | 28 | self.layer1 = self._make_layer(block, 32, layers[0], stride=2) 29 | self.layer2 = self._make_layer(block, 64, layers[1], stride=1) 30 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 31 | self.layer4 = self._make_layer(block, 256, layers[3], stride=1) 32 | self.layer5 = self._make_layer(block, 512, layers[4], stride=1) 33 | 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | m.weight.data.normal_(0, math.sqrt(2. / n)) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | m.weight.data.fill_(1) 40 | m.bias.data.zero_() 41 | 42 | def _make_layer(self, block, planes, blocks, stride=1): 43 | downsample = None 44 | if stride != 1 or self.inplanes != planes * block.expansion: 45 | downsample = nn.Sequential( 46 | nn.Conv2d(self.inplanes, planes * block.expansion, 47 | kernel_size=1, stride=stride, bias=False), 48 | nn.BatchNorm2d(planes * block.expansion), 49 | ) 50 | 51 | layers = [] 52 | layers.append(block(self.inplanes, planes, stride, downsample)) 53 | self.inplanes = planes * block.expansion 54 | for i in range(1, blocks): 55 | layers.append(block(self.inplanes, planes)) 56 | 57 | return nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | x = self.conv1(x) 61 | x = self.bn1(x) 62 | x = self.relu(x) 63 | x = self.layer1(x) 64 | x = self.layer2(x) 65 | x = self.layer3(x) 66 | x = self.layer4(x) 67 | x = self.layer5(x) 68 | return x 69 | 70 | 71 | def resnet45(): 72 | return ResNet(BasicBlock, [3, 4, 6, 6, 3]) 73 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/crnn/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017 Jieru Mei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/crnn/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Shi, Baoguang, Xiang Bai, and Cong Yao. 3 | "An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition." 4 | IEEE transactions on pattern analysis and machine intelligence 39, no. 11 (2016): 2298-2304. 5 | 6 | https://arxiv.org/abs/1507.05717 7 | 8 | All source files, except `system.py`, are based on the implementation listed below, 9 | and hence are released under the license of the original. 10 | 11 | Source: https://github.com/meijieru/crnn.pytorch 12 | License: MIT License (see included LICENSE file) 13 | """ 14 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/crnn/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from strhub.models.modules import BidirectionalLSTM 4 | 5 | 6 | class CRNN(nn.Module): 7 | 8 | def __init__(self, img_h, nc, nclass, nh, leaky_relu=False): 9 | super().__init__() 10 | assert img_h % 16 == 0, 'img_h has to be a multiple of 16' 11 | 12 | ks = [3, 3, 3, 3, 3, 3, 2] 13 | ps = [1, 1, 1, 1, 1, 1, 0] 14 | ss = [1, 1, 1, 1, 1, 1, 1] 15 | nm = [64, 128, 256, 256, 512, 512, 512] 16 | 17 | cnn = nn.Sequential() 18 | 19 | def convRelu(i, batchNormalization=False): 20 | nIn = nc if i == 0 else nm[i - 1] 21 | nOut = nm[i] 22 | cnn.add_module('conv{0}'.format(i), 23 | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization)) 24 | if batchNormalization: 25 | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) 26 | if leaky_relu: 27 | cnn.add_module('relu{0}'.format(i), 28 | nn.LeakyReLU(0.2, inplace=True)) 29 | else: 30 | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) 31 | 32 | convRelu(0) 33 | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 34 | convRelu(1) 35 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 36 | convRelu(2, True) 37 | convRelu(3) 38 | cnn.add_module('pooling{0}'.format(2), 39 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 40 | convRelu(4, True) 41 | convRelu(5) 42 | cnn.add_module('pooling{0}'.format(3), 43 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 44 | convRelu(6, True) # 512x1x16 45 | 46 | self.cnn = cnn 47 | self.rnn = nn.Sequential( 48 | BidirectionalLSTM(512, nh, nh), 49 | BidirectionalLSTM(nh, nh, nclass)) 50 | 51 | def forward(self, input): 52 | # conv features 53 | conv = self.cnn(input) 54 | b, c, h, w = conv.size() 55 | assert h == 1, 'the height of conv must be 1' 56 | conv = conv.squeeze(2) 57 | conv = conv.transpose(1, 2) # [b, w, c] 58 | 59 | # rnn features 60 | output = self.rnn(conv) 61 | 62 | return output 63 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/crnn/system.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Sequence, Optional 17 | 18 | from pytorch_lightning.utilities.types import STEP_OUTPUT 19 | from torch import Tensor 20 | 21 | from strhub.models.base import CTCSystem 22 | from strhub.models.utils import init_weights 23 | from .model import CRNN as Model 24 | 25 | 26 | class CRNN(CTCSystem): 27 | 28 | def __init__(self, charset_train: str, charset_test: str, max_label_length: int, 29 | batch_size: int, lr: float, warmup_pct: float, weight_decay: float, 30 | img_size: Sequence[int], hidden_size: int, leaky_relu: bool, **kwargs) -> None: 31 | super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) 32 | self.save_hyperparameters() 33 | self.model = Model(img_size[0], 3, len(self.tokenizer), hidden_size, leaky_relu) 34 | self.model.apply(init_weights) 35 | 36 | def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: 37 | return self.model.forward(images) 38 | 39 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 40 | images, labels = batch 41 | loss = self.forward_logits_loss(images, labels)[1] 42 | self.log('loss', loss) 43 | return loss 44 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/modules.py: -------------------------------------------------------------------------------- 1 | r"""Shared modules used by CRNN and TRBA""" 2 | from torch import nn 3 | 4 | 5 | class BidirectionalLSTM(nn.Module): 6 | """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py""" 7 | 8 | def __init__(self, input_size, hidden_size, output_size): 9 | super().__init__() 10 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 11 | self.linear = nn.Linear(hidden_size * 2, output_size) 12 | 13 | def forward(self, input): 14 | """ 15 | input : visual feature [batch_size x T x input_size], T = num_steps. 16 | output : contextual feature [batch_size x T x output_size] 17 | """ 18 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 19 | output = self.linear(recurrent) # batch_size x T x output_size 20 | return output 21 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/parseq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeGoat24/DreamText/5b3b7e8fa6cec94a10138fa181644bc4192f9778/src/parseq/strhub/models/parseq/__init__.py -------------------------------------------------------------------------------- /src/parseq/strhub/models/parseq/modules.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | from typing import Optional 18 | 19 | import torch 20 | from torch import nn as nn, Tensor 21 | from torch.nn import functional as F 22 | from torch.nn.modules import transformer 23 | 24 | from timm.models.vision_transformer import VisionTransformer, PatchEmbed 25 | 26 | 27 | class DecoderLayer(nn.Module): 28 | """A Transformer decoder layer supporting two-stream attention (XLNet) 29 | This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" 30 | 31 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', 32 | layer_norm_eps=1e-5): 33 | super().__init__() 34 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 35 | self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 36 | # Implementation of Feedforward model 37 | self.linear1 = nn.Linear(d_model, dim_feedforward) 38 | self.dropout = nn.Dropout(dropout) 39 | self.linear2 = nn.Linear(dim_feedforward, d_model) 40 | 41 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 42 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 43 | self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) 44 | self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) 45 | self.dropout1 = nn.Dropout(dropout) 46 | self.dropout2 = nn.Dropout(dropout) 47 | self.dropout3 = nn.Dropout(dropout) 48 | 49 | self.activation = transformer._get_activation_fn(activation) 50 | 51 | def __setstate__(self, state): 52 | if 'activation' not in state: 53 | state['activation'] = F.gelu 54 | super().__setstate__(state) 55 | 56 | def forward_stream(self, tgt: Tensor, tgt_norm: Tensor, tgt_kv: Tensor, memory: Tensor, tgt_mask: Optional[Tensor], 57 | tgt_key_padding_mask: Optional[Tensor]): 58 | """Forward pass for a single stream (i.e. content or query) 59 | tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. 60 | Both tgt_kv and memory are expected to be LayerNorm'd too. 61 | memory is LayerNorm'd by ViT. 62 | """ 63 | tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, 64 | key_padding_mask=tgt_key_padding_mask) 65 | tgt = tgt + self.dropout1(tgt2) 66 | 67 | tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) 68 | tgt = tgt + self.dropout2(tgt2) 69 | 70 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt))))) 71 | tgt = tgt + self.dropout3(tgt2) 72 | return tgt, sa_weights, ca_weights 73 | 74 | def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, 75 | content_key_padding_mask: Optional[Tensor] = None, update_content: bool = True): 76 | query_norm = self.norm_q(query) 77 | content_norm = self.norm_c(content) 78 | query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0] 79 | if update_content: 80 | content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, 81 | content_key_padding_mask)[0] 82 | return query, content 83 | 84 | 85 | class Decoder(nn.Module): 86 | __constants__ = ['norm'] 87 | 88 | def __init__(self, decoder_layer, num_layers, norm): 89 | super().__init__() 90 | self.layers = transformer._get_clones(decoder_layer, num_layers) 91 | self.num_layers = num_layers 92 | self.norm = norm 93 | 94 | def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, 95 | content_key_padding_mask: Optional[Tensor] = None): 96 | for i, mod in enumerate(self.layers): 97 | last = i == len(self.layers) - 1 98 | query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, 99 | update_content=not last) 100 | query = self.norm(query) 101 | return query 102 | 103 | 104 | class Encoder(VisionTransformer): 105 | 106 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 107 | qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed): 108 | super().__init__(img_size, patch_size, in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, 109 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 110 | drop_path_rate=drop_path_rate, embed_layer=embed_layer, 111 | num_classes=0, global_pool='', class_token=False) # these disable the classifier head 112 | 113 | def forward(self, x): 114 | # Return all tokens 115 | return self.forward_features(x) 116 | 117 | 118 | class TokenEmbedding(nn.Module): 119 | 120 | def __init__(self, charset_size: int, embed_dim: int): 121 | super().__init__() 122 | self.embedding = nn.Embedding(charset_size, embed_dim) 123 | self.embed_dim = embed_dim 124 | 125 | def forward(self, tokens: torch.Tensor): 126 | return math.sqrt(self.embed_dim) * self.embedding(tokens) 127 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/trba/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Baek, Jeonghun, Geewook Kim, Junyeop Lee, Sungrae Park, Dongyoon Han, Sangdoo Yun, Seong Joon Oh, and Hwalsuk Lee. 3 | "What is wrong with scene text recognition model comparisons? dataset and model analysis." 4 | In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4715-4723. 2019. 5 | 6 | https://arxiv.org/abs/1904.01906 7 | 8 | All source files, except `system.py`, are based on the implementation listed below, 9 | and hence are released under the license of the original. 10 | 11 | Source: https://github.com/clovaai/deep-text-recognition-benchmark 12 | License: Apache License 2.0 (see LICENSE file in project root) 13 | """ 14 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/trba/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torchvision.models.resnet import BasicBlock 4 | 5 | 6 | class ResNet_FeatureExtractor(nn.Module): 7 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 8 | 9 | def __init__(self, input_channel, output_channel=512): 10 | super().__init__() 11 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 12 | 13 | def forward(self, input): 14 | return self.ConvNet(input) 15 | 16 | 17 | class ResNet(nn.Module): 18 | 19 | def __init__(self, input_channel, output_channel, block, layers): 20 | super().__init__() 21 | 22 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 23 | 24 | self.inplanes = int(output_channel / 8) 25 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 26 | kernel_size=3, stride=1, padding=1, bias=False) 27 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 28 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 29 | kernel_size=3, stride=1, padding=1, bias=False) 30 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 31 | self.relu = nn.ReLU(inplace=True) 32 | 33 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 34 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 35 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 36 | 0], kernel_size=3, stride=1, padding=1, bias=False) 37 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 38 | 39 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 40 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 41 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 42 | 1], kernel_size=3, stride=1, padding=1, bias=False) 43 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 44 | 45 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 46 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 47 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 48 | 2], kernel_size=3, stride=1, padding=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 50 | 51 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 52 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 53 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 54 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 55 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 56 | 3], kernel_size=2, stride=1, padding=0, bias=False) 57 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 58 | 59 | def _make_layer(self, block, planes, blocks, stride=1): 60 | downsample = None 61 | if stride != 1 or self.inplanes != planes * block.expansion: 62 | downsample = nn.Sequential( 63 | nn.Conv2d(self.inplanes, planes * block.expansion, 64 | kernel_size=1, stride=stride, bias=False), 65 | nn.BatchNorm2d(planes * block.expansion), 66 | ) 67 | 68 | layers = [] 69 | layers.append(block(self.inplanes, planes, stride, downsample)) 70 | self.inplanes = planes * block.expansion 71 | for i in range(1, blocks): 72 | layers.append(block(self.inplanes, planes)) 73 | 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | x = self.conv0_1(x) 78 | x = self.bn0_1(x) 79 | x = self.relu(x) 80 | x = self.conv0_2(x) 81 | x = self.bn0_2(x) 82 | x = self.relu(x) 83 | 84 | x = self.maxpool1(x) 85 | x = self.layer1(x) 86 | x = self.conv1(x) 87 | x = self.bn1(x) 88 | x = self.relu(x) 89 | 90 | x = self.maxpool2(x) 91 | x = self.layer2(x) 92 | x = self.conv2(x) 93 | x = self.bn2(x) 94 | x = self.relu(x) 95 | 96 | x = self.maxpool3(x) 97 | x = self.layer3(x) 98 | x = self.conv3(x) 99 | x = self.bn3(x) 100 | x = self.relu(x) 101 | 102 | x = self.layer4(x) 103 | x = self.conv4_1(x) 104 | x = self.bn4_1(x) 105 | x = self.relu(x) 106 | x = self.conv4_2(x) 107 | x = self.bn4_2(x) 108 | x = self.relu(x) 109 | 110 | return x 111 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/trba/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from strhub.models.modules import BidirectionalLSTM 4 | from .feature_extraction import ResNet_FeatureExtractor 5 | from .prediction import Attention 6 | from .transformation import TPS_SpatialTransformerNetwork 7 | 8 | 9 | class TRBA(nn.Module): 10 | 11 | def __init__(self, img_h, img_w, num_class, num_fiducial=20, input_channel=3, output_channel=512, hidden_size=256, 12 | use_ctc=False): 13 | super().__init__() 14 | """ Transformation """ 15 | self.Transformation = TPS_SpatialTransformerNetwork( 16 | F=num_fiducial, I_size=(img_h, img_w), I_r_size=(img_h, img_w), 17 | I_channel_num=input_channel) 18 | 19 | """ FeatureExtraction """ 20 | self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel) 21 | self.FeatureExtraction_output = output_channel 22 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 23 | 24 | """ Sequence modeling""" 25 | self.SequenceModeling = nn.Sequential( 26 | BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), 27 | BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) 28 | self.SequenceModeling_output = hidden_size 29 | 30 | """ Prediction """ 31 | if use_ctc: 32 | self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) 33 | else: 34 | self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class) 35 | 36 | def forward(self, image, max_label_length, text=None): 37 | """ Transformation stage """ 38 | image = self.Transformation(image) 39 | 40 | """ Feature extraction stage """ 41 | visual_feature = self.FeatureExtraction(image) 42 | visual_feature = visual_feature.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h] 43 | visual_feature = self.AdaptiveAvgPool(visual_feature) # [b, w, c, h] -> [b, w, c, 1] 44 | visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c] 45 | 46 | """ Sequence modeling stage """ 47 | contextual_feature = self.SequenceModeling(visual_feature) # [b, num_steps, hidden_size] 48 | 49 | """ Prediction stage """ 50 | if isinstance(self.Prediction, Attention): 51 | prediction = self.Prediction(contextual_feature.contiguous(), text, max_label_length) 52 | else: 53 | prediction = self.Prediction(contextual_feature.contiguous()) # CTC 54 | 55 | return prediction # [b, num_steps, num_class] 56 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/trba/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | 8 | def __init__(self, input_size, hidden_size, num_class, num_char_embeddings=256): 9 | super().__init__() 10 | self.attention_cell = AttentionCell(input_size, hidden_size, num_char_embeddings) 11 | self.hidden_size = hidden_size 12 | self.num_class = num_class 13 | self.generator = nn.Linear(hidden_size, num_class) 14 | self.char_embeddings = nn.Embedding(num_class, num_char_embeddings) 15 | 16 | def forward(self, batch_H, text, max_label_length=25): 17 | """ 18 | input: 19 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_class] 20 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS]. 21 | output: probability distribution at each step [batch_size x num_steps x num_class] 22 | """ 23 | batch_size = batch_H.size(0) 24 | num_steps = max_label_length + 1 # +1 for [EOS] at end of sentence. 25 | 26 | output_hiddens = batch_H.new_zeros((batch_size, num_steps, self.hidden_size), dtype=torch.float) 27 | hidden = (batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float), 28 | batch_H.new_zeros((batch_size, self.hidden_size), dtype=torch.float)) 29 | 30 | if self.training: 31 | for i in range(num_steps): 32 | char_embeddings = self.char_embeddings(text[:, i]) 33 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1}) 34 | hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) 35 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 36 | probs = self.generator(output_hiddens) 37 | 38 | else: 39 | targets = text[0].expand(batch_size) # should be fill with [SOS] token 40 | probs = batch_H.new_zeros((batch_size, num_steps, self.num_class), dtype=torch.float) 41 | 42 | for i in range(num_steps): 43 | char_embeddings = self.char_embeddings(targets) 44 | hidden, alpha = self.attention_cell(hidden, batch_H, char_embeddings) 45 | probs_step = self.generator(hidden[0]) 46 | probs[:, i, :] = probs_step 47 | _, next_input = probs_step.max(1) 48 | targets = next_input 49 | 50 | return probs # batch_size x num_steps x num_class 51 | 52 | 53 | class AttentionCell(nn.Module): 54 | 55 | def __init__(self, input_size, hidden_size, num_embeddings): 56 | super().__init__() 57 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 58 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 59 | self.score = nn.Linear(hidden_size, 1, bias=False) 60 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 61 | self.hidden_size = hidden_size 62 | 63 | def forward(self, prev_hidden, batch_H, char_embeddings): 64 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 65 | batch_H_proj = self.i2h(batch_H) 66 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 67 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 68 | 69 | alpha = F.softmax(e, dim=1) 70 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 71 | concat_context = torch.cat([context, char_embeddings], 1) # batch_size x (num_channel + num_embedding) 72 | cur_hidden = self.rnn(concat_context, prev_hidden) 73 | return cur_hidden, alpha 74 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/trba/system.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | from typing import Sequence, Any, Optional 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | from pytorch_lightning.utilities.types import STEP_OUTPUT 22 | from timm.models.helpers import named_apply 23 | from torch import Tensor 24 | 25 | from strhub.models.base import CrossEntropySystem, CTCSystem 26 | from strhub.models.utils import init_weights 27 | from .model import TRBA as Model 28 | 29 | 30 | class TRBA(CrossEntropySystem): 31 | 32 | def __init__(self, charset_train: str, charset_test: str, max_label_length: int, 33 | batch_size: int, lr: float, warmup_pct: float, weight_decay: float, 34 | img_size: Sequence[int], num_fiducial: int, output_channel: int, hidden_size: int, 35 | **kwargs: Any) -> None: 36 | super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) 37 | self.save_hyperparameters() 38 | self.max_label_length = max_label_length 39 | img_h, img_w = img_size 40 | self.model = Model(img_h, img_w, len(self.tokenizer), num_fiducial, 41 | output_channel=output_channel, hidden_size=hidden_size, use_ctc=False) 42 | named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model) 43 | 44 | @torch.jit.ignore 45 | def no_weight_decay(self): 46 | return {'model.Prediction.char_embeddings.weight'} 47 | 48 | def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: 49 | max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) 50 | text = images.new_full([1], self.bos_id, dtype=torch.long) 51 | return self.model.forward(images, max_length, text) 52 | 53 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 54 | images, labels = batch 55 | encoded = self.tokenizer.encode(labels, self.device) 56 | inputs = encoded[:, :-1] # remove 57 | targets = encoded[:, 1:] # remove 58 | max_length = encoded.shape[1] - 2 # exclude and from count 59 | logits = self.model.forward(images, max_length, inputs) 60 | loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) 61 | self.log('loss', loss) 62 | return loss 63 | 64 | 65 | class TRBC(CTCSystem): 66 | 67 | def __init__(self, charset_train: str, charset_test: str, max_label_length: int, 68 | batch_size: int, lr: float, warmup_pct: float, weight_decay: float, 69 | img_size: Sequence[int], num_fiducial: int, output_channel: int, hidden_size: int, 70 | **kwargs: Any) -> None: 71 | super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) 72 | self.save_hyperparameters() 73 | self.max_label_length = max_label_length 74 | img_h, img_w = img_size 75 | self.model = Model(img_h, img_w, len(self.tokenizer), num_fiducial, 76 | output_channel=output_channel, hidden_size=hidden_size, use_ctc=True) 77 | named_apply(partial(init_weights, exclude=['Transformation.LocalizationNetwork.localization_fc2']), self.model) 78 | 79 | def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: 80 | # max_label_length is unused in CTC prediction 81 | return self.model.forward(images, None) 82 | 83 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 84 | images, labels = batch 85 | loss = self.forward_logits_loss(images, labels)[1] 86 | self.log('loss', loss) 87 | return loss 88 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import PurePath 2 | from typing import Sequence 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import yaml 8 | 9 | 10 | class InvalidModelError(RuntimeError): 11 | """Exception raised for any model-related error (creation, loading)""" 12 | 13 | 14 | _WEIGHTS_URL = { 15 | 'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt', 16 | 'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', 17 | 'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt', 18 | 'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt', 19 | 'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt', 20 | 'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt', 21 | } 22 | 23 | 24 | def _get_config(experiment: str, **kwargs): 25 | """Emulates hydra config resolution""" 26 | root = PurePath(__file__).parents[2] 27 | with open(root / 'configs/main.yaml', 'r') as f: 28 | config = yaml.load(f, yaml.Loader)['model'] 29 | with open(root / f'configs/charset/94_full.yaml', 'r') as f: 30 | config.update(yaml.load(f, yaml.Loader)['model']) 31 | with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f: 32 | exp = yaml.load(f, yaml.Loader) 33 | # Apply base model config 34 | model = exp['defaults'][0]['override /model'] 35 | with open(root / f'configs/model/{model}.yaml', 'r') as f: 36 | config.update(yaml.load(f, yaml.Loader)) 37 | # Apply experiment config 38 | if 'model' in exp: 39 | config.update(exp['model']) 40 | config.update(kwargs) 41 | # Workaround for now: manually cast the lr to the correct type. 42 | config['lr'] = float(config['lr']) 43 | return config 44 | 45 | 46 | def _get_model_class(key): 47 | if 'abinet' in key: 48 | from .abinet.system import ABINet as ModelClass 49 | elif 'crnn' in key: 50 | from .crnn.system import CRNN as ModelClass 51 | elif 'parseq' in key: 52 | from .parseq.system import PARSeq as ModelClass 53 | elif 'trba' in key: 54 | from .trba.system import TRBA as ModelClass 55 | elif 'trbc' in key: 56 | from .trba.system import TRBC as ModelClass 57 | elif 'vitstr' in key: 58 | from .vitstr.system import ViTSTR as ModelClass 59 | else: 60 | raise InvalidModelError("Unable to find model class for '{}'".format(key)) 61 | return ModelClass 62 | 63 | 64 | def get_pretrained_weights(experiment): 65 | try: 66 | url = _WEIGHTS_URL[experiment] 67 | except KeyError: 68 | raise InvalidModelError("No pretrained weights found for '{}'".format(experiment)) from None 69 | return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True) 70 | 71 | 72 | def create_model(experiment: str, pretrained: bool = False, **kwargs): 73 | try: 74 | config = _get_config(experiment, **kwargs) 75 | except FileNotFoundError: 76 | raise InvalidModelError("No configuration found for '{}'".format(experiment)) from None 77 | ModelClass = _get_model_class(experiment) 78 | model = ModelClass(**config) 79 | if pretrained: 80 | model.load_state_dict(get_pretrained_weights(experiment)) 81 | return model 82 | 83 | 84 | def load_from_checkpoint(checkpoint_path: str, **kwargs): 85 | if checkpoint_path.startswith('pretrained='): 86 | model_id = checkpoint_path.split('=', maxsplit=1)[1] 87 | model = create_model(model_id, True, **kwargs) 88 | else: 89 | ModelClass = _get_model_class(checkpoint_path) 90 | model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs) 91 | return model 92 | 93 | 94 | def parse_model_args(args): 95 | kwargs = {} 96 | arg_types = {t.__name__: t for t in [int, float, str]} 97 | arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool 98 | for arg in args: 99 | name, value = arg.split('=', maxsplit=1) 100 | name, arg_type = name.split(':', maxsplit=1) 101 | kwargs[name] = arg_types[arg_type](value) 102 | return kwargs 103 | 104 | 105 | def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()): 106 | """Initialize the weights using the typical initialization schemes used in SOTA models.""" 107 | if any(map(name.startswith, exclude)): 108 | return 109 | if isinstance(module, nn.Linear): 110 | nn.init.trunc_normal_(module.weight, std=.02) 111 | if module.bias is not None: 112 | nn.init.zeros_(module.bias) 113 | elif isinstance(module, nn.Embedding): 114 | nn.init.trunc_normal_(module.weight, std=.02) 115 | if module.padding_idx is not None: 116 | module.weight.data[module.padding_idx].zero_() 117 | elif isinstance(module, nn.Conv2d): 118 | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') 119 | if module.bias is not None: 120 | nn.init.zeros_(module.bias) 121 | elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): 122 | nn.init.ones_(module.weight) 123 | nn.init.zeros_(module.bias) 124 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/vitstr/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Atienza, Rowel. "Vision Transformer for Fast and Efficient Scene Text Recognition." 3 | In International Conference on Document Analysis and Recognition (ICDAR). 2021. 4 | 5 | https://arxiv.org/abs/2105.08582 6 | 7 | All source files, except `system.py`, are based on the implementation listed below, 8 | and hence are released under the license of the original. 9 | 10 | Source: https://github.com/roatienza/deep-text-recognition-benchmark 11 | License: Apache License 2.0 (see LICENSE file in project root) 12 | """ 13 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/vitstr/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of ViTSTR based on timm VisionTransformer. 3 | 4 | TODO: 5 | 1) distilled deit backbone 6 | 2) base deit backbone 7 | 8 | Copyright 2021 Rowel Atienza 9 | """ 10 | 11 | from timm.models.vision_transformer import VisionTransformer 12 | 13 | 14 | class ViTSTR(VisionTransformer): 15 | """ 16 | ViTSTR is basically a ViT that uses DeiT weights. 17 | Modified head to support a sequence of characters prediction for STR. 18 | """ 19 | 20 | def forward(self, x, seqlen: int = 25): 21 | x = self.forward_features(x) 22 | x = x[:, :seqlen] 23 | 24 | # batch, seqlen, embsize 25 | b, s, e = x.size() 26 | x = x.reshape(b * s, e) 27 | x = self.head(x).view(b, s, self.num_classes) 28 | return x 29 | -------------------------------------------------------------------------------- /src/parseq/strhub/models/vitstr/system.py: -------------------------------------------------------------------------------- 1 | # Scene Text Recognition Model Hub 2 | # Copyright 2022 Darwin Bautista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Sequence, Any, Optional 17 | 18 | import torch 19 | from pytorch_lightning.utilities.types import STEP_OUTPUT 20 | from torch import Tensor 21 | 22 | from strhub.models.base import CrossEntropySystem 23 | from strhub.models.utils import init_weights 24 | from .model import ViTSTR as Model 25 | 26 | 27 | class ViTSTR(CrossEntropySystem): 28 | 29 | def __init__(self, charset_train: str, charset_test: str, max_label_length: int, 30 | batch_size: int, lr: float, warmup_pct: float, weight_decay: float, 31 | img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, num_heads: int, 32 | **kwargs: Any) -> None: 33 | super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) 34 | self.save_hyperparameters() 35 | self.max_label_length = max_label_length 36 | # We don't predict nor 37 | self.model = Model(img_size=img_size, patch_size=patch_size, depth=12, mlp_ratio=4, qkv_bias=True, 38 | embed_dim=embed_dim, num_heads=num_heads, num_classes=len(self.tokenizer) - 2) 39 | # Non-zero weight init for the head 40 | self.model.head.apply(init_weights) 41 | 42 | @torch.jit.ignore 43 | def no_weight_decay(self): 44 | return {'model.' + n for n in self.model.no_weight_decay()} 45 | 46 | def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: 47 | max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) 48 | logits = self.model.forward(images, max_length + 2) # +2 tokens for [GO] and [s] 49 | # Truncate to conform to other models. [GO] in ViTSTR is actually used as the padding (therefore, ignored). 50 | # First position corresponds to the class token, which is unused and ignored in the original work. 51 | logits = logits[:, 1:] 52 | return logits 53 | 54 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 55 | images, labels = batch 56 | loss = self.forward_logits_loss(images, labels)[1] 57 | self.log('loss', loss) 58 | return loss 59 | -------------------------------------------------------------------------------- /src/parseq/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Scene Text Recognition Model Hub 3 | # Copyright 2022 Darwin Bautista 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import string 19 | import sys 20 | from dataclasses import dataclass 21 | from typing import List 22 | 23 | import torch 24 | 25 | from tqdm import tqdm 26 | 27 | from strhub.data.module import SceneTextDataModule 28 | from strhub.models.utils import load_from_checkpoint, parse_model_args 29 | 30 | 31 | @dataclass 32 | class Result: 33 | dataset: str 34 | num_samples: int 35 | accuracy: float 36 | ned: float 37 | confidence: float 38 | label_length: float 39 | 40 | 41 | def print_results_table(results: List[Result], file=None): 42 | w = max(map(len, map(getattr, results, ['dataset'] * len(results)))) 43 | w = max(w, len('Dataset'), len('Combined')) 44 | print('| {:<{w}} | # samples | Accuracy | 1 - NED | Confidence | Label Length |'.format('Dataset', w=w), file=file) 45 | print('|:{:-<{w}}:|----------:|---------:|--------:|-----------:|-------------:|'.format('----', w=w), file=file) 46 | c = Result('Combined', 0, 0, 0, 0, 0) 47 | for res in results: 48 | c.num_samples += res.num_samples 49 | c.accuracy += res.num_samples * res.accuracy 50 | c.ned += res.num_samples * res.ned 51 | c.confidence += res.num_samples * res.confidence 52 | c.label_length += res.num_samples * res.label_length 53 | print(f'| {res.dataset:<{w}} | {res.num_samples:>9} | {res.accuracy:>8.2f} | {res.ned:>7.2f} ' 54 | f'| {res.confidence:>10.2f} | {res.label_length:>12.2f} |', file=file) 55 | c.accuracy /= c.num_samples 56 | c.ned /= c.num_samples 57 | c.confidence /= c.num_samples 58 | c.label_length /= c.num_samples 59 | print('|-{:-<{w}}-|-----------|----------|---------|------------|--------------|'.format('----', w=w), file=file) 60 | print(f'| {c.dataset:<{w}} | {c.num_samples:>9} | {c.accuracy:>8.2f} | {c.ned:>7.2f} ' 61 | f'| {c.confidence:>10.2f} | {c.label_length:>12.2f} |', file=file) 62 | 63 | 64 | @torch.inference_mode() 65 | def main(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('checkpoint', help="Model checkpoint (or 'pretrained=')") 68 | parser.add_argument('--data_root', default='data') 69 | parser.add_argument('--batch_size', type=int, default=512) 70 | parser.add_argument('--num_workers', type=int, default=4) 71 | parser.add_argument('--cased', action='store_true', default=False, help='Cased comparison') 72 | parser.add_argument('--punctuation', action='store_true', default=False, help='Check punctuation') 73 | parser.add_argument('--new', action='store_true', default=False, help='Evaluate on new benchmark datasets') 74 | parser.add_argument('--rotation', type=int, default=0, help='Angle of rotation (counter clockwise) in degrees.') 75 | parser.add_argument('--device', default='cuda') 76 | args, unknown = parser.parse_known_args() 77 | kwargs = parse_model_args(unknown) 78 | 79 | charset_test = string.digits + string.ascii_lowercase 80 | if args.cased: 81 | charset_test += string.ascii_uppercase 82 | if args.punctuation: 83 | charset_test += string.punctuation 84 | kwargs.update({'charset_test': charset_test}) 85 | print(f'Additional keyword arguments: {kwargs}') 86 | 87 | model = load_from_checkpoint(args.checkpoint, **kwargs).eval().to(args.device) 88 | hp = model.hparams 89 | datamodule = SceneTextDataModule(args.data_root, '_unused_', hp.img_size, hp.max_label_length, hp.charset_train, 90 | hp.charset_test, args.batch_size, args.num_workers, False, rotation=args.rotation) 91 | 92 | test_set = SceneTextDataModule.TEST_BENCHMARK_SUB + SceneTextDataModule.TEST_BENCHMARK 93 | if args.new: 94 | test_set += SceneTextDataModule.TEST_NEW 95 | test_set = sorted(set(test_set)) 96 | 97 | results = {} 98 | max_width = max(map(len, test_set)) 99 | for name, dataloader in datamodule.test_dataloaders(test_set).items(): 100 | total = 0 101 | correct = 0 102 | ned = 0 103 | confidence = 0 104 | label_length = 0 105 | for imgs, labels in tqdm(iter(dataloader), desc=f'{name:>{max_width}}'): 106 | res = model.test_step((imgs.to(model.device), labels), -1)['output'] 107 | total += res.num_samples 108 | correct += res.correct 109 | ned += res.ned 110 | confidence += res.confidence 111 | label_length += res.label_length 112 | accuracy = 100 * correct / total 113 | mean_ned = 100 * (1 - ned / total) 114 | mean_conf = 100 * confidence / total 115 | mean_label_length = label_length / total 116 | results[name] = Result(name, total, accuracy, mean_ned, mean_conf, mean_label_length) 117 | 118 | result_groups = { 119 | 'Benchmark (Subset)': SceneTextDataModule.TEST_BENCHMARK_SUB, 120 | 'Benchmark': SceneTextDataModule.TEST_BENCHMARK 121 | } 122 | if args.new: 123 | result_groups.update({'New': SceneTextDataModule.TEST_NEW}) 124 | with open(args.checkpoint + '.log.txt', 'w') as f: 125 | for out in [f, sys.stdout]: 126 | for group, subset in result_groups.items(): 127 | print(f'{group} set:', file=out) 128 | print_results_table([results[s] for s in subset], out) 129 | print('\n', file=out) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /src/parseq/tools/art_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | 5 | with open('train_task2_labels.json', 'r', encoding='utf8') as f: 6 | d = json.load(f) 7 | 8 | with open('gt.txt', 'w', encoding='utf8') as f: 9 | for k, v in d.items(): 10 | if len(v) != 1: 11 | print('error', v) 12 | v = v[0] 13 | if v['language'].lower() != 'latin': 14 | # print('Skipping non-Latin:', v) 15 | continue 16 | if v['illegibility']: 17 | # print('Skipping unreadable:', v) 18 | continue 19 | label = v['transcription'].strip() 20 | if not label: 21 | # print('Skipping blank label') 22 | continue 23 | if '#' in label and label != 'LocaL#3': 24 | # print('Skipping corrupted label') 25 | continue 26 | f.write('\t'.join(['train_task2_images/' + k + '.jpg', label]) + '\n') 27 | -------------------------------------------------------------------------------- /src/parseq/tools/case_sensitive_str_datasets_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os.path 4 | import sys 5 | from pathlib import Path 6 | 7 | d = sys.argv[1] 8 | p = Path(d) 9 | 10 | gt = [] 11 | 12 | num_samples = len(list(p.glob('label/*.txt'))) 13 | ext = 'jpg' if p.joinpath('IMG', '1.jpg').is_file() else 'png' 14 | 15 | for i in range(1, num_samples + 1): 16 | img = p.joinpath('IMG', f'{i}.{ext}') 17 | name = os.path.splitext(img.name)[0] 18 | 19 | with open(p.joinpath('label', f'{i}.txt'), 'r') as f: 20 | label = f.readline() 21 | gt.append((os.path.join('IMG', img.name), label)) 22 | 23 | with open(d + '/lmdb.txt', 'w', encoding='utf-8') as f: 24 | for line in gt: 25 | fname, label = line 26 | fname = fname.strip() 27 | label = label.strip() 28 | f.write('\t'.join([fname, label]) + '\n') 29 | -------------------------------------------------------------------------------- /src/parseq/tools/coco_2_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import html 4 | import math 5 | import os 6 | import os.path as osp 7 | from functools import partial 8 | 9 | import mmcv 10 | from PIL import Image 11 | from mmocr.utils.fileio import list_to_file 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Generate training and validation set of TextOCR ' 17 | 'by cropping box image.') 18 | parser.add_argument('root_path', help='Root dir path of TextOCR') 19 | parser.add_argument( 20 | 'n_proc', default=1, type=int, help='Number of processes to run') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def process_img(args, src_image_root, dst_image_root): 26 | # Dirty hack for multiprocessing 27 | img_idx, img_info, anns = args 28 | src_img = Image.open(osp.join(src_image_root, 'train2014', img_info['file_name'])) 29 | src_w, src_h = src_img.size 30 | labels = [] 31 | for ann_idx, ann in enumerate(anns): 32 | text_label = html.unescape(ann['utf8_string'].strip()) 33 | 34 | # Ignore empty labels 35 | if not text_label or ann['class'] != 'machine printed' or ann['language'] != 'english' or \ 36 | ann['legibility'] != 'legible': 37 | continue 38 | 39 | # Some labels and images with '#' in the middle are actually good, but some aren't, so we just filter them all. 40 | if text_label != '#' and '#' in text_label: 41 | continue 42 | 43 | # Some labels use '*' to denote unreadable characters 44 | if text_label.startswith('*') or text_label.endswith('*'): 45 | continue 46 | 47 | pad = 2 48 | x, y, w, h = ann['bbox'] 49 | x, y = max(0, math.floor(x) - pad), max(0, math.floor(y) - pad) 50 | w, h = math.ceil(w), math.ceil(h) 51 | x2, y2 = min(src_w, x + w + 2 * pad), min(src_h, y + h + 2 * pad) 52 | dst_img = src_img.crop((x, y, x2, y2)) 53 | dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' 54 | dst_img_path = osp.join(dst_image_root, dst_img_name) 55 | # Preserve JPEG quality 56 | dst_img.save(dst_img_path, qtables=src_img.quantization) 57 | labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' 58 | f' {text_label}') 59 | src_img.close() 60 | return labels 61 | 62 | 63 | def convert_textocr(root_path, 64 | dst_image_path, 65 | dst_label_filename, 66 | annotation_filename, 67 | img_start_idx=0, 68 | nproc=1): 69 | annotation_path = osp.join(root_path, annotation_filename) 70 | if not osp.exists(annotation_path): 71 | raise Exception( 72 | f'{annotation_path} not exists, please check and try again.') 73 | src_image_root = root_path 74 | 75 | # outputs 76 | dst_label_file = osp.join(root_path, dst_label_filename) 77 | dst_image_root = osp.join(root_path, dst_image_path) 78 | os.makedirs(dst_image_root, exist_ok=True) 79 | 80 | annotation = mmcv.load(annotation_path) 81 | split = 'train' if 'train' in dst_label_filename else 'val' 82 | 83 | process_img_with_path = partial( 84 | process_img, 85 | src_image_root=src_image_root, 86 | dst_image_root=dst_image_root) 87 | tasks = [] 88 | for img_idx, img_info in enumerate(annotation['imgs'].values()): 89 | if img_info['set'] != split: 90 | continue 91 | ann_ids = annotation['imgToAnns'][str(img_info['id'])] 92 | anns = [annotation['anns'][str(ann_id)] for ann_id in ann_ids] 93 | tasks.append((img_idx + img_start_idx, img_info, anns)) 94 | 95 | labels_list = mmcv.track_parallel_progress( 96 | process_img_with_path, tasks, keep_order=True, nproc=nproc) 97 | final_labels = [] 98 | for label_list in labels_list: 99 | final_labels += label_list 100 | list_to_file(dst_label_file, final_labels) 101 | return len(annotation['imgs']) 102 | 103 | 104 | def main(): 105 | args = parse_args() 106 | root_path = args.root_path 107 | print('Processing training set...') 108 | num_train_imgs = convert_textocr( 109 | root_path=root_path, 110 | dst_image_path='image', 111 | dst_label_filename='train_label.txt', 112 | annotation_filename='cocotext.v2.json', 113 | nproc=args.n_proc) 114 | print('Processing validation set...') 115 | convert_textocr( 116 | root_path=root_path, 117 | dst_image_path='image_val', 118 | dst_label_filename='val_label.txt', 119 | annotation_filename='cocotext.v2.json', 120 | img_start_idx=num_train_imgs, 121 | nproc=args.n_proc) 122 | print('Finish') 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /src/parseq/tools/coco_text_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | for s in ['train', 'val']: 4 | with open('{}_words_gt.txt'.format(s), 'r', encoding='utf8') as f: 5 | d = f.readlines() 6 | 7 | with open('{}_lmdb.txt'.format(s), 'w', encoding='utf8') as f: 8 | for line in d: 9 | try: 10 | fname, label = line.split(',', maxsplit=1) 11 | except ValueError: 12 | continue 13 | fname = '{}_words/{}.jpg'.format(s, fname.strip()) 14 | label = label.strip().strip('|') 15 | f.write('\t'.join([fname, label]) + '\n') 16 | -------------------------------------------------------------------------------- /src/parseq/tools/create_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ 3 | import io 4 | import os 5 | 6 | import fire 7 | import lmdb 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | def checkImageIsValid(imageBin): 13 | if imageBin is None: 14 | return False 15 | img = Image.open(io.BytesIO(imageBin)).convert('RGB') 16 | return np.prod(img.size) > 0 17 | 18 | 19 | def writeCache(env, cache): 20 | with env.begin(write=True) as txn: 21 | for k, v in cache.items(): 22 | txn.put(k, v) 23 | 24 | 25 | def createDataset(inputPath, gtFile, outputPath, checkValid=True): 26 | """ 27 | Create LMDB dataset for training and evaluation. 28 | ARGS: 29 | inputPath : input folder path where starts imagePath 30 | outputPath : LMDB output path 31 | gtFile : list of image path and label 32 | checkValid : if true, check the validity of every image 33 | """ 34 | os.makedirs(outputPath, exist_ok=True) 35 | env = lmdb.open(outputPath, map_size=1099511627776) 36 | 37 | cache = {} 38 | cnt = 1 39 | 40 | with open(gtFile, 'r', encoding='utf-8') as f: 41 | data = f.readlines() 42 | 43 | nSamples = len(data) 44 | for i, line in enumerate(data): 45 | imagePath, label = line.strip().split(maxsplit=1) 46 | imagePath = os.path.join(inputPath, imagePath) 47 | with open(imagePath, 'rb') as f: 48 | imageBin = f.read() 49 | if checkValid: 50 | try: 51 | img = Image.open(io.BytesIO(imageBin)).convert('RGB') 52 | except IOError as e: 53 | with open(outputPath + '/error_image_log.txt', 'a') as log: 54 | log.write('{}-th image data occured error: {}, {}\n'.format(i, imagePath, e)) 55 | continue 56 | if np.prod(img.size) == 0: 57 | print('%s is not a valid image' % imagePath) 58 | continue 59 | 60 | imageKey = 'image-%09d'.encode() % cnt 61 | labelKey = 'label-%09d'.encode() % cnt 62 | cache[imageKey] = imageBin 63 | cache[labelKey] = label.encode() 64 | 65 | if cnt % 1000 == 0: 66 | writeCache(env, cache) 67 | cache = {} 68 | print('Written %d / %d' % (cnt, nSamples)) 69 | cnt += 1 70 | nSamples = cnt - 1 71 | cache['num-samples'.encode()] = str(nSamples).encode() 72 | writeCache(env, cache) 73 | env.close() 74 | print('Created dataset with %d samples' % nSamples) 75 | 76 | 77 | if __name__ == '__main__': 78 | fire.Fire(createDataset) 79 | -------------------------------------------------------------------------------- /src/parseq/tools/filter_lmdb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import io 3 | import os 4 | from argparse import ArgumentParser 5 | 6 | import numpy as np 7 | import lmdb 8 | from PIL import Image 9 | 10 | 11 | def main(): 12 | parser = ArgumentParser() 13 | parser.add_argument('inputs', nargs='+', help='Path to input LMDBs') 14 | parser.add_argument('--output', help='Path to output LMDB') 15 | parser.add_argument('--min_image_dim', type=int, default=8) 16 | args = parser.parse_args() 17 | 18 | os.makedirs(args.output, exist_ok=True) 19 | with lmdb.open(args.output, map_size=1099511627776) as env_out: 20 | in_samples = 0 21 | out_samples = 0 22 | samples_per_chunk = 1000 23 | for lmdb_in in args.inputs: 24 | with lmdb.open(lmdb_in, readonly=True, max_readers=1, lock=False) as env_in: 25 | with env_in.begin() as txn: 26 | num_samples = int(txn.get('num-samples'.encode())) 27 | in_samples += num_samples 28 | chunks = np.array_split(range(num_samples), num_samples // samples_per_chunk) 29 | for chunk in chunks: 30 | cache = {} 31 | with env_in.begin() as txn: 32 | for index in chunk: 33 | index += 1 # lmdb starts at 1 34 | image_key = f'image-{index:09d}'.encode() 35 | image_bin = txn.get(image_key) 36 | img = Image.open(io.BytesIO(image_bin)) 37 | w, h = img.size 38 | if w < args.min_image_dim or h < args.min_image_dim: 39 | print(f'Skipping: {index}, w = {w}, h = {h}') 40 | continue 41 | out_samples += 1 # increment. start at 1 42 | label_key = f'label-{index:09d}'.encode() 43 | out_label_key = f'label-{out_samples:09d}'.encode() 44 | out_image_key = f'image-{out_samples:09d}'.encode() 45 | cache[out_label_key] = txn.get(label_key) 46 | cache[out_image_key] = image_bin 47 | with env_out.begin(write=True) as txn: 48 | for k, v in cache.items(): 49 | txn.put(k, v) 50 | print(f'Written samples from {chunk[0]} to {chunk[-1]}') 51 | with env_out.begin(write=True) as txn: 52 | txn.put('num-samples'.encode(), str(out_samples).encode()) 53 | print(f'Written {out_samples} samples to {args.output} out of {in_samples} input samples.') 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /src/parseq/tools/lsvt_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import re 6 | from functools import partial 7 | 8 | import mmcv 9 | import numpy as np 10 | from PIL import Image 11 | from mmocr.utils.fileio import list_to_file 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Generate training set of LSVT ' 17 | 'by cropping box image.') 18 | parser.add_argument('root_path', help='Root dir path of LSVT') 19 | parser.add_argument( 20 | 'n_proc', default=1, type=int, help='Number of processes to run') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def process_img(args, src_image_root, dst_image_root): 26 | # Dirty hack for multiprocessing 27 | img_idx, img_info, anns = args 28 | try: 29 | src_img = Image.open(osp.join(src_image_root, 'train_full_images_0/{}.jpg'.format(img_info))) 30 | except IOError: 31 | src_img = Image.open(osp.join(src_image_root, 'train_full_images_1/{}.jpg'.format(img_info))) 32 | blacklist = ['LOFTINESS*'] 33 | whitelist = ['#Find YOUR Fun#', 'Story #', '*0#'] 34 | labels = [] 35 | for ann_idx, ann in enumerate(anns): 36 | text_label = ann['transcription'] 37 | 38 | # Ignore illegible or words with non-Latin characters 39 | if ann['illegibility'] or re.findall(r'[\u4e00-\u9fff]+', text_label) or text_label in blacklist or \ 40 | ('#' in text_label and text_label not in whitelist): 41 | continue 42 | 43 | points = np.asarray(ann['points']) 44 | x1, y1 = points.min(axis=0) 45 | x2, y2 = points.max(axis=0) 46 | 47 | dst_img = src_img.crop((x1, y1, x2, y2)) 48 | dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' 49 | dst_img_path = osp.join(dst_image_root, dst_img_name) 50 | # Preserve JPEG quality 51 | dst_img.save(dst_img_path, qtables=src_img.quantization) 52 | labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' 53 | f' {text_label}') 54 | src_img.close() 55 | return labels 56 | 57 | 58 | def convert_lsvt(root_path, 59 | dst_image_path, 60 | dst_label_filename, 61 | annotation_filename, 62 | img_start_idx=0, 63 | nproc=1): 64 | annotation_path = osp.join(root_path, annotation_filename) 65 | if not osp.exists(annotation_path): 66 | raise Exception( 67 | f'{annotation_path} not exists, please check and try again.') 68 | src_image_root = root_path 69 | 70 | # outputs 71 | dst_label_file = osp.join(root_path, dst_label_filename) 72 | dst_image_root = osp.join(root_path, dst_image_path) 73 | os.makedirs(dst_image_root, exist_ok=True) 74 | 75 | annotation = mmcv.load(annotation_path) 76 | 77 | process_img_with_path = partial( 78 | process_img, 79 | src_image_root=src_image_root, 80 | dst_image_root=dst_image_root) 81 | tasks = [] 82 | for img_idx, (img_info, anns) in enumerate(annotation.items()): 83 | tasks.append((img_idx + img_start_idx, img_info, anns)) 84 | labels_list = mmcv.track_parallel_progress( 85 | process_img_with_path, tasks, keep_order=True, nproc=nproc) 86 | final_labels = [] 87 | for label_list in labels_list: 88 | final_labels += label_list 89 | list_to_file(dst_label_file, final_labels) 90 | return len(annotation) 91 | 92 | 93 | def main(): 94 | args = parse_args() 95 | root_path = args.root_path 96 | print('Processing training set...') 97 | convert_lsvt( 98 | root_path=root_path, 99 | dst_image_path='image_train', 100 | dst_label_filename='train_label.txt', 101 | annotation_filename='train_full_labels.json', 102 | nproc=args.n_proc) 103 | print('Finish') 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /src/parseq/tools/mlt19_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | 5 | root = sys.argv[1] 6 | 7 | with open(root + '/gt.txt', 'r') as f: 8 | d = f.readlines() 9 | 10 | with open(root + '/lmdb.txt', 'w') as f: 11 | for line in d: 12 | img, script, label = line.split(',', maxsplit=2) 13 | label = label.strip() 14 | if label and script in ['Latin', 'Symbols']: 15 | f.write('\t'.join([img, label]) + '\n') 16 | -------------------------------------------------------------------------------- /src/parseq/tools/openvino_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | import os 4 | import os.path as osp 5 | from argparse import ArgumentParser 6 | from functools import partial 7 | 8 | import mmcv 9 | from PIL import Image 10 | 11 | from mmocr.utils.fileio import list_to_file 12 | 13 | 14 | def parse_args(): 15 | parser = ArgumentParser(description='Generate training and validation set ' 16 | 'of OpenVINO annotations for Open ' 17 | 'Images by cropping box image.') 18 | parser.add_argument( 19 | 'root_path', help='Root dir containing images and annotations') 20 | parser.add_argument( 21 | 'n_proc', default=1, type=int, help='Number of processes to run') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def process_img(args, src_image_root, dst_image_root): 27 | # Dirty hack for multiprocessing 28 | img_idx, img_info, anns = args 29 | src_img = Image.open(osp.join(src_image_root, img_info['file_name'])) 30 | labels = [] 31 | for ann_idx, ann in enumerate(anns): 32 | attrs = ann['attributes'] 33 | text_label = attrs['transcription'] 34 | 35 | # Ignore illegible or non-English words 36 | if not attrs['legible'] or attrs['language'] != 'english': 37 | continue 38 | 39 | x, y, w, h = ann['bbox'] 40 | x, y = max(0, math.floor(x)), max(0, math.floor(y)) 41 | w, h = math.ceil(w), math.ceil(h) 42 | dst_img = src_img.crop((x, y, x + w, y + h)) 43 | dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' 44 | dst_img_path = osp.join(dst_image_root, dst_img_name) 45 | # Preserve JPEG quality 46 | dst_img.save(dst_img_path, qtables=src_img.quantization) 47 | labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' 48 | f' {text_label}') 49 | src_img.close() 50 | return labels 51 | 52 | 53 | def convert_openimages(root_path, 54 | dst_image_path, 55 | dst_label_filename, 56 | annotation_filename, 57 | img_start_idx=0, 58 | nproc=1): 59 | annotation_path = osp.join(root_path, annotation_filename) 60 | if not osp.exists(annotation_path): 61 | raise Exception( 62 | f'{annotation_path} not exists, please check and try again.') 63 | src_image_root = root_path 64 | 65 | # outputs 66 | dst_label_file = osp.join(root_path, dst_label_filename) 67 | dst_image_root = osp.join(root_path, dst_image_path) 68 | os.makedirs(dst_image_root, exist_ok=True) 69 | 70 | annotation = mmcv.load(annotation_path) 71 | 72 | process_img_with_path = partial( 73 | process_img, 74 | src_image_root=src_image_root, 75 | dst_image_root=dst_image_root) 76 | tasks = [] 77 | anns = {} 78 | for ann in annotation['annotations']: 79 | anns.setdefault(ann['image_id'], []).append(ann) 80 | for img_idx, img_info in enumerate(annotation['images']): 81 | tasks.append((img_idx + img_start_idx, img_info, anns[img_info['id']])) 82 | labels_list = mmcv.track_parallel_progress( 83 | process_img_with_path, tasks, keep_order=True, nproc=nproc) 84 | final_labels = [] 85 | for label_list in labels_list: 86 | final_labels += label_list 87 | list_to_file(dst_label_file, final_labels) 88 | return len(annotation['images']) 89 | 90 | 91 | def main(): 92 | args = parse_args() 93 | root_path = args.root_path 94 | print('Processing training set...') 95 | num_train_imgs = 0 96 | for s in '125f': 97 | num_train_imgs = convert_openimages( 98 | root_path=root_path, 99 | dst_image_path=f'image_{s}', 100 | dst_label_filename=f'train_{s}_label.txt', 101 | annotation_filename=f'text_spotting_openimages_v5_train_{s}.json', 102 | img_start_idx=num_train_imgs, 103 | nproc=args.n_proc) 104 | print('Processing validation set...') 105 | convert_openimages( 106 | root_path=root_path, 107 | dst_image_path='image_val', 108 | dst_label_filename='val_label.txt', 109 | annotation_filename='text_spotting_openimages_v5_validation.json', 110 | img_start_idx=num_train_imgs, 111 | nproc=args.n_proc) 112 | print('Finish') 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /src/parseq/tools/test_abinet_lm_acc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import string 4 | import sys 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | from tqdm import tqdm 12 | 13 | from strhub.data.module import SceneTextDataModule 14 | from strhub.models.abinet.system import ABINet 15 | 16 | sys.path.insert(0, '.') 17 | from hubconf import _get_config 18 | from test import Result, print_results_table 19 | 20 | 21 | class ABINetLM(ABINet): 22 | 23 | def _encode(self, labels): 24 | targets = [torch.arange(self.max_label_length + 1)] # dummy target. used to set pad_sequence() length 25 | lengths = [] 26 | for label in labels: 27 | targets.append(torch.as_tensor([self.tokenizer._stoi[c] for c in label])) 28 | lengths.append(len(label) + 1) 29 | targets = pad_sequence(targets, batch_first=True, padding_value=0)[1:] # exclude dummy target 30 | lengths = torch.as_tensor(lengths, device=self.device) 31 | targets = F.one_hot(targets, len(self.tokenizer._stoi))[..., :len(self.tokenizer._stoi) - 2].float().to(self.device) 32 | return targets, lengths 33 | 34 | def forward(self, labels: Tensor, max_length: int = None) -> Tensor: 35 | targets, lengths = self._encode(labels) 36 | return self.model.language(targets, lengths)['logits'] 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser(description='Measure the word accuracy of ABINet LM using the ground truth as input') 41 | parser.add_argument('checkpoint', help='Official pretrained weights for ABINet-LV (best-train-abinet.pth)') 42 | parser.add_argument('--data_root', default='data') 43 | parser.add_argument('--batch_size', type=int, default=512) 44 | parser.add_argument('--num_workers', type=int, default=4) 45 | parser.add_argument('--new', action='store_true', default=False, help='Evaluate on new benchmark datasets') 46 | parser.add_argument('--device', default='cuda') 47 | args = parser.parse_args() 48 | 49 | # charset used by original ABINet 50 | charset = string.ascii_lowercase + '1234567890' 51 | ckpt = torch.load(args.checkpoint) 52 | 53 | config = _get_config('abinet', charset_train=charset, charset_test=charset) 54 | model = ABINetLM(**config) 55 | model.model.load_state_dict(ckpt['model']) 56 | 57 | model = model.eval().to(args.device) 58 | model.freeze() # disable autograd 59 | hp = model.hparams 60 | datamodule = SceneTextDataModule(args.data_root, '_unused_', hp.img_size, hp.max_label_length, hp.charset_train, 61 | hp.charset_test, args.batch_size, args.num_workers, False) 62 | 63 | test_set = SceneTextDataModule.TEST_BENCHMARK 64 | if args.new: 65 | test_set += SceneTextDataModule.TEST_NEW 66 | test_set = sorted(set(test_set)) 67 | 68 | results = {} 69 | max_width = max(map(len, test_set)) 70 | for name, dataloader in datamodule.test_dataloaders(test_set).items(): 71 | total = 0 72 | correct = 0 73 | ned = 0 74 | confidence = 0 75 | label_length = 0 76 | for _, labels in tqdm(iter(dataloader), desc=f'{name:>{max_width}}'): 77 | res = model.test_step((labels, labels), -1)['output'] 78 | total += res.num_samples 79 | correct += res.correct 80 | ned += res.ned 81 | confidence += res.confidence 82 | label_length += res.label_length 83 | accuracy = 100 * correct / total 84 | mean_ned = 100 * (1 - ned / total) 85 | mean_conf = 100 * confidence / total 86 | mean_label_length = label_length / total 87 | results[name] = Result(name, total, accuracy, mean_ned, mean_conf, mean_label_length) 88 | 89 | result_groups = { 90 | 'Benchmark': SceneTextDataModule.TEST_BENCHMARK 91 | } 92 | if args.new: 93 | result_groups.update({'New': SceneTextDataModule.TEST_NEW}) 94 | for group, subset in result_groups.items(): 95 | print(f'{group} set:') 96 | print_results_table([results[s] for s in subset]) 97 | print('\n') 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /src/parseq/tools/textocr_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import argparse 4 | import math 5 | import os 6 | import os.path as osp 7 | from functools import partial 8 | 9 | import mmcv 10 | import numpy as np 11 | from PIL import Image 12 | from mmocr.utils.fileio import list_to_file 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description='Generate training and validation set of TextOCR ' 18 | 'by cropping box image.') 19 | parser.add_argument('root_path', help='Root dir path of TextOCR') 20 | parser.add_argument( 21 | 'n_proc', default=1, type=int, help='Number of processes to run') 22 | parser.add_argument('--rectify_pose', action='store_true', 23 | help='Fix pose of rotated text to make them horizontal') 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def rectify_image_pose(image, top_left, points): 29 | # Points-based heuristics for determining text orientation w.r.t. bounding box 30 | points = np.asarray(points).reshape(-1, 2) 31 | dist = ((points - np.asarray(top_left)) ** 2).sum(axis=1) 32 | left_midpoint = (points[0] + points[-1]) / 2 33 | right_corner_points = ((points - left_midpoint) ** 2).sum(axis=1).argsort()[-2:] 34 | right_midpoint = points[right_corner_points].sum(axis=0) / 2 35 | d_x, d_y = abs(right_midpoint - left_midpoint) 36 | 37 | if dist[0] + dist[-1] <= dist[right_corner_points].sum(): 38 | if d_x >= d_y: 39 | rot = 0 40 | else: 41 | rot = 90 42 | else: 43 | if d_x >= d_y: 44 | rot = 180 45 | else: 46 | rot = -90 47 | if rot: 48 | image = image.rotate(rot, expand=True) 49 | return image 50 | 51 | 52 | def process_img(args, src_image_root, dst_image_root): 53 | # Dirty hack for multiprocessing 54 | img_idx, img_info, anns, rectify_pose = args 55 | src_img = Image.open(osp.join(src_image_root, img_info['file_name'])) 56 | labels = [] 57 | for ann_idx, ann in enumerate(anns): 58 | text_label = ann['utf8_string'] 59 | 60 | # Ignore illegible or non-English words 61 | if text_label == '.': 62 | continue 63 | 64 | x, y, w, h = ann['bbox'] 65 | x, y = max(0, math.floor(x)), max(0, math.floor(y)) 66 | w, h = math.ceil(w), math.ceil(h) 67 | dst_img = src_img.crop((x, y, x + w, y + h)) 68 | if rectify_pose: 69 | dst_img = rectify_image_pose(dst_img, (x, y), ann['points']) 70 | dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' 71 | dst_img_path = osp.join(dst_image_root, dst_img_name) 72 | # Preserve JPEG quality 73 | dst_img.save(dst_img_path, qtables=src_img.quantization) 74 | labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' 75 | f' {text_label}') 76 | src_img.close() 77 | return labels 78 | 79 | 80 | def convert_textocr(root_path, 81 | dst_image_path, 82 | dst_label_filename, 83 | annotation_filename, 84 | img_start_idx=0, 85 | nproc=1, 86 | rectify_pose=False): 87 | annotation_path = osp.join(root_path, annotation_filename) 88 | if not osp.exists(annotation_path): 89 | raise Exception( 90 | f'{annotation_path} not exists, please check and try again.') 91 | src_image_root = root_path 92 | 93 | # outputs 94 | dst_label_file = osp.join(root_path, dst_label_filename) 95 | dst_image_root = osp.join(root_path, dst_image_path) 96 | os.makedirs(dst_image_root, exist_ok=True) 97 | 98 | annotation = mmcv.load(annotation_path) 99 | 100 | process_img_with_path = partial( 101 | process_img, 102 | src_image_root=src_image_root, 103 | dst_image_root=dst_image_root) 104 | tasks = [] 105 | for img_idx, img_info in enumerate(annotation['imgs'].values()): 106 | ann_ids = annotation['imgToAnns'][img_info['id']] 107 | anns = [annotation['anns'][ann_id] for ann_id in ann_ids] 108 | tasks.append((img_idx + img_start_idx, img_info, anns, rectify_pose)) 109 | labels_list = mmcv.track_parallel_progress( 110 | process_img_with_path, tasks, keep_order=True, nproc=nproc) 111 | final_labels = [] 112 | for label_list in labels_list: 113 | final_labels += label_list 114 | list_to_file(dst_label_file, final_labels) 115 | return len(annotation['imgs']) 116 | 117 | 118 | def main(): 119 | args = parse_args() 120 | root_path = args.root_path 121 | print('Processing training set...') 122 | num_train_imgs = convert_textocr( 123 | root_path=root_path, 124 | dst_image_path='image', 125 | dst_label_filename='train_label.txt', 126 | annotation_filename='TextOCR_0.1_train.json', 127 | nproc=args.n_proc, 128 | rectify_pose=args.rectify_pose) 129 | print('Processing validation set...') 130 | convert_textocr( 131 | root_path=root_path, 132 | dst_image_path='image', 133 | dst_label_filename='val_label.txt', 134 | annotation_filename='TextOCR_0.1_val.json', 135 | img_start_idx=num_train_imgs, 136 | nproc=args.n_proc, 137 | rectify_pose=args.rectify_pose) 138 | print('Finish') 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | -------------------------------------------------------------------------------- /src/parseq/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Scene Text Recognition Model Hub 3 | # Copyright 2022 Darwin Bautista 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from pathlib import Path 18 | 19 | from omegaconf import DictConfig, open_dict 20 | import hydra 21 | from hydra.core.hydra_config import HydraConfig 22 | 23 | from pytorch_lightning import Trainer 24 | from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | from pytorch_lightning.strategies import DDPStrategy 27 | from pytorch_lightning.utilities.model_summary import summarize 28 | 29 | from strhub.data.module import SceneTextDataModule 30 | from strhub.models.base import BaseSystem 31 | from strhub.models.utils import get_pretrained_weights 32 | 33 | 34 | @hydra.main(config_path='configs', config_name='main', version_base='1.2') 35 | def main(config: DictConfig): 36 | trainer_strategy = None 37 | with open_dict(config): 38 | # Resolve absolute path to data.root_dir 39 | config.data.root_dir = hydra.utils.to_absolute_path(config.data.root_dir) 40 | # Special handling for GPU-affected config 41 | gpus = config.trainer.get('gpus', 0) 42 | if gpus: 43 | # Use mixed-precision training 44 | config.trainer.precision = 16 45 | if gpus > 1: 46 | # Use DDP 47 | config.trainer.strategy = 'ddp' 48 | # DDP optimizations 49 | trainer_strategy = DDPStrategy(find_unused_parameters=False, gradient_as_bucket_view=True) 50 | # Scale steps-based config 51 | config.trainer.val_check_interval //= gpus 52 | if config.trainer.get('max_steps', -1) > 0: 53 | config.trainer.max_steps //= gpus 54 | 55 | # Special handling for PARseq 56 | if config.model.get('perm_mirrored', False): 57 | assert config.model.perm_num % 2 == 0, 'perm_num should be even if perm_mirrored = True' 58 | 59 | model: BaseSystem = hydra.utils.instantiate(config.model) 60 | # If specified, use pretrained weights to initialize the model 61 | if config.pretrained is not None: 62 | model.load_state_dict(get_pretrained_weights(config.pretrained)) 63 | print(summarize(model, max_depth=1 if model.hparams.name.startswith('parseq') else 2)) 64 | 65 | datamodule: SceneTextDataModule = hydra.utils.instantiate(config.data) 66 | 67 | checkpoint = ModelCheckpoint(monitor='val_accuracy', mode='max', save_top_k=3, save_last=True, 68 | filename='{epoch}-{step}-{val_accuracy:.4f}-{val_NED:.4f}') 69 | swa = StochasticWeightAveraging(swa_epoch_start=0.75) 70 | cwd = HydraConfig.get().runtime.output_dir if config.ckpt_path is None else \ 71 | str(Path(config.ckpt_path).parents[1].absolute()) 72 | trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=TensorBoardLogger(cwd, '', '.'), 73 | strategy=trainer_strategy, enable_model_summary=False, 74 | callbacks=[checkpoint, swa]) 75 | trainer.fit(model, datamodule=datamodule, ckpt_path=config.ckpt_path) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import os 5 | 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from contextlib import nullcontext 9 | from os.path import join as ospj 10 | from torchvision.utils import save_image 11 | from omegaconf import OmegaConf 12 | from pytorch_lightning import seed_everything 13 | from dataset.dataloader import get_dataloader 14 | 15 | from util import * 16 | 17 | 18 | def predict(cfgs, model, sampler, batch): 19 | 20 | context = torch.no_grad 21 | 22 | with context(): 23 | 24 | batch, batch_uc_1 = prepare_batch(cfgs, batch) 25 | 26 | c, uc_1 = model.conditioner.get_unconditional_conditioning( 27 | batch, 28 | batch_uc=batch_uc_1, 29 | force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings, 30 | ) 31 | 32 | x = sampler.get_init_noise(cfgs, batch=batch) 33 | samples_z = sampler(model, x, cond=c, batch=batch, uc=uc_1, init_step=0, 34 | detailed = cfgs.detailed) 35 | 36 | samples_x = model.decode_first_stage(samples_z) 37 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) 38 | 39 | return samples, samples_z 40 | 41 | 42 | def test(model, sampler, dataloader, cfgs): 43 | 44 | output_dir = cfgs.output_dir 45 | os.system(f"rm -rf {output_dir}") 46 | os.makedirs(output_dir, exist_ok=True) 47 | real_dir = ospj(output_dir, "real") 48 | fake_dir = ospj(output_dir, "fake") 49 | os.makedirs(real_dir, exist_ok=True) 50 | os.makedirs(fake_dir, exist_ok=True) 51 | 52 | temp_dir = cfgs.temp_dir 53 | os.system(f"rm -rf {temp_dir}") 54 | os.makedirs(ospj(temp_dir, "attn_map"), exist_ok=True) 55 | os.makedirs(ospj(temp_dir, "seg_map"), exist_ok=True) 56 | 57 | if cfgs.ocr_enabled: 58 | predictor = instantiate_from_config(cfgs.predictor_config) 59 | predictor.parseq = predictor.parseq.to(sampler.device) 60 | 61 | correct_num = 0 62 | total_num = 0 63 | 64 | for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): 65 | 66 | 67 | name = batch["name"][0] 68 | results, results_z = predict(cfgs, model, sampler, batch) 69 | 70 | # run ocr 71 | if cfgs.ocr_enabled: 72 | 73 | r_bbox = batch["r_bbox"] 74 | gt_txt = batch["label"] 75 | results_crop = [] 76 | for i, bbox in enumerate(r_bbox): 77 | r_top, r_bottom, r_left, r_right = bbox 78 | results_crop.append(results[i, :, r_top:r_bottom, r_left:r_right]) 79 | pred_txt = predictor.img2txt(results_crop) 80 | 81 | correct_count = sum([int(pred_txt[i].lower()==gt_txt[i].lower()) for i in range(len(gt_txt))]) 82 | print(f"Expected text: {batch['label']}") 83 | if correct_count < len(gt_txt): 84 | print(f"\033[1;31m OCR Result: {pred_txt} \033[0m") 85 | else: 86 | print(f"\033[1;32m OCR Result: {pred_txt} \033[0m") 87 | correct_num += correct_count 88 | total_num += len(gt_txt) 89 | 90 | # save results 91 | result = results.cpu().numpy().transpose(0, 2, 3, 1) * 255 92 | result = np.concatenate(result, axis = -2) 93 | 94 | outputs = [] 95 | for key in ("image", "masked", "mask"): 96 | if key in batch: 97 | output = batch[key] 98 | if key != "mask": 99 | output = (output + 1.0) / 2.0 100 | output = output.cpu().numpy().transpose(0, 2, 3, 1) * 255 101 | output = np.concatenate(output, axis = -2) 102 | if key == "mask": 103 | output = np.tile(output, (1,1,3)) 104 | outputs.append(output) 105 | 106 | outputs.append(result) 107 | real = Image.fromarray(outputs[0].astype(np.uint8)) 108 | fake = Image.fromarray(outputs[-1].astype(np.uint8)) 109 | real.save(ospj(output_dir, "real", f"{name}.png")) 110 | fake.save(ospj(output_dir, "fake", f"{name}.png")) 111 | 112 | output = np.concatenate(outputs, axis = 0) 113 | output = Image.fromarray(output.astype(np.uint8)) 114 | output.save(ospj(output_dir, f"{name}.png")) 115 | 116 | if cfgs.ocr_enabled: 117 | print(f"OCR test completed. Mean accuracy: {correct_num/total_num}") 118 | 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | cfgs = OmegaConf.load("./configs/test.yaml") 124 | 125 | seed = random.randint(0, 2147483647) 126 | seed_everything(seed) 127 | 128 | model = init_model(cfgs) 129 | sampler = init_sampling(cfgs) 130 | dataloader = get_dataloader(cfgs, "val") 131 | 132 | test(model, sampler, dataloader, cfgs) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import random 4 | import pytorch_lightning as pl 5 | 6 | from omegaconf import OmegaConf 7 | from dataset.dataloader import get_dataloader 8 | from pytorch_lightning import seed_everything 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | from torchvision.utils import save_image 11 | from pytorch_lightning.loggers import TensorBoardLogger 12 | 13 | from util import * 14 | 15 | 16 | def train(): 17 | 18 | sys.path.append(os.getcwd()) 19 | 20 | # torch settings 21 | torch.multiprocessing.set_start_method('spawn') # multiprocess mode 22 | torch.set_float32_matmul_precision('medium') # matrix multiply precision 23 | 24 | config_path = 'configs/train.yaml' 25 | cfgs = OmegaConf.load(config_path) 26 | 27 | seed = random.randint(0, 2147483647) 28 | seed_everything(seed, workers=True) 29 | 30 | dataloader = get_dataloader(cfgs) 31 | model = init_model(cfgs) 32 | model.learning_rate = cfgs.base_learning_rate 33 | 34 | checkpoint_callback_epoch = ModelCheckpoint(dirpath = cfgs.save_ckpt_dir, every_n_epochs = cfgs.save_ckpt_freq, save_top_k=-1) 35 | 36 | trainer = pl.Trainer(callbacks = [checkpoint_callback_epoch], **cfgs.lightning) 37 | 38 | trainer.fit(model = model, train_dataloaders = dataloader) 39 | 40 | 41 | 42 | if __name__=='__main__': 43 | 44 | train() 45 | 46 | 47 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | from sgm.util import instantiate_from_config 4 | from sgm.modules.diffusionmodules.sampling import * 5 | 6 | 7 | def init_model(cfgs): 8 | 9 | model_cfg = OmegaConf.load(cfgs.model_cfg_path) 10 | ckpt = cfgs.load_ckpt_path 11 | 12 | model = instantiate_from_config(model_cfg.model) 13 | model.init_from_ckpt(ckpt) 14 | 15 | if cfgs.type == "train": 16 | model.train() 17 | else: 18 | model.to(torch.device("cuda", index=cfgs.gpu)) 19 | model.eval() 20 | model.freeze() 21 | 22 | return model 23 | 24 | def init_sampling(cfgs): 25 | 26 | discretization_config = { 27 | "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", 28 | } 29 | 30 | guider_config = { 31 | "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", 32 | "params": {"scale": cfgs.scale[0]}, 33 | } 34 | 35 | sampler = EulerEDMSampler( 36 | num_steps=cfgs.steps, 37 | discretization_config=discretization_config, 38 | guider_config=guider_config, 39 | s_churn=0.0, 40 | s_tmin=0.0, 41 | s_tmax=999.0, 42 | s_noise=1.0, 43 | verbose=True, 44 | device=torch.device("cuda", index=cfgs.gpu) 45 | ) 46 | 47 | return sampler 48 | 49 | def deep_copy(batch): 50 | 51 | c_batch = {} 52 | for key in batch: 53 | if isinstance(batch[key], torch.Tensor): 54 | c_batch[key] = torch.clone(batch[key]) 55 | elif isinstance(batch[key], (tuple, list)): 56 | c_batch[key] = batch[key].copy() 57 | else: 58 | c_batch[key] = batch[key] 59 | 60 | return c_batch 61 | 62 | def prepare_batch(cfgs, batch): 63 | 64 | for key in batch: 65 | if isinstance(batch[key], torch.Tensor): 66 | batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu)) 67 | 68 | batch_uc = deep_copy(batch) 69 | 70 | if "ntxt" in batch: 71 | batch_uc["txt"] = batch["ntxt"] 72 | else: 73 | batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))] 74 | 75 | if "label" in batch: 76 | batch_uc["label"] = ["" for _ in range(len(batch["label"]))] 77 | 78 | return batch, batch_uc --------------------------------------------------------------------------------