├── .gitignore ├── LICENSE ├── README.md ├── Rethinking-Text-Segmentation ├── .gitignore ├── README.md ├── configs │ ├── __init__.py │ ├── cfg_base.py │ ├── cfg_dataset.py │ └── cfg_model.py ├── eval_utils.py ├── hrnet_code │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── requirements.txt │ └── tools │ │ ├── _init_paths.py │ │ ├── test.py │ │ └── train.py ├── inference.py ├── main.py ├── requirement.txt └── train_utils.py ├── banner.png ├── bert_tokenizer.py ├── cldm ├── cldm.py ├── ddim_hacked.py ├── embedding_manager.py ├── hack.py ├── logger.py ├── model.py └── recognizer.py ├── dataset_util.py ├── demo.py ├── docs ├── demo.jpg ├── eval.jpg ├── framework.jpg ├── gallery.png ├── reproduce.jpg └── sample.jpg ├── environment.yml ├── eval ├── anytext_multiGPUs.py ├── anytext_singleGPU.py ├── controlnet_multiGPUs.py ├── controlnet_singleGPU.py ├── cut_bounding_box.py ├── eval_dgocr.py ├── eval_fid.sh ├── eval_font.py ├── eval_ocr.sh ├── gen_glyph.sh ├── gen_imgs_anytext.sh ├── gen_imgs_controlnet_canny.sh ├── gen_imgs_controltext.sh ├── gen_imgs_glyphcontrol.sh ├── gen_imgs_textdiffuser.sh ├── get_glyph_lines.py ├── glyphcontrol_multiGPUs.py ├── glyphcontrol_singleGPU.py ├── render_glyph_imgs.py ├── textdiffuser_multiGPUs.py └── textdiffuser_singleGPU.py ├── example_images ├── banner.png ├── edit1.png ├── edit10.png ├── edit11.png ├── edit12.png ├── edit13.png ├── edit14.png ├── edit15.png ├── edit16.png ├── edit2.png ├── edit3.png ├── edit4.png ├── edit5.png ├── edit6.png ├── edit7.png ├── edit8.png ├── edit9.png ├── gen1.png ├── gen10.png ├── gen11.png ├── gen12.png ├── gen13.png ├── gen14.png ├── gen15.png ├── gen16.png ├── gen17.png ├── gen18.png ├── gen19.png ├── gen2.png ├── gen20.png ├── gen21.png ├── gen3.png ├── gen4.png ├── gen5.png ├── gen6.png ├── gen7.png ├── gen8.png ├── gen9.png ├── ref1.jpg ├── ref10.jpg ├── ref11.jpg ├── ref12.png ├── ref13.jpg ├── ref14.png ├── ref15.jpeg ├── ref16.jpeg ├── ref2.jpg ├── ref3.jpg ├── ref4.jpg ├── ref5.jpg ├── ref6.jpg ├── ref7.jpg ├── ref8.jpg └── ref9.jpg ├── flows.png ├── inference.py ├── javascript └── bboxHint.js ├── ldm ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── lora_util.py ├── models_yaml ├── anytext_sd15.yaml ├── anytext_sd15_conv.yaml ├── anytext_sd15_perloss.yaml └── anytext_sd15_vit.yaml ├── ocr_recog ├── RNN.py ├── RecCTCHead.py ├── RecModel.py ├── RecMv1_enhance.py ├── RecSVTR.py ├── common.py ├── en_dict.txt └── ppocr_keys_v1.txt ├── ocr_weights ├── en_dict.txt ├── ppocr_keys_v1.txt ├── ppv3_rec.pth └── ppv3_rec_en.pth ├── preprocess_conditions.py ├── proj_3d_surface.py ├── requirements.txt ├── style.css ├── synthetic_dataset ├── generate_prompt_json.py ├── generate_text_transformation_pairs.py ├── nejm_test_en.txt ├── nejm_test_zh.txt ├── prompt.json ├── requirements.txt ├── restore_from_transformations.py ├── test_rectify.py ├── unet_dataloader.py ├── unet_inference.py ├── unet_models.py ├── unet_ocr.py ├── unet_train.py └── unet_train_config.yaml ├── t3_dataset.py ├── tool_add_anytext.py ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | training/ 4 | lightning_logs/ 5 | image_log/ 6 | 7 | *.pth 8 | *.pt 9 | *.ckpt 10 | *.safetensors 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | .vscode 142 | /show_results/ 143 | *-ori.py 144 | *.tar.gz 145 | /tmp_dir/ 146 | /tmp_files/ 147 | /SaveImages/ 148 | /*.png 149 | font/*.ttf 150 | 151 | *.jpg 152 | *.png 153 | *.jpeg 154 | synthetic_dataset/ 155 | frontend/ 156 | *.csv 157 | ccd/ 158 | fonts/ 159 | *.txt 160 | models/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## This is the official implementation of the paper [ControlText: Unlocking Controllable Fonts in Multilingual Text Rendering without Font Annotations](https://arxiv.org/abs/2502.10999) in PyTorch. 2 | [![Arxiv](https://img.shields.io/badge/ArXiv-Paper-B31B1B)](https://arxiv.org/abs/2502.10999) 3 | [![Google Scholar](https://img.shields.io/badge/Google_Scholar-Cite_Our_Paper-4085F4)](https://scholar.google.com/scholar?hl=en&as_sdt=0%2C39&q=ControlText%3A+Unlocking+Controllable+Fonts+in+Multilingual+Text+Rendering+without+Font+Annotations&btnG=) 4 | 5 | ## ✨ Overview 6 | 7 | Visual text rendering is a challenging task, especially when precise font control is desired. This work demonstrates that diffusion models can achieve **font-controllable** multilingual text rendering using **just raw images without font label annotations**. 8 | 9 | --- 10 | 11 | ## 🚀 Key Takeaways 12 | 13 | - **Font controls require no font label annotations:** 14 | A text segmentation model can capture nuanced font information in pixel space without requiring font label annotations in the dataset, enabling zero- shot generation on unseen languages and fonts, as well as scalable training on web-scale image datasets as long as they contain text. 15 | 16 | - **Evaluating ambiguous fonts in the open world:** 17 | Fuzzy font accuracy can be measured in the embed- ding space of a pretrained font classification model, utilizing our proposed metrics `l2@k` and `cos@k`. 18 | 19 | - **Supporting user-driven design flexibility:** 20 | Random perturbations can be applied to segmented glyphs. While this won’t affect the rendered text quality, it accounts for users not precisely aligning text to best locations and prevents models from rigidly replicating the pixel locations in glyphs. 21 | 22 | - **Working with foundation models:** 23 | With limited computational resources, we can still copilot foundational image generation models to perform localized text and font editing. 24 | 25 | ![Banner](banner.png) 26 | --- 27 | 28 | ## Citation 29 | If you find our work inspires you, please consider citing it. Thank you! 30 | 31 | @article{jiang2025controltext, 32 | title={ControlText: Unlocking Controllable Fonts in Multilingual Text Rendering without Font Annotations}, 33 | author={Jiang, Bowen and Yuan, Yuan and Bai, Xinyi and Hao, Zhuoqun and Yin, Alyson and Hu, Yaojie and Liao, Wenyu and Ungar, Lyle and Taylor, Camillo J}, 34 | journal={arXiv preprint arXiv:2502.10999}, 35 | year={2025} 36 | } 37 | 38 | ## 🔧 How to Train 39 | 40 | Our repository is based on the code of [AnyText](https://github.com/tyxsspa/AnyText). We build upon and extend it to enable user-controllable fonts in zero-shot. Below is a brief walkthrough: 41 | 42 | 1. **Prerequisites:** 43 | We use conda environment to manage all required packages. 44 | ``` 45 | conda env create -f environment.yml 46 | conda activate controltext 47 | ``` 48 | 49 | 3. **Preprocess Glyphs:** 50 | 51 | 52 | 5. **Configuration:** 53 | - Adjust hyperparameters such as `batch_size`, `grad_accum`, `learning_rate`, `logger_freq`, and `max_epochs` in the training script `train.py`. Please keep `mask_ratio = 1`. 54 | - Set paths for GPUs, checkpoints, model configuration file, image datasets, and preprocessed glyphs accordingly. 55 | 56 | 6. **Training Command:** 57 | Run the training script: 58 | ```bash 59 | python train.py 60 | ``` 61 | 62 | ## 🔮 Inference & Front-End 63 | 64 | The front-end code for user-friendly text and font editing are coming soon! Stay tuned for updates as we continue to enhance the project. 65 | 66 | ## 👩‍💻 Evaluation 67 | 1. **Our Generated Data** 68 | 69 | laion_controltext [Google Drive](https://drive.google.com/file/d/1sxzAENTWDAixkMFMHyOeXcyhZOq7WY2B/view?usp=sharing), laion_controltext_gly_lines (cropped regions for each line of text from the entire image) [Google Drive](https://drive.google.com/file/d/1JrJTkJ8oePXUo9d8E5QOVsh0DBWi82P_/view?usp=sharing), laion_controltext_gly_lines_grayscale (laion_controltext_gly_lines after text segmentation) [Google Drive](https://drive.google.com/file/d/1qSQs_NB3jUe08YZLaKmM42iJWqT7mmjA/view?usp=drive_link), laion_gly_lines_gt (cropped regions from input glyphs after text segmentation) [Google Drive](https://drive.google.com/file/d/1XiRu24gRiYwpODyjuJnW1XyJ9kd-1f9U/view?usp=drive_link) 70 | 71 | 72 | wukong_controltext [Google Drive](https://drive.google.com/file/d/1ZCeEsD4aCeK0OePNUHQ96Pp3Xq_f4pW2/view?usp=drive_link), wukong_controltext_gly_line [Google Drive](https://drive.google.com/file/d/1weseRPN5mNA2NNeOjUxFQxA7Fu6K4CuZ/view?usp=drive_link), wukong_controltext_glylines_grayscale [Google Drive](https://drive.google.com/file/d/1uyWyF_FwMhyAyVRsBQsTx9G7Ar5dZBMb/view?usp=drive_link), wukong_gly_lines_gt [Google Drive](https://drive.google.com/file/d/1XKsliU0-XVxj7YUyfbGaAq1PCED18s1a/view?usp=drive_link) 73 | 74 | 3. **Our Model Checkpoint** 75 | 76 | [Google Drive](https://drive.google.com/file/d/1fUNeKqoGhGutkcCFTHa3USkhChlfE_kQ/view?usp=sharing) 77 | 78 | 5. **Script for evaluating text accuracy:** 79 | 80 | Run the following script to calculate SenACC and NED scores for text accuracy, which will evaluate ```laion_controltext_gly_lines``` and ```wukong_controltext_gly_line```. 81 | ``` 82 | bash eval/eval_dgocr.sh 83 | ``` 84 | Run the following script to calculate FID score for overall image quality, which will evaluate ```laion_controltext``` and ```wukong_controltext```. 85 | ``` 86 | bash eval/eval_fid.sh 87 | ``` 88 | 89 | 7. **Script for evaluating font accuracy in the open world:** 90 | 91 | Run the following script to calculate the font accuracy 92 | ``` 93 | bash eval/eval_font.sh --generated_folder path/to/your/generated_folder --gt_folder path/to/your/gt_folder 94 | ``` 95 | In the argument, ```path/to/your/generated_folder``` should point to the directory containing your generated images, for example, ```laion_controltext_gly_lines_grayscale``` or ```wukong_controltext_glylines_grayscale```. Similarly, ```path/to/your/gt_folder``` should refer to the directory containing the ground-truth glyph images or the segmented glyphs used as input conditions, where we use ```laion_gly_lines_gt``` or ```wukong_gly_lines_gt```. 96 | 97 | --- 98 | ![Flows](flows.png) 99 | -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /models 3 | /log 4 | /scripts 5 | /pretrained 6 | **/__pycache__ 7 | **/.idea/* 8 | **/.vscode/* 9 | **/.ipynb_checkpoints/* 10 | **/old/* 11 | **/veryold/* 12 | *.ipynb 13 | **/build 14 | **/*.out 15 | -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/Rethinking-Text-Segmentation/configs/__init__.py -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/configs/cfg_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import copy 5 | import socket 6 | 7 | from easydict import EasyDict as edict 8 | 9 | cfg = edict() 10 | 11 | # -----------------------------BASE----------------------------- 12 | 13 | cfg.DEBUG = False 14 | cfg.EXPERIMENT_ID = 0 15 | cfg.GPU_DEVICE = 'all' 16 | cfg.CUDA = False 17 | cfg.MISC_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', 'log')) 18 | cfg.LOG_FILE = None 19 | cfg.RND_SEED = None 20 | cfg.RND_RECORDING = False 21 | # cfg.USE_FLOAT16 = False 22 | cfg.MATPLOTLIB_MODE = 'Agg' 23 | cfg.MAINLOOP_EXECUTE = True 24 | cfg.MAIN_CODE_PATH = None 25 | cfg.MAIN_CODE = [] 26 | cfg.SAVE_CODE = True 27 | cfg.COMPUTER_NAME = socket.gethostname() 28 | cfg.TORCH_VERSION = 'unknown' 29 | 30 | cfg.DIST_URL = 'tcp://127.0.0.1:11233' 31 | cfg.DIST_BACKEND = 'nccl' 32 | 33 | cfg_train = copy.deepcopy(cfg) 34 | cfg_test = copy.deepcopy(cfg) 35 | 36 | # -----------------------------TRAIN----------------------------- 37 | 38 | cfg_train.TRAIN = edict() 39 | cfg_train.TRAIN.BATCH_SIZE = None 40 | cfg_train.TRAIN.BATCH_SIZE_PER_GPU = None 41 | cfg_train.TRAIN.MAX_STEP = 0 42 | cfg_train.TRAIN.MAX_STEP_TYPE = None 43 | cfg_train.TRAIN.SKIP_PARTIAL = True 44 | # cfg_train.TRAIN.LR_ADJUST_MODE = None 45 | cfg_train.TRAIN.LR_ITER_BY = None 46 | cfg_train.TRAIN.OPTIMIZER = None 47 | cfg_train.TRAIN.DISPLAY = 0 48 | cfg_train.TRAIN.VISUAL = None 49 | cfg_train.TRAIN.SAVE_INIT_MODEL = True 50 | cfg_train.TRAIN.SAVE_CODE = True 51 | 52 | # -----------------------------TEST----------------------------- 53 | 54 | cfg_test.TEST = edict() 55 | cfg_test.TEST.BATCH_SIZE = None 56 | cfg_test.TEST.BATCH_SIZE_PER_GPU = None 57 | cfg_test.TEST.VISUAL = None 58 | 59 | # -----------------------------COMBINED----------------------------- 60 | 61 | cfg.update(cfg_train) 62 | cfg.update(cfg_test) 63 | -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/configs/cfg_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import copy 5 | 6 | from easydict import EasyDict as edict 7 | 8 | cfg = edict() 9 | cfg.MODEL_NAME = None 10 | # cfg.CONV_TYPE = 'conv' 11 | # cfg.BN_TYPE = 'bn' 12 | # cfg.RELU_TYPE = 'relu' 13 | 14 | # resnet 15 | cfg_resnet = copy.deepcopy(cfg) 16 | cfg_resnet.MODEL_NAME = 'resnet' 17 | cfg_resnet.RESNET = edict() 18 | cfg_resnet.RESNET.MODEL_TAGS = None 19 | cfg_resnet.RESNET.PRETRAINED_PTH = None 20 | cfg_resnet.RESNET.BN_TYPE = 'bn' 21 | cfg_resnet.RESNET.RELU_TYPE = 'relu' 22 | 23 | # deeplab 24 | cfg_deeplab = copy.deepcopy(cfg) 25 | cfg_deeplab.MODEL_NAME = 'deeplab' 26 | cfg_deeplab.DEEPLAB = edict() 27 | cfg_deeplab.DEEPLAB.MODEL_TAGS = None 28 | cfg_deeplab.DEEPLAB.PRETRAINED_PTH = None 29 | cfg_deeplab.DEEPLAB.FREEZE_BACKBONE_BN = False 30 | cfg_deeplab.DEEPLAB.BN_TYPE = 'bn' 31 | cfg_deeplab.DEEPLAB.RELU_TYPE = 'relu' 32 | # cfg_deeplab.DEEPLAB.ASPP_DROPOUT_TYPE = 'dropout|0.5' 33 | cfg_deeplab.DEEPLAB.ASPP_WITH_GAP = True 34 | # cfg_deeplab.DEEPLAB.DECODER_DROPOUT2_TYPE = 'dropout|0.5' 35 | # cfg_deeplab.DEEPLAB.DECODER_DROPOUT3_TYPE = 'dropout|0.1' 36 | cfg_deeplab.RESNET = cfg_resnet.RESNET 37 | 38 | # hrnet 39 | cfg_hrnet = copy.deepcopy(cfg) 40 | cfg_hrnet.MODEL_NAME = 'hrnet' 41 | cfg_hrnet.HRNET = edict() 42 | cfg_hrnet.HRNET.MODEL_TAGS = None 43 | cfg_hrnet.HRNET.PRETRAINED_PTH = None 44 | cfg_hrnet.HRNET.BN_TYPE = 'bn' 45 | cfg_hrnet.HRNET.RELU_TYPE = 'relu' 46 | 47 | # texrnet 48 | cfg_texrnet = copy.deepcopy(cfg) 49 | cfg_texrnet.MODEL_NAME = 'texrnet' 50 | cfg_texrnet.TEXRNET = edict() 51 | cfg_texrnet.TEXRNET.MODEL_TAGS = None 52 | cfg_texrnet.TEXRNET.PRETRAINED_PTH = None 53 | cfg_texrnet.RESNET = cfg_resnet.RESNET 54 | cfg_texrnet.DEEPLAB = cfg_deeplab.DEEPLAB 55 | -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/hrnet_code/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__/ 3 | *.py[co] 4 | data/ 5 | log/ 6 | output/ 7 | pretrained_models 8 | scripts/ 9 | detail-api/ 10 | data/list -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/hrnet_code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2019] [Microsoft] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ======================================================================================= 24 | 3-clause BSD licenses 25 | ======================================================================================= 26 | 1. syncbn - For details, see lib/models/syncbn/LICENSE 27 | Copyright (c) 2017 mapillary 28 | -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/hrnet_code/requirements.txt: -------------------------------------------------------------------------------- 1 | EasyDict==1.7 2 | opencv-python==3.4.2.17 3 | shapely==1.6.4 4 | Cython 5 | scipy 6 | pandas 7 | pyyaml 8 | json_tricks 9 | scikit-image 10 | yacs>=0.1.5 11 | tensorboardX>=1.6 12 | tqdm 13 | ninja 14 | -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/hrnet_code/tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path as osp 12 | import sys 13 | 14 | 15 | def add_path(path): 16 | if path not in sys.path: 17 | sys.path.insert(0, path) 18 | 19 | this_dir = osp.dirname(__file__) 20 | 21 | lib_path = osp.join(this_dir, '..', 'lib') 22 | add_path(lib_path) 23 | -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/hrnet_code/tools/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import argparse 8 | import os 9 | import pprint 10 | import shutil 11 | import sys 12 | 13 | import logging 14 | import time 15 | import timeit 16 | from pathlib import Path 17 | 18 | import numpy as np 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.backends.cudnn as cudnn 23 | 24 | import _init_paths 25 | import models 26 | import datasets 27 | from config import config 28 | from config import update_config 29 | from core.function import testval, test 30 | from utils.modelsummary import get_model_summary 31 | from utils.utils import create_logger, FullModel 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser(description='Train segmentation network') 35 | 36 | parser.add_argument('--cfg', 37 | help='experiment configure file name', 38 | required=True, 39 | type=str) 40 | parser.add_argument('opts', 41 | help="Modify config options using the command-line", 42 | default=None, 43 | nargs=argparse.REMAINDER) 44 | 45 | args = parser.parse_args() 46 | update_config(config, args) 47 | 48 | return args 49 | 50 | def main(): 51 | args = parse_args() 52 | 53 | logger, final_output_dir, _ = create_logger( 54 | config, args.cfg, 'test') 55 | 56 | logger.info(pprint.pformat(args)) 57 | logger.info(pprint.pformat(config)) 58 | 59 | # cudnn related setting 60 | cudnn.benchmark = config.CUDNN.BENCHMARK 61 | cudnn.deterministic = config.CUDNN.DETERMINISTIC 62 | cudnn.enabled = config.CUDNN.ENABLED 63 | 64 | # build model 65 | if torch.__version__.startswith('1'): 66 | module = eval('models.'+config.MODEL.NAME) 67 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d 68 | model = eval('models.'+config.MODEL.NAME + 69 | '.get_seg_model')(config) 70 | 71 | dump_input = torch.rand( 72 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]) 73 | ) 74 | logger.info(get_model_summary(model.cuda(), dump_input.cuda())) 75 | 76 | if config.TEST.MODEL_FILE: 77 | model_state_file = config.TEST.MODEL_FILE 78 | else: 79 | model_state_file = os.path.join(final_output_dir, 'final_state.pth') 80 | logger.info('=> loading model from {}'.format(model_state_file)) 81 | 82 | pretrained_dict = torch.load(model_state_file) 83 | if 'state_dict' in pretrained_dict: 84 | pretrained_dict = pretrained_dict['state_dict'] 85 | model_dict = model.state_dict() 86 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() 87 | if k[6:] in model_dict.keys()} 88 | for k, _ in pretrained_dict.items(): 89 | logger.info( 90 | '=> loading {} from pretrained model'.format(k)) 91 | model_dict.update(pretrained_dict) 92 | model.load_state_dict(model_dict) 93 | 94 | gpus = list(config.GPUS) 95 | model = nn.DataParallel(model, device_ids=gpus).cuda() 96 | 97 | # prepare data 98 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0]) 99 | test_dataset = eval('datasets.'+config.DATASET.DATASET)( 100 | root=config.DATASET.ROOT, 101 | list_path=config.DATASET.TEST_SET, 102 | num_samples=None, 103 | num_classes=config.DATASET.NUM_CLASSES, 104 | multi_scale=False, 105 | flip=False, 106 | ignore_label=config.TRAIN.IGNORE_LABEL, 107 | base_size=config.TEST.BASE_SIZE, 108 | crop_size=test_size, 109 | downsample_rate=1) 110 | 111 | testloader = torch.utils.data.DataLoader( 112 | test_dataset, 113 | batch_size=1, 114 | shuffle=False, 115 | num_workers=config.WORKERS, 116 | pin_memory=True) 117 | 118 | start = timeit.default_timer() 119 | if 'val' in config.DATASET.TEST_SET: 120 | mean_IoU, IoU_array, pixel_acc, mean_acc = testval(config, 121 | test_dataset, 122 | testloader, 123 | model) 124 | 125 | msg = 'MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, \ 126 | Mean_Acc: {: 4.4f}, Class IoU: '.format(mean_IoU, 127 | pixel_acc, mean_acc) 128 | logging.info(msg) 129 | logging.info(IoU_array) 130 | elif 'test' in config.DATASET.TEST_SET: 131 | test(config, 132 | test_dataset, 133 | testloader, 134 | model, 135 | sv_dir=final_output_dir) 136 | 137 | end = timeit.default_timer() 138 | logger.info('Mins: %d' % np.int((end-start)/60)) 139 | logger.info('Done') 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /Rethinking-Text-Segmentation/requirement.txt: -------------------------------------------------------------------------------- 1 | torch==1.6 2 | torchvision==0.7 3 | matplotlib==3.3.2 4 | opencv-python==4.5.1.48 5 | easydict==1.9 6 | flash_attn 7 | -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/banner.png -------------------------------------------------------------------------------- /cldm/embedding_manager.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Alibaba, Inc. and its affiliates. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from functools import partial 8 | from ldm.modules.diffusionmodules.util import conv_nd, linear 9 | 10 | 11 | def get_clip_token_for_string(tokenizer, string): 12 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, 13 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 14 | tokens = batch_encoding["input_ids"] 15 | assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" 16 | return tokens[0, 1] 17 | 18 | 19 | def get_bert_token_for_string(tokenizer, string): 20 | token = tokenizer(string) 21 | assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" 22 | token = token[0, 1] 23 | return token 24 | 25 | 26 | def get_clip_vision_emb(encoder, processor, img): 27 | _img = img.repeat(1, 3, 1, 1)*255 28 | inputs = processor(images=_img, return_tensors="pt") 29 | inputs['pixel_values'] = inputs['pixel_values'].to(img.device) 30 | outputs = encoder(**inputs) 31 | emb = outputs.image_embeds 32 | return emb 33 | 34 | 35 | def get_recog_emb(encoder, img_list): 36 | _img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list] 37 | encoder.predictor.eval() 38 | _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) 39 | return preds_neck 40 | 41 | 42 | def pad_H(x): 43 | _, _, H, W = x.shape 44 | p_top = (W - H) // 2 45 | p_bot = W - H - p_top 46 | return F.pad(x, (0, 0, p_top, p_bot)) 47 | 48 | 49 | class EncodeNet(nn.Module): 50 | def __init__(self, in_channels, out_channels): 51 | super(EncodeNet, self).__init__() 52 | chan = 16 53 | n_layer = 4 # downsample 54 | 55 | self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1) 56 | self.conv_list = nn.ModuleList([]) 57 | _c = chan 58 | for i in range(n_layer): 59 | self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2)) 60 | _c *= 2 61 | self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1) 62 | self.avgpool = nn.AdaptiveAvgPool2d(1) 63 | self.act = nn.SiLU() 64 | 65 | def forward(self, x): 66 | x = self.act(self.conv1(x)) 67 | for layer in self.conv_list: 68 | x = self.act(layer(x)) 69 | x = self.act(self.conv2(x)) 70 | x = self.avgpool(x) 71 | x = x.view(x.size(0), -1) 72 | return x 73 | 74 | 75 | class EmbeddingManager(nn.Module): 76 | def __init__( 77 | self, 78 | embedder, 79 | valid=True, 80 | glyph_channels=20, 81 | position_channels=1, 82 | placeholder_string='*', 83 | add_pos=False, 84 | emb_type='ocr', 85 | **kwargs 86 | ): 87 | super().__init__() 88 | if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder 89 | get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) 90 | token_dim = 768 91 | if hasattr(embedder, 'vit'): 92 | assert emb_type == 'vit' 93 | self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor) 94 | self.get_recog_emb = None 95 | else: # using LDM's BERT encoder 96 | get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) 97 | token_dim = 1280 98 | self.token_dim = token_dim 99 | self.emb_type = emb_type 100 | 101 | self.add_pos = add_pos 102 | if add_pos: 103 | self.position_encoder = EncodeNet(position_channels, token_dim) 104 | if emb_type == 'ocr': 105 | self.proj = linear(40*64, token_dim) 106 | if emb_type == 'conv': 107 | self.glyph_encoder = EncodeNet(glyph_channels, token_dim) 108 | 109 | self.placeholder_token = get_token_for_string(placeholder_string) 110 | 111 | def encode_text(self, text_info): 112 | if self.get_recog_emb is None and self.emb_type == 'ocr': 113 | self.get_recog_emb = partial(get_recog_emb, self.recog) 114 | 115 | gline_list = [] 116 | pos_list = [] 117 | for i in range(len(text_info['n_lines'])): # sample index in a batch 118 | n_lines = text_info['n_lines'][i] 119 | for j in range(n_lines): # line 120 | gline_list += [text_info['gly_line'][j][i:i+1]] 121 | if self.add_pos: 122 | pos_list += [text_info['positions'][j][i:i+1]] 123 | 124 | if len(gline_list) > 0: 125 | if self.emb_type == 'ocr': 126 | recog_emb = self.get_recog_emb(gline_list) 127 | enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) 128 | elif self.emb_type == 'vit': 129 | enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0))) 130 | elif self.emb_type == 'conv': 131 | enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0))) 132 | if self.add_pos: 133 | enc_pos = self.position_encoder(torch.cat(gline_list, dim=0)) 134 | enc_glyph = enc_glyph+enc_pos 135 | 136 | self.text_embs_all = [] 137 | n_idx = 0 138 | for i in range(len(text_info['n_lines'])): # sample index in a batch 139 | n_lines = text_info['n_lines'][i] 140 | text_embs = [] 141 | for j in range(n_lines): # line 142 | text_embs += [enc_glyph[n_idx:n_idx+1]] 143 | n_idx += 1 144 | self.text_embs_all += [text_embs] 145 | 146 | def forward( 147 | self, 148 | tokenized_text, 149 | embedded_text, 150 | ): 151 | b, device = tokenized_text.shape[0], tokenized_text.device 152 | for i in range(b): 153 | idx = tokenized_text[i] == self.placeholder_token.to(device) 154 | if sum(idx) > 0: 155 | if i >= len(self.text_embs_all): 156 | print('truncation for log images...') 157 | break 158 | text_emb = torch.cat(self.text_embs_all[i], dim=0) 159 | if sum(idx) != len(text_emb): 160 | print('truncation for long caption...') 161 | embedded_text[i][idx] = text_emb[:sum(idx)] 162 | return embedded_text 163 | 164 | def embedding_parameters(self): 165 | return self.parameters() 166 | -------------------------------------------------------------------------------- /cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import ldm.modules.encoders.modules 5 | import ldm.modules.attention 6 | 7 | from transformers import logging 8 | from ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path, cond_stage_path=None, use_fp16=False): 25 | config = OmegaConf.load(config_path) 26 | if cond_stage_path: 27 | config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked 28 | if use_fp16: 29 | config.model.params.use_fp16 = True 30 | config.model.params.control_stage_config.params.use_fp16 = True 31 | config.model.params.unet_config.params.use_fp16 = True 32 | model = instantiate_from_config(config.model).cpu() 33 | print(f'Loaded model config from [{config_path}]') 34 | return model 35 | -------------------------------------------------------------------------------- /dataset_util.py: -------------------------------------------------------------------------------- 1 | import ujson 2 | import json 3 | import pathlib 4 | 5 | __all__ = ['load', 'save', 'show_bbox_on_image'] 6 | 7 | 8 | def load(file_path: str): 9 | file_path = pathlib.Path(file_path) 10 | func_dict = {'.txt': load_txt, '.json': load_json, '.list': load_txt} 11 | assert file_path.suffix in func_dict 12 | return func_dict[file_path.suffix](file_path) 13 | 14 | 15 | def load_txt(file_path: str): 16 | with open(file_path, 'r', encoding='utf8') as f: 17 | content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()] 18 | return content 19 | 20 | 21 | def load_json(file_path: str): 22 | with open(file_path, 'rb') as f: 23 | content = f.read() 24 | return ujson.loads(content) 25 | 26 | 27 | def save(data, file_path): 28 | file_path = pathlib.Path(file_path) 29 | func_dict = {'.txt': save_txt, '.json': save_json} 30 | assert file_path.suffix in func_dict 31 | return func_dict[file_path.suffix](data, file_path) 32 | 33 | 34 | def save_txt(data, file_path): 35 | if not isinstance(data, list): 36 | data = [data] 37 | with open(file_path, mode='w', encoding='utf8') as f: 38 | f.write('\n'.join(data)) 39 | 40 | 41 | def save_json(data, file_path): 42 | with open(file_path, 'w', encoding='utf-8') as json_file: 43 | json.dump(data, json_file, ensure_ascii=False, indent=4) 44 | 45 | 46 | def show_bbox_on_image(image, polygons=None, txt=None, color=None, font_path='./font/Arial_Unicode.ttf'): 47 | from PIL import ImageDraw, ImageFont 48 | image = image.convert('RGB') 49 | draw = ImageDraw.Draw(image) 50 | if len(txt) == 0: 51 | txt = None 52 | if color is None: 53 | color = (255, 0, 0) 54 | if txt is not None: 55 | font = ImageFont.truetype(font_path, 20) 56 | for i, box in enumerate(polygons): 57 | box = box[0] 58 | if txt is not None: 59 | draw.text((int(box[0][0]) + 20, int(box[0][1]) - 20), str(txt[i]), fill='red', font=font) 60 | for j in range(len(box) - 1): 61 | draw.line((box[j][0], box[j][1], box[j + 1][0], box[j + 1][1]), fill=color, width=2) 62 | draw.line((box[-1][0], box[-1][1], box[0][0], box[0][1]), fill=color, width=2) 63 | return image 64 | 65 | 66 | def show_glyphs(glyphs, name): 67 | import numpy as np 68 | import cv2 69 | size = 64 70 | gap = 5 71 | n_char = 20 72 | canvas = np.ones((size, size*n_char + gap*(n_char-1), 1))*0.5 73 | x = 0 74 | for i in range(glyphs.shape[-1]): 75 | canvas[:, x:x + size, :] = glyphs[..., i:i+1] 76 | x += size+gap 77 | cv2.imwrite(name, canvas*255) 78 | -------------------------------------------------------------------------------- /docs/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/docs/demo.jpg -------------------------------------------------------------------------------- /docs/eval.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/docs/eval.jpg -------------------------------------------------------------------------------- /docs/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/docs/framework.jpg -------------------------------------------------------------------------------- /docs/gallery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/docs/gallery.png -------------------------------------------------------------------------------- /docs/reproduce.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/docs/reproduce.jpg -------------------------------------------------------------------------------- /docs/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/docs/sample.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: controltext 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.10.6 7 | - pip=23.0.1 8 | - cudatoolkit=11.8 9 | - numpy=1.23.3 10 | - cython==0.29.33 11 | - pip: 12 | - Pillow==9.5.0 13 | - gradio==3.50.0 14 | - albumentations==0.4.3 15 | - opencv-python==4.7.0.72 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.5.0 19 | - omegaconf==2.2.3 20 | - test-tube==0.7.5 21 | - streamlit==1.20.0 22 | - einops==0.4.1 23 | - transformers==4.30.2 24 | - webdataset==0.2.5 25 | - kornia==0.6.7 26 | - open_clip_torch==2.7.0 27 | - torchmetrics==0.11.4 28 | - timm==0.6.7 29 | - addict==2.4.0 30 | - yapf==0.32.0 31 | - safetensors==0.4.0 32 | - basicsr==1.4.2 33 | - jieba==0.42.1 34 | - modelscope==1.10.0 35 | - tensorflow==2.13.0 36 | - torch==2.0.1 37 | - torchvision==0.15.2 38 | - easydict==1.10 39 | - xformers==0.0.20 40 | - subword-nmt==0.3.8 41 | - sacremoses==0.0.53 42 | - sentencepiece==0.1.99 43 | - fsspec 44 | - diffusers==0.10.2 45 | - ujson -------------------------------------------------------------------------------- /eval/anytext_multiGPUs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import copy 4 | import argparse 5 | import pathlib 6 | import json 7 | 8 | 9 | def load(file_path: str): 10 | file_path = pathlib.Path(file_path) 11 | func_dict = {'.json': load_json} 12 | assert file_path.suffix in func_dict 13 | return func_dict[file_path.suffix](file_path) 14 | 15 | 16 | def load_json(file_path: str): 17 | with open(file_path, 'r', encoding='utf8') as f: 18 | content = json.load(f) 19 | return content 20 | 21 | 22 | def save(data, file_path): 23 | file_path = pathlib.Path(file_path) 24 | func_dict = {'.json': save_json} 25 | assert file_path.suffix in func_dict 26 | return func_dict[file_path.suffix](data, file_path) 27 | 28 | 29 | def save_json(data, file_path): 30 | with open(file_path, 'w', encoding='utf-8') as json_file: 31 | json.dump(data, json_file, ensure_ascii=False, indent=4) 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--model_path", 38 | type=str, 39 | default='models/anytext_v1.1.ckpt', 40 | help='path of model' 41 | ) 42 | parser.add_argument( 43 | "--gpus", 44 | type=str, 45 | default='0,1,2,3,4,5,6,7', 46 | help='gpus for inference' 47 | ) 48 | parser.add_argument( 49 | "--output_dir", 50 | type=str, 51 | default='./anytext_v1.1_laion_generated/', 52 | help="output path" 53 | ) 54 | parser.add_argument( 55 | "--json_path", 56 | type=str, 57 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json', 58 | help="json path for evaluation dataset" 59 | ) 60 | args = parser.parse_args() 61 | return args 62 | 63 | 64 | if __name__ == "__main__": 65 | args = parse_args() 66 | ckpt_path = args.model_path 67 | gpus = args.gpus 68 | output_dir = args.output_dir 69 | json_path = args.json_path 70 | 71 | USING_DLC = False 72 | if USING_DLC: 73 | json_path = json_path.replace('/data/vdb', '/mnt/data', 1) 74 | output_dir = output_dir.replace('/data/vdb', '/mnt/data', 1) 75 | 76 | exec_path = './eval/anytext_singleGPU.py' 77 | continue_gen = True # if True, not clear output_dir, and generate rest images. 78 | tmp_dir = './tmp_dir' 79 | if os.path.exists(tmp_dir): 80 | shutil.rmtree(tmp_dir) 81 | os.makedirs(tmp_dir) 82 | 83 | if not continue_gen: 84 | if os.path.exists(output_dir): 85 | shutil.rmtree(output_dir) 86 | os.makedirs(output_dir) 87 | else: 88 | if not os.path.exists(output_dir): 89 | os.makedirs(output_dir) 90 | 91 | os.system('sleep 1') 92 | 93 | gpu_ids = [int(i) for i in gpus.split(',')] 94 | nproc = len(gpu_ids) 95 | all_lines = load(json_path) 96 | split_file = [] 97 | length = len(all_lines['data_list']) // nproc 98 | cmds = [] 99 | for i in range(nproc): 100 | start, end = i*length, (i+1)*length 101 | if i == nproc - 1: 102 | end = len(all_lines['data_list']) 103 | temp_lines = copy.deepcopy(all_lines) 104 | temp_lines['data_list'] = temp_lines['data_list'][start:end] 105 | tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json') 106 | save(temp_lines, tmp_file) 107 | os.system('sleep 1') 108 | cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --input_json {tmp_file} --output_dir {output_dir} --ckpt_path {ckpt_path} && echo proc-{i} done!'] 109 | cmds = ' & '.join(cmds) 110 | os.system(cmds) 111 | print('Done.') 112 | os.system('sleep 2') 113 | shutil.rmtree(tmp_dir) 114 | 115 | ''' 116 | command to kill the task after running: 117 | $ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9 118 | ''' 119 | -------------------------------------------------------------------------------- /eval/controlnet_multiGPUs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import copy 4 | import argparse 5 | import pathlib 6 | import json 7 | 8 | 9 | def load(file_path: str): 10 | file_path = pathlib.Path(file_path) 11 | func_dict = {'.json': load_json} 12 | assert file_path.suffix in func_dict 13 | return func_dict[file_path.suffix](file_path) 14 | 15 | 16 | def load_json(file_path: str): 17 | with open(file_path, 'r', encoding='utf8') as f: 18 | content = json.load(f) 19 | return content 20 | 21 | 22 | def save(data, file_path): 23 | file_path = pathlib.Path(file_path) 24 | func_dict = {'.json': save_json} 25 | assert file_path.suffix in func_dict 26 | return func_dict[file_path.suffix](data, file_path) 27 | 28 | 29 | def save_json(data, file_path): 30 | with open(file_path, 'w', encoding='utf-8') as json_file: 31 | json.dump(data, json_file, ensure_ascii=False, indent=4) 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--model_path", 38 | type=str, 39 | default='/home/yuxiang.tyx/projects/AnyText/models/control_sd15_canny.pth', 40 | help='path of model' 41 | ) 42 | parser.add_argument( 43 | "--gpus", 44 | type=str, 45 | default='0,1,2,3,4,5,6,7', 46 | help='gpus for inference' 47 | ) 48 | parser.add_argument( 49 | "--output_dir", 50 | type=str, 51 | default='./controlnet_laion_generated/', 52 | help="output path" 53 | ) 54 | parser.add_argument( 55 | "--glyph_dir", 56 | type=str, 57 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion', 58 | help="path of glyph images from anytext evaluation dataset" 59 | ) 60 | parser.add_argument( 61 | "--json_path", 62 | type=str, 63 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json', 64 | help="json path for evaluation dataset" 65 | ) 66 | args = parser.parse_args() 67 | return args 68 | 69 | 70 | if __name__ == "__main__": 71 | args = parse_args() 72 | output_dir = args.output_dir 73 | 74 | tmp_dir = './tmp_dir' 75 | exec_path = './controlnet_singleGPU.py' 76 | continue_gen = True # if True, not clear output_dir, and generate rest images. 77 | 78 | if os.path.exists(tmp_dir): 79 | shutil.rmtree(tmp_dir) 80 | os.makedirs(tmp_dir) 81 | 82 | if not continue_gen: 83 | if os.path.exists(output_dir): 84 | shutil.rmtree(output_dir) 85 | os.makedirs(output_dir) 86 | else: 87 | if not os.path.exists(output_dir): 88 | os.makedirs(output_dir) 89 | 90 | os.system('sleep 1') 91 | gpu_ids = [int(i) for i in args.gpus.split(',')] 92 | nproc = len(gpu_ids) 93 | all_lines = load(args.json_path) 94 | split_file = [] 95 | length = len(all_lines['data_list']) // nproc 96 | cmds = [] 97 | for i in range(nproc): 98 | start, end = i*length, (i+1)*length 99 | if i == nproc - 1: 100 | end = len(all_lines['data_list']) 101 | temp_lines = copy.deepcopy(all_lines) 102 | temp_lines['data_list'] = temp_lines['data_list'][start:end] 103 | tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json') 104 | save(temp_lines, tmp_file) 105 | os.system('sleep 1') 106 | cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --json_path {tmp_file} --output_dir {output_dir} --model_path {args.model_path} --glyph_dir {args.glyph_dir} && echo proc-{i} done!'] 107 | cmds = ' & '.join(cmds) 108 | os.system(cmds) 109 | print('Done.') 110 | os.system('sleep 2') 111 | shutil.rmtree(tmp_dir) 112 | 113 | 114 | ''' 115 | command to kill the task after running: 116 | $ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9 117 | ''' 118 | -------------------------------------------------------------------------------- /eval/eval_fid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python -m pytorch_fid \ 4 | /tmp/datasets/AnyWord-3M/AnyText-Benchmark/FID/laion-40k \ 5 | ./eval/laion_controltext -------------------------------------------------------------------------------- /eval/eval_ocr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=7 3 | python eval/eval_dgocr.py \ 4 | --img_dir ./eval/wukong_ablation_gly_lines \ 5 | --json_path /tmp/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/wukong_word/test1k.json \ 6 | --glyph_path ./Rethinking-Text-Segmentation/log/images/output/anytext_benchmark/wukong_word 7 | 8 | # #!/bin/bash 9 | # export CUDA_VISIBLE_DEVICES=7 10 | # python eval/eval_dgocr.py \ 11 | # --img_dir ./eval/textdiffuser_laion_generated_gly_lines \ 12 | # --json_path /tmp/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/laion_word/test1k.json \ 13 | # --glyph_path ./Rethinking-Text-Segmentation/log/images/output/anytext_benchmark/laion_word 14 | 15 | # #!/bin/bash 16 | # export CUDA_VISIBLE_DEVICES=7 17 | # python eval/eval_dgocr.py \ 18 | # --img_dir ./eval/glyphcontrol_laion_generated_gly_lines \ 19 | # --json_path /tmp/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/laion_word/test1k.json \ 20 | # --glyph_path ./Rethinking-Text-Segmentation/log/images/output/anytext_benchmark/laion_word 21 | 22 | # #!/bin/bash 23 | # export CUDA_VISIBLE_DEVICES=7 24 | # python eval/eval_dgocr.py \ 25 | # --img_dir ./eval/controlnet_wukong_generated_gly_lines \ 26 | # --json_path /tmp/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/wukong_word/test1k.json \ 27 | # --glyph_path ./Rethinking-Text-Segmentation/log/images/output/anytext_benchmark/wukong_word -------------------------------------------------------------------------------- /eval/gen_glyph.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/render_glyph_imgs.py \ 3 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json \ 4 | --output_dir /data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion 5 | -------------------------------------------------------------------------------- /eval/gen_imgs_anytext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/anytext_singleGPU.py \ 3 | --ckpt_path ./models/anytext_v1.1.ckpt \ 4 | --json_path /tmp/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/laion_word/test1k.json \ 5 | --output_dir ./eval/laion_anytext \ 6 | -------------------------------------------------------------------------------- /eval/gen_imgs_controlnet_canny.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python controlnet_multiGPUs.py \ 3 | --model_path /home/yuxiang.tyx/projects/AnyText/models/control_sd15_canny.pth \ 4 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json \ 5 | --glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong \ 6 | --output_dir ./controlnet_wukong_generated \ 7 | --gpus 0,1,2,3,4,5,6,7 8 | -------------------------------------------------------------------------------- /eval/gen_imgs_controltext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/anytext_singleGPU.py \ 3 | --ckpt_path ./models/lightning_logs/version_9/checkpoints/last.ckpt \ 4 | --json_path /tmp/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/laion_word/test1k.json \ 5 | --output_dir ./eval/laion_controltext \ 6 | --glyph_path ./Rethinking-Text-Segmentation/log/images/output/anytext_benchmark/laion_word -------------------------------------------------------------------------------- /eval/gen_imgs_glyphcontrol.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python glyphcontrol_multiGPUs.py \ 3 | --model_path checkpoints/laion10M_epoch_6_model_ema_only.ckpt \ 4 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json \ 5 | --glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion \ 6 | --output_dir ./glyphcontrol_laion_generated \ 7 | --gpus 0,1,2,3,4,5,6,7 8 | -------------------------------------------------------------------------------- /eval/gen_imgs_textdiffuser.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python textdiffuser_multiGPUs.py \ 3 | --model_path textdiffuser-ckpt/diffusion_backbone \ 4 | --json_path /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json \ 5 | --glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong \ 6 | --output_dir ./textdiffuser_wukong_generated \ 7 | --gpus 0,1,2,3,4,5,6,7 8 | -------------------------------------------------------------------------------- /eval/get_glyph_lines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import numpy as np 5 | 6 | # # JSON path 7 | # wukong_test1k_json_path = '/pool/bwjiang/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/wukong_word/test1k.json' 8 | # laion_test1k_json_path = '/pool/bwjiang/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/laion_word/test1k.json' 9 | # 10 | # # Image path for laion 11 | # laion_image_dir = '/pool/bwjiang/controltext/eval/laion_controltext_ploss2' 12 | # laion_bw_image_dir = '/pool/bwjiang/controltext/Rethinking-Text-Segmentation/log/images/output/laion_ploss2' 13 | # 14 | # # output directory for laion glyphs 15 | # laion_glyph_dir = '/pool/bwjiang/controltext/eval/laion_controltext_ploss2_gly_lines' 16 | # laion_bw_glyph_dir = '/pool/bwjiang/controltext/eval/laion_controltext_ploss2_gly_lines_black_and_white' 17 | # 18 | # # Image path for wukong 19 | # wukong_image_dir = '/pool/bwjiang/controltext/eval/wukong_controltext_ploss2' 20 | # wukong_bw_image_dir = '/pool/bwjiang/controltext/Rethinking-Text-Segmentation/log/images/output/wukong_ploss2' 21 | # 22 | # # output directory for wukong glyphs 23 | # wukong_glyph_dir = '/pool/bwjiang/controltext/eval/wukong_controltext_ploss2_gly_lines' 24 | # wukong_bw_glyph_dir = '/pool/bwjiang/controltext/eval/wukong_controltext_ploss2_gly_lines_black_and_white' 25 | 26 | 27 | # # JSON path 28 | # wukong_test1k_json_path = '/pool/bwjiang/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/wukong_word/test1k.json' 29 | # laion_test1k_json_path = '/pool/bwjiang/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/laion_word/test1k.json' 30 | # 31 | # # Image path for laion 32 | # laion_image_dir = '/pool/bwjiang/controltext/eval/anytext_eval_imgs/textdiffuser_laion_generated/' 33 | # laion_bw_image_dir = '/pool/bwjiang/controltext/Rethinking-Text-Segmentation/log/images/output/textdiffuser_laion' 34 | # 35 | # # output directory for laion glyphs 36 | # laion_glyph_dir = '/pool/bwjiang/controltext/eval/textdiffuser_laion_generated_gly_lines' 37 | # laion_bw_glyph_dir = '/pool/bwjiang/controltext/eval/textdiffuser_laion_generated_gly_lines_black_and_white' 38 | # 39 | # # Image path for wukong 40 | # wukong_image_dir = '/pool/bwjiang/controltext/eval/anytext_eval_imgs/textdiffuser_wukong_generated/' 41 | # wukong_bw_image_dir = '/pool/bwjiang/controltext/Rethinking-Text-Segmentation/log/images/output/textdiffuser_wukong' 42 | # 43 | # # output directory for wukong glyphs 44 | # wukong_glyph_dir = '/pool/bwjiang/controltext/eval/textdiffuser_wukong_generated_gly_lines' 45 | # wukong_bw_glyph_dir = '/pool/bwjiang/controltext/eval/textdiffuser_wukong_generated_gly_lines_black_and_white' 46 | 47 | 48 | # JSON path 49 | wukong_test1k_json_path = '/pool/bwjiang/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/wukong_word/test1k.json' 50 | laion_test1k_json_path = '/pool/bwjiang/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/laion_word/test1k.json' 51 | 52 | # Image path for laion 53 | laion_image_dir = '/pool/bwjiang/controltext/eval/anytext_eval_imgs/glyphcontrol_laion_generated/' 54 | laion_bw_image_dir = '/pool/bwjiang/controltext/Rethinking-Text-Segmentation/log/images/output/glyphcontrol_laion' 55 | 56 | # output directory for laion glyphs 57 | laion_glyph_dir = '/pool/bwjiang/controltext/eval/glyphcontrol_laion_generated_gly_lines' 58 | laion_bw_glyph_dir = '/pool/bwjiang/controltext/eval/glyphcontrol_laion_generated_gly_lines_black_and_white' 59 | 60 | # Image path for wukong 61 | wukong_image_dir = '/pool/bwjiang/controltext/eval/anytext_eval_imgs/glyphcontrol_wukong_generated/' 62 | wukong_bw_image_dir = '/pool/bwjiang/controltext/Rethinking-Text-Segmentation/log/images/output/glyphcontrol_wukong' 63 | 64 | # output directory for wukong glyphs 65 | wukong_glyph_dir = '/pool/bwjiang/controltext/eval/glyphcontrol_wukong_generated_gly_lines' 66 | wukong_bw_glyph_dir = '/pool/bwjiang/controltext/eval/glyphcontrol_wukong_generated_gly_lines_black_and_white' 67 | 68 | 69 | # # JSON path 70 | # wukong_test1k_json_path = '/pool/bwjiang/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/wukong_word/test1k.json' 71 | # laion_test1k_json_path = '/pool/bwjiang/datasets/AnyWord-3M/AnyText-Benchmark/benchmark/laion_word/test1k.json' 72 | # 73 | # # Image path for laion 74 | # laion_image_dir = '/pool/bwjiang/controltext/eval/anytext_eval_imgs/controlnet_laion_generated/' 75 | # laion_bw_image_dir = '/pool/bwjiang/controltext/Rethinking-Text-Segmentation/log/images/output/controlnet_laion' 76 | # 77 | # # output directory for laion glyphs 78 | # laion_glyph_dir = '/pool/bwjiang/controltext/eval/controlnet_laion_generated_gly_lines' 79 | # laion_bw_glyph_dir = '/pool/bwjiang/controltext/eval/controlnet_laion_generated_gly_lines_black_and_white' 80 | # 81 | # # Image path for wukong 82 | # wukong_image_dir = '/pool/bwjiang/controltext/eval/anytext_eval_imgs/controlnet_wukong_generated/' 83 | # wukong_bw_image_dir = '/pool/bwjiang/controltext/Rethinking-Text-Segmentation/log/images/output/controlnet_wukong' 84 | # 85 | # # output directory for wukong glyphs 86 | # wukong_glyph_dir = '/pool/bwjiang/controltext/eval/controlnet_wukong_generated_gly_lines' 87 | # wukong_bw_glyph_dir = '/pool/bwjiang/controltext/eval/controlnet_wukong_generated_gly_lines_black_and_white' 88 | 89 | 90 | def get_glyphs(json_path, image_dir, glyph_dir): 91 | # Ensure glyph directory exists 92 | os.makedirs(glyph_dir, exist_ok=True) 93 | 94 | with open(json_path, 'r', encoding='utf-8') as f: 95 | data = json.load(f) 96 | 97 | if not data: 98 | print('Empty JSON file') 99 | return 100 | 101 | data_list = data.get('data_list', []) 102 | 103 | for entry in data_list: 104 | img_name_prefix = entry['img_name'].split('.')[0] 105 | img_annotations = entry.get('annotations', []) 106 | 107 | # Find all images with the given prefix 108 | matching_images = [f for f in os.listdir(image_dir) if f.startswith(img_name_prefix)] 109 | 110 | if not matching_images: 111 | print(f'No matching image found for {img_name_prefix}') 112 | continue 113 | 114 | for img_file in matching_images: 115 | img_path = os.path.join(image_dir, img_file) 116 | img = cv2.imread(img_path) 117 | 118 | if img is None: 119 | print(f'Error: Unable to read image {img_file}') 120 | continue 121 | 122 | for idx, annotation in enumerate(img_annotations): 123 | polygon = annotation.get('polygon', []) 124 | valid = annotation.get('valid', False) 125 | text_content = annotation.get('text', "") 126 | 127 | if not valid or len(polygon) != 4: 128 | continue # Skip invalid annotations or incorrect polygon format 129 | 130 | # Convert polygon to bounding box 131 | x, y, w, h = cv2.boundingRect(np.array(polygon, dtype=np.int32)) 132 | 133 | # Crop the glyph region 134 | cropped_img = img[y:y + h, x:x + w] 135 | 136 | # Save cropped glyph image 137 | if not cropped_img.size: 138 | continue 139 | output_filename = f"{os.path.splitext(img_file)[0]}_{idx}_{text_content}.jpg" 140 | output_path = os.path.join(glyph_dir, output_filename) 141 | cv2.imwrite(output_path, cropped_img) 142 | print(f"Saved cropped glyph: {output_path}") 143 | 144 | # get_glyphs(laion_test1k_json_path, laion_image_dir, laion_glyph_dir) 145 | # get_glyphs(wukong_test1k_json_path, wukong_image_dir, wukong_glyph_dir) 146 | get_glyphs(laion_test1k_json_path, laion_bw_image_dir, laion_bw_glyph_dir) 147 | get_glyphs(wukong_test1k_json_path, wukong_bw_image_dir, wukong_bw_glyph_dir) 148 | 149 | print("Glyph Processing Done") -------------------------------------------------------------------------------- /eval/glyphcontrol_multiGPUs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import copy 4 | import argparse 5 | import pathlib 6 | import json 7 | 8 | 9 | def load(file_path: str): 10 | file_path = pathlib.Path(file_path) 11 | func_dict = {'.json': load_json} 12 | assert file_path.suffix in func_dict 13 | return func_dict[file_path.suffix](file_path) 14 | 15 | 16 | def load_json(file_path: str): 17 | with open(file_path, 'r', encoding='utf8') as f: 18 | content = json.load(f) 19 | return content 20 | 21 | 22 | def save(data, file_path): 23 | file_path = pathlib.Path(file_path) 24 | func_dict = {'.json': save_json} 25 | assert file_path.suffix in func_dict 26 | return func_dict[file_path.suffix](data, file_path) 27 | 28 | 29 | def save_json(data, file_path): 30 | with open(file_path, 'w', encoding='utf-8') as json_file: 31 | json.dump(data, json_file, ensure_ascii=False, indent=4) 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--model_path", 38 | type=str, 39 | default='checkpoints/laion10M_epoch_6_model_ema_only.ckpt', 40 | help='path to checkpoint of model' 41 | ) 42 | parser.add_argument( 43 | "--gpus", 44 | type=str, 45 | default='0,1,2,3,4,5,6,7', 46 | help='gpus for inference' 47 | ) 48 | parser.add_argument( 49 | "--output_dir", 50 | type=str, 51 | default='./glyphcontrol_laion_generated/', 52 | help="output path" 53 | ) 54 | parser.add_argument( 55 | "--glyph_dir", 56 | type=str, 57 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion', 58 | help="path of glyph images from anytext evaluation dataset" 59 | ) 60 | parser.add_argument( 61 | "--json_path", 62 | type=str, 63 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json', 64 | help="json path for evaluation dataset" 65 | ) 66 | args = parser.parse_args() 67 | return args 68 | 69 | 70 | if __name__ == "__main__": 71 | args = parse_args() 72 | output_dir = args.output_dir 73 | 74 | tmp_dir = './tmp_dir' 75 | exec_path = './glyphcontrol_singleGPU.py' 76 | continue_gen = True # if True, not clear output_dir, and generate rest images. 77 | 78 | if os.path.exists(tmp_dir): 79 | shutil.rmtree(tmp_dir) 80 | os.makedirs(tmp_dir) 81 | 82 | if not continue_gen: 83 | if os.path.exists(output_dir): 84 | shutil.rmtree(output_dir) 85 | os.makedirs(output_dir) 86 | else: 87 | if not os.path.exists(output_dir): 88 | os.makedirs(output_dir) 89 | 90 | os.system('sleep 1') 91 | gpu_ids = [int(i) for i in args.gpus.split(',')] 92 | nproc = len(gpu_ids) 93 | all_lines = load(args.json_path) 94 | split_file = [] 95 | length = len(all_lines['data_list']) // nproc 96 | cmds = [] 97 | for i in range(nproc): 98 | start, end = i*length, (i+1)*length 99 | if i == nproc - 1: 100 | end = len(all_lines['data_list']) 101 | temp_lines = copy.deepcopy(all_lines) 102 | temp_lines['data_list'] = temp_lines['data_list'][start:end] 103 | tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json') 104 | save(temp_lines, tmp_file) 105 | os.system('sleep 1') 106 | cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --json_path {tmp_file} --output_dir {output_dir} --model_path {args.model_path} --glyph_dir {args.glyph_dir} && echo proc-{i} done!'] 107 | cmds = ' & '.join(cmds) 108 | os.system(cmds) 109 | print('Done.') 110 | os.system('sleep 2') 111 | shutil.rmtree(tmp_dir) 112 | 113 | 114 | ''' 115 | command to kill the task after running: 116 | $ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9 117 | ''' 118 | -------------------------------------------------------------------------------- /eval/render_glyph_imgs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | from tqdm import tqdm 5 | import shutil 6 | import numpy as np 7 | import cv2 8 | from PIL import Image, ImageFont 9 | from torch.utils.data import DataLoader 10 | from dataset_util import show_bbox_on_image 11 | import argparse 12 | from t3_dataset import T3DataSet 13 | max_lines = 20 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--json_path", 20 | type=str, 21 | default='/data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json', 22 | help="json path for evaluation dataset", 23 | ) 24 | parser.add_argument( 25 | "--output_dir", 26 | type=str, 27 | default='/data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong', 28 | help="output path, clear the folder if exist", 29 | ) 30 | parser.add_argument( 31 | "--img_count", 32 | type=int, 33 | default=1000, 34 | help="image count", 35 | ) 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | if __name__ == '__main__': 41 | args = parse_args() 42 | if os.path.exists(args.output_dir): 43 | shutil.rmtree(args.output_dir) 44 | os.makedirs(args.output_dir) 45 | dataset = T3DataSet(args.json_path, for_show=True, max_lines=max_lines, glyph_scale=2, mask_img_prob=1.0, caption_pos_prob=0.0) 46 | train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 47 | pbar = tqdm(total=args.img_count) 48 | for i, data in enumerate(train_loader): 49 | if i == args.img_count: 50 | break 51 | all_glyphs = [] 52 | for k, glyphs in enumerate(data['glyphs']): 53 | all_glyphs += [glyphs[0].numpy().astype(np.int32)*255] 54 | glyph_img = cv2.resize(255.0-np.sum(all_glyphs, axis=0), (512, 512)) 55 | cv2.imwrite(os.path.join(args.output_dir, data['img_name'][0]), glyph_img) 56 | pbar.update(1) 57 | pbar.close() 58 | -------------------------------------------------------------------------------- /eval/textdiffuser_multiGPUs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import copy 4 | import argparse 5 | import pathlib 6 | import json 7 | 8 | 9 | def load(file_path: str): 10 | file_path = pathlib.Path(file_path) 11 | func_dict = {'.json': load_json} 12 | assert file_path.suffix in func_dict 13 | return func_dict[file_path.suffix](file_path) 14 | 15 | 16 | def load_json(file_path: str): 17 | with open(file_path, 'r', encoding='utf8') as f: 18 | content = json.load(f) 19 | return content 20 | 21 | 22 | def save(data, file_path): 23 | file_path = pathlib.Path(file_path) 24 | func_dict = {'.json': save_json} 25 | assert file_path.suffix in func_dict 26 | return func_dict[file_path.suffix](data, file_path) 27 | 28 | 29 | def save_json(data, file_path): 30 | with open(file_path, 'w', encoding='utf-8') as json_file: 31 | json.dump(data, json_file, ensure_ascii=False, indent=4) 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--model_path", 38 | type=str, 39 | default='textdiffuser-ckpt/diffusion_backbone', 40 | help='path to model' 41 | ) 42 | parser.add_argument( 43 | "--gpus", 44 | type=str, 45 | default='0,1,2,3,4,5,6,7', 46 | help='gpus for inference' 47 | ) 48 | parser.add_argument( 49 | "--output_dir", 50 | type=str, 51 | default='./textdiffuser_laion_generated/', 52 | help="output path" 53 | ) 54 | parser.add_argument( 55 | "--glyph_dir", 56 | type=str, 57 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion', 58 | help="path of glyph images from anytext evaluation dataset" 59 | ) 60 | parser.add_argument( 61 | "--json_path", 62 | type=str, 63 | default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json', 64 | help="json path for evaluation dataset" 65 | ) 66 | args = parser.parse_args() 67 | return args 68 | 69 | 70 | if __name__ == "__main__": 71 | args = parse_args() 72 | output_dir = args.output_dir 73 | 74 | tmp_dir = './tmp_dir' 75 | exec_path = './textdiffuser_singleGPU.py' 76 | continue_gen = True # if True, not clear output_dir, and generate rest images. 77 | 78 | if os.path.exists(tmp_dir): 79 | shutil.rmtree(tmp_dir) 80 | os.makedirs(tmp_dir) 81 | 82 | if not continue_gen: 83 | if os.path.exists(output_dir): 84 | shutil.rmtree(output_dir) 85 | os.makedirs(output_dir) 86 | else: 87 | if not os.path.exists(output_dir): 88 | os.makedirs(output_dir) 89 | 90 | os.system('sleep 1') 91 | gpu_ids = [int(i) for i in args.gpus.split(',')] 92 | nproc = len(gpu_ids) 93 | all_lines = load(args.json_path) 94 | split_file = [] 95 | length = len(all_lines['data_list']) // nproc 96 | cmds = [] 97 | for i in range(nproc): 98 | start, end = i*length, (i+1)*length 99 | if i == nproc - 1: 100 | end = len(all_lines['data_list']) 101 | temp_lines = copy.deepcopy(all_lines) 102 | temp_lines['data_list'] = temp_lines['data_list'][start:end] 103 | tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json') 104 | save(temp_lines, tmp_file) 105 | os.system('sleep 1') 106 | cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --json_path {tmp_file} --output_dir {output_dir} --model_path {args.model_path} --glyph_dir {args.glyph_dir} && echo proc-{i} done!'] 107 | cmds = ' & '.join(cmds) 108 | os.system(cmds) 109 | print('Done.') 110 | os.system('sleep 2') 111 | shutil.rmtree(tmp_dir) 112 | 113 | 114 | ''' 115 | command to kill the task after running: 116 | $ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9 117 | ''' 118 | -------------------------------------------------------------------------------- /example_images/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/banner.png -------------------------------------------------------------------------------- /example_images/edit1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit1.png -------------------------------------------------------------------------------- /example_images/edit10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit10.png -------------------------------------------------------------------------------- /example_images/edit11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit11.png -------------------------------------------------------------------------------- /example_images/edit12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit12.png -------------------------------------------------------------------------------- /example_images/edit13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit13.png -------------------------------------------------------------------------------- /example_images/edit14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit14.png -------------------------------------------------------------------------------- /example_images/edit15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit15.png -------------------------------------------------------------------------------- /example_images/edit16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit16.png -------------------------------------------------------------------------------- /example_images/edit2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit2.png -------------------------------------------------------------------------------- /example_images/edit3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit3.png -------------------------------------------------------------------------------- /example_images/edit4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit4.png -------------------------------------------------------------------------------- /example_images/edit5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit5.png -------------------------------------------------------------------------------- /example_images/edit6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit6.png -------------------------------------------------------------------------------- /example_images/edit7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit7.png -------------------------------------------------------------------------------- /example_images/edit8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit8.png -------------------------------------------------------------------------------- /example_images/edit9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/edit9.png -------------------------------------------------------------------------------- /example_images/gen1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen1.png -------------------------------------------------------------------------------- /example_images/gen10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen10.png -------------------------------------------------------------------------------- /example_images/gen11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen11.png -------------------------------------------------------------------------------- /example_images/gen12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen12.png -------------------------------------------------------------------------------- /example_images/gen13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen13.png -------------------------------------------------------------------------------- /example_images/gen14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen14.png -------------------------------------------------------------------------------- /example_images/gen15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen15.png -------------------------------------------------------------------------------- /example_images/gen16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen16.png -------------------------------------------------------------------------------- /example_images/gen17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen17.png -------------------------------------------------------------------------------- /example_images/gen18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen18.png -------------------------------------------------------------------------------- /example_images/gen19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen19.png -------------------------------------------------------------------------------- /example_images/gen2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen2.png -------------------------------------------------------------------------------- /example_images/gen20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen20.png -------------------------------------------------------------------------------- /example_images/gen21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen21.png -------------------------------------------------------------------------------- /example_images/gen3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen3.png -------------------------------------------------------------------------------- /example_images/gen4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen4.png -------------------------------------------------------------------------------- /example_images/gen5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen5.png -------------------------------------------------------------------------------- /example_images/gen6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen6.png -------------------------------------------------------------------------------- /example_images/gen7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen7.png -------------------------------------------------------------------------------- /example_images/gen8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen8.png -------------------------------------------------------------------------------- /example_images/gen9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/gen9.png -------------------------------------------------------------------------------- /example_images/ref1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref1.jpg -------------------------------------------------------------------------------- /example_images/ref10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref10.jpg -------------------------------------------------------------------------------- /example_images/ref11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref11.jpg -------------------------------------------------------------------------------- /example_images/ref12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref12.png -------------------------------------------------------------------------------- /example_images/ref13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref13.jpg -------------------------------------------------------------------------------- /example_images/ref14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref14.png -------------------------------------------------------------------------------- /example_images/ref15.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref15.jpeg -------------------------------------------------------------------------------- /example_images/ref16.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref16.jpeg -------------------------------------------------------------------------------- /example_images/ref2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref2.jpg -------------------------------------------------------------------------------- /example_images/ref3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref3.jpg -------------------------------------------------------------------------------- /example_images/ref4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref4.jpg -------------------------------------------------------------------------------- /example_images/ref5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref5.jpg -------------------------------------------------------------------------------- /example_images/ref6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref6.jpg -------------------------------------------------------------------------------- /example_images/ref7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref7.jpg -------------------------------------------------------------------------------- /example_images/ref8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref8.jpg -------------------------------------------------------------------------------- /example_images/ref9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/example_images/ref9.jpg -------------------------------------------------------------------------------- /flows.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/flows.png -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('fonts/Arial_Unicode.ttf', size=size) 20 | nc = int(32 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config, **kwargs): 73 | if "target" not in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss -------------------------------------------------------------------------------- /models_yaml/anytext_sd15.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "img" 10 | cond_stage_key: "caption" 11 | control_key: "hint" 12 | glyph_key: "glyphs" 13 | position_key: "positions" 14 | image_size: 64 15 | channels: 4 16 | cond_stage_trainable: true # need be true when embedding_manager is valid 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.18215 20 | use_ema: False 21 | only_mid_control: False 22 | loss_alpha: 0 # perceptual loss, 0.003 23 | loss_beta: 0 # ctc loss 24 | latin_weight: 1.0 # latin text line may need smaller weigth 25 | with_step_weight: true 26 | use_vae_upsample: true 27 | embedding_manager_config: 28 | target: cldm.embedding_manager.EmbeddingManager 29 | params: 30 | valid: true # v6 31 | emb_type: ocr # ocr, vit, conv 32 | glyph_channels: 1 33 | position_channels: 1 34 | add_pos: false 35 | placeholder_string: '*' 36 | 37 | control_stage_config: 38 | target: cldm.cldm.ControlNet 39 | params: 40 | image_size: 32 # unused 41 | in_channels: 4 42 | model_channels: 320 43 | glyph_channels: 1 44 | position_channels: 1 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | unet_config: 56 | target: cldm.cldm.ControlledUnetModel 57 | params: 58 | image_size: 32 # unused 59 | in_channels: 4 60 | out_channels: 4 61 | model_channels: 320 62 | attention_resolutions: [ 4, 2, 1 ] 63 | num_res_blocks: 2 64 | channel_mult: [ 1, 2, 4, 4 ] 65 | num_heads: 8 66 | use_spatial_transformer: True 67 | transformer_depth: 1 68 | context_dim: 768 69 | use_checkpoint: True 70 | legacy: False 71 | 72 | first_stage_config: 73 | target: ldm.models.autoencoder.AutoencoderKL 74 | params: 75 | embed_dim: 4 76 | monitor: val/rec_loss 77 | ddconfig: 78 | double_z: true 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: 85 | - 1 86 | - 2 87 | - 4 88 | - 4 89 | num_res_blocks: 2 90 | attn_resolutions: [] 91 | dropout: 0.0 92 | lossconfig: 93 | target: torch.nn.Identity 94 | 95 | cond_stage_config: 96 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedderT3 97 | params: 98 | # version: /home/yuxiang.tyx/.cache/modelscope/hub/damo/cv_anytext_text_generation_editing/clip-vit-large-patch14 99 | use_vision: false # v6 100 | -------------------------------------------------------------------------------- /models_yaml/anytext_sd15_conv.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "img" 10 | cond_stage_key: "caption" 11 | control_key: "hint" 12 | glyph_key: "glyphs" 13 | position_key: "positions" 14 | image_size: 64 15 | channels: 4 16 | cond_stage_trainable: true # need be true when embedding_manager is valid 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.18215 20 | use_ema: False 21 | only_mid_control: False 22 | loss_alpha: 0 # perceptual loss, 0.003 23 | loss_beta: 0 # ctc loss 24 | latin_weight: 1.0 # latin text line may need smaller weigth 25 | with_step_weight: true 26 | use_vae_upsample: true 27 | embedding_manager_config: 28 | target: cldm.embedding_manager.EmbeddingManager 29 | params: 30 | valid: true # v6 31 | emb_type: conv # ocr, vit, conv 32 | glyph_channels: 1 33 | position_channels: 1 34 | add_pos: false 35 | placeholder_string: '*' 36 | 37 | control_stage_config: 38 | target: cldm.cldm.ControlNet 39 | params: 40 | image_size: 32 # unused 41 | in_channels: 4 42 | model_channels: 320 43 | glyph_channels: 1 44 | position_channels: 1 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | unet_config: 56 | target: cldm.cldm.ControlledUnetModel 57 | params: 58 | image_size: 32 # unused 59 | in_channels: 4 60 | out_channels: 4 61 | model_channels: 320 62 | attention_resolutions: [ 4, 2, 1 ] 63 | num_res_blocks: 2 64 | channel_mult: [ 1, 2, 4, 4 ] 65 | num_heads: 8 66 | use_spatial_transformer: True 67 | transformer_depth: 1 68 | context_dim: 768 69 | use_checkpoint: True 70 | legacy: False 71 | 72 | first_stage_config: 73 | target: ldm.models.autoencoder.AutoencoderKL 74 | params: 75 | embed_dim: 4 76 | monitor: val/rec_loss 77 | ddconfig: 78 | double_z: true 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: 85 | - 1 86 | - 2 87 | - 4 88 | - 4 89 | num_res_blocks: 2 90 | attn_resolutions: [] 91 | dropout: 0.0 92 | lossconfig: 93 | target: torch.nn.Identity 94 | 95 | cond_stage_config: 96 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedderT3 97 | params: 98 | version: ./models/clip-vit-large-patch14 99 | use_vision: false # v6 100 | -------------------------------------------------------------------------------- /models_yaml/anytext_sd15_perloss.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "img" 10 | cond_stage_key: "caption" 11 | control_key: "hint" 12 | glyph_key: "glyphs" 13 | position_key: "positions" 14 | image_size: 64 15 | channels: 4 16 | cond_stage_trainable: true # need be true when embedding_manager is valid 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.18215 20 | use_ema: False 21 | only_mid_control: False 22 | loss_alpha: 0.003 # perceptual loss, 0.003 23 | loss_beta: 0 # ctc loss 24 | latin_weight: 1.0 # latin text line may need smaller weigth 25 | with_step_weight: true 26 | use_vae_upsample: true 27 | embedding_manager_config: 28 | target: cldm.embedding_manager.EmbeddingManager 29 | params: 30 | valid: true # v6 31 | emb_type: ocr # ocr, vit, conv 32 | glyph_channels: 1 33 | position_channels: 1 34 | add_pos: false 35 | placeholder_string: '*' 36 | 37 | control_stage_config: 38 | target: cldm.cldm.ControlNet 39 | params: 40 | image_size: 32 # unused 41 | in_channels: 4 42 | model_channels: 320 43 | glyph_channels: 1 44 | position_channels: 1 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | unet_config: 56 | target: cldm.cldm.ControlledUnetModel 57 | params: 58 | image_size: 32 # unused 59 | in_channels: 4 60 | out_channels: 4 61 | model_channels: 320 62 | attention_resolutions: [ 4, 2, 1 ] 63 | num_res_blocks: 2 64 | channel_mult: [ 1, 2, 4, 4 ] 65 | num_heads: 8 66 | use_spatial_transformer: True 67 | transformer_depth: 1 68 | context_dim: 768 69 | use_checkpoint: True 70 | legacy: False 71 | 72 | first_stage_config: 73 | target: ldm.models.autoencoder.AutoencoderKL 74 | params: 75 | embed_dim: 4 76 | monitor: val/rec_loss 77 | ddconfig: 78 | double_z: true 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: 85 | - 1 86 | - 2 87 | - 4 88 | - 4 89 | num_res_blocks: 2 90 | attn_resolutions: [] 91 | dropout: 0.0 92 | lossconfig: 93 | target: torch.nn.Identity 94 | 95 | cond_stage_config: 96 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedderT3 97 | params: 98 | # version: ./models/clip-vit-large-patch14 99 | use_vision: false # v6 100 | -------------------------------------------------------------------------------- /models_yaml/anytext_sd15_vit.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: cldm.cldm.ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "img" 10 | cond_stage_key: "caption" 11 | control_key: "hint" 12 | glyph_key: "glyphs" 13 | position_key: "positions" 14 | image_size: 64 15 | channels: 4 16 | cond_stage_trainable: true # need be true when embedding_manager is valid 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.18215 20 | use_ema: False 21 | only_mid_control: False 22 | loss_alpha: 0 # perceptual loss, 0.003 23 | loss_beta: 0 # ctc loss 24 | latin_weight: 1.0 # latin text line may need smaller weigth 25 | with_step_weight: true 26 | use_vae_upsample: true 27 | embedding_manager_config: 28 | target: cldm.embedding_manager.EmbeddingManager 29 | params: 30 | valid: true # v6 31 | emb_type: vit # ocr, vit, conv 32 | glyph_channels: 1 33 | position_channels: 1 34 | add_pos: false 35 | placeholder_string: '*' 36 | 37 | control_stage_config: 38 | target: cldm.cldm.ControlNet 39 | params: 40 | image_size: 32 # unused 41 | in_channels: 4 42 | model_channels: 320 43 | glyph_channels: 1 44 | position_channels: 1 45 | attention_resolutions: [ 4, 2, 1 ] 46 | num_res_blocks: 2 47 | channel_mult: [ 1, 2, 4, 4 ] 48 | num_heads: 8 49 | use_spatial_transformer: True 50 | transformer_depth: 1 51 | context_dim: 768 52 | use_checkpoint: True 53 | legacy: False 54 | 55 | unet_config: 56 | target: cldm.cldm.ControlledUnetModel 57 | params: 58 | image_size: 32 # unused 59 | in_channels: 4 60 | out_channels: 4 61 | model_channels: 320 62 | attention_resolutions: [ 4, 2, 1 ] 63 | num_res_blocks: 2 64 | channel_mult: [ 1, 2, 4, 4 ] 65 | num_heads: 8 66 | use_spatial_transformer: True 67 | transformer_depth: 1 68 | context_dim: 768 69 | use_checkpoint: True 70 | legacy: False 71 | 72 | first_stage_config: 73 | target: ldm.models.autoencoder.AutoencoderKL 74 | params: 75 | embed_dim: 4 76 | monitor: val/rec_loss 77 | ddconfig: 78 | double_z: true 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: 85 | - 1 86 | - 2 87 | - 4 88 | - 4 89 | num_res_blocks: 2 90 | attn_resolutions: [] 91 | dropout: 0.0 92 | lossconfig: 93 | target: torch.nn.Identity 94 | 95 | cond_stage_config: 96 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedderT3 97 | params: 98 | version: ./models/clip-vit-large-patch14 99 | use_vision: true # v6 100 | -------------------------------------------------------------------------------- /ocr_recog/RNN.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from .RecSVTR import Block 4 | 5 | class Swish(nn.Module): 6 | def __int__(self): 7 | super(Swish, self).__int__() 8 | 9 | def forward(self,x): 10 | return x*torch.sigmoid(x) 11 | 12 | class Im2Im(nn.Module): 13 | def __init__(self, in_channels, **kwargs): 14 | super().__init__() 15 | self.out_channels = in_channels 16 | 17 | def forward(self, x): 18 | return x 19 | 20 | class Im2Seq(nn.Module): 21 | def __init__(self, in_channels, **kwargs): 22 | super().__init__() 23 | self.out_channels = in_channels 24 | 25 | def forward(self, x): 26 | B, C, H, W = x.shape 27 | # assert H == 1 28 | x = x.reshape(B, C, H * W) 29 | x = x.permute((0, 2, 1)) 30 | return x 31 | 32 | class EncoderWithRNN(nn.Module): 33 | def __init__(self, in_channels,**kwargs): 34 | super(EncoderWithRNN, self).__init__() 35 | hidden_size = kwargs.get('hidden_size', 256) 36 | self.out_channels = hidden_size * 2 37 | self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True) 38 | 39 | def forward(self, x): 40 | self.lstm.flatten_parameters() 41 | x, _ = self.lstm(x) 42 | return x 43 | 44 | class SequenceEncoder(nn.Module): 45 | def __init__(self, in_channels, encoder_type='rnn', **kwargs): 46 | super(SequenceEncoder, self).__init__() 47 | self.encoder_reshape = Im2Seq(in_channels) 48 | self.out_channels = self.encoder_reshape.out_channels 49 | self.encoder_type = encoder_type 50 | if encoder_type == 'reshape': 51 | self.only_reshape = True 52 | else: 53 | support_encoder_dict = { 54 | 'reshape': Im2Seq, 55 | 'rnn': EncoderWithRNN, 56 | 'svtr': EncoderWithSVTR 57 | } 58 | assert encoder_type in support_encoder_dict, '{} must in {}'.format( 59 | encoder_type, support_encoder_dict.keys()) 60 | 61 | self.encoder = support_encoder_dict[encoder_type]( 62 | self.encoder_reshape.out_channels,**kwargs) 63 | self.out_channels = self.encoder.out_channels 64 | self.only_reshape = False 65 | 66 | def forward(self, x): 67 | if self.encoder_type != 'svtr': 68 | x = self.encoder_reshape(x) 69 | if not self.only_reshape: 70 | x = self.encoder(x) 71 | return x 72 | else: 73 | x = self.encoder(x) 74 | x = self.encoder_reshape(x) 75 | return x 76 | 77 | class ConvBNLayer(nn.Module): 78 | def __init__(self, 79 | in_channels, 80 | out_channels, 81 | kernel_size=3, 82 | stride=1, 83 | padding=0, 84 | bias_attr=False, 85 | groups=1, 86 | act=nn.GELU): 87 | super().__init__() 88 | self.conv = nn.Conv2d( 89 | in_channels=in_channels, 90 | out_channels=out_channels, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | groups=groups, 95 | # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), 96 | bias=bias_attr) 97 | self.norm = nn.BatchNorm2d(out_channels) 98 | self.act = Swish() 99 | 100 | def forward(self, inputs): 101 | out = self.conv(inputs) 102 | out = self.norm(out) 103 | out = self.act(out) 104 | return out 105 | 106 | 107 | class EncoderWithSVTR(nn.Module): 108 | def __init__( 109 | self, 110 | in_channels, 111 | dims=64, # XS 112 | depth=2, 113 | hidden_dims=120, 114 | use_guide=False, 115 | num_heads=8, 116 | qkv_bias=True, 117 | mlp_ratio=2.0, 118 | drop_rate=0.1, 119 | attn_drop_rate=0.1, 120 | drop_path=0., 121 | qk_scale=None): 122 | super(EncoderWithSVTR, self).__init__() 123 | self.depth = depth 124 | self.use_guide = use_guide 125 | self.conv1 = ConvBNLayer( 126 | in_channels, in_channels // 8, padding=1, act='swish') 127 | self.conv2 = ConvBNLayer( 128 | in_channels // 8, hidden_dims, kernel_size=1, act='swish') 129 | 130 | self.svtr_block = nn.ModuleList([ 131 | Block( 132 | dim=hidden_dims, 133 | num_heads=num_heads, 134 | mixer='Global', 135 | HW=None, 136 | mlp_ratio=mlp_ratio, 137 | qkv_bias=qkv_bias, 138 | qk_scale=qk_scale, 139 | drop=drop_rate, 140 | act_layer='swish', 141 | attn_drop=attn_drop_rate, 142 | drop_path=drop_path, 143 | norm_layer='nn.LayerNorm', 144 | epsilon=1e-05, 145 | prenorm=False) for i in range(depth) 146 | ]) 147 | self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) 148 | self.conv3 = ConvBNLayer( 149 | hidden_dims, in_channels, kernel_size=1, act='swish') 150 | # last conv-nxn, the input is concat of input tensor and conv3 output tensor 151 | self.conv4 = ConvBNLayer( 152 | 2 * in_channels, in_channels // 8, padding=1, act='swish') 153 | 154 | self.conv1x1 = ConvBNLayer( 155 | in_channels // 8, dims, kernel_size=1, act='swish') 156 | self.out_channels = dims 157 | self.apply(self._init_weights) 158 | 159 | def _init_weights(self, m): 160 | # weight initialization 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 163 | if m.bias is not None: 164 | nn.init.zeros_(m.bias) 165 | elif isinstance(m, nn.BatchNorm2d): 166 | nn.init.ones_(m.weight) 167 | nn.init.zeros_(m.bias) 168 | elif isinstance(m, nn.Linear): 169 | nn.init.normal_(m.weight, 0, 0.01) 170 | if m.bias is not None: 171 | nn.init.zeros_(m.bias) 172 | elif isinstance(m, nn.ConvTranspose2d): 173 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 174 | if m.bias is not None: 175 | nn.init.zeros_(m.bias) 176 | elif isinstance(m, nn.LayerNorm): 177 | nn.init.ones_(m.weight) 178 | nn.init.zeros_(m.bias) 179 | 180 | def forward(self, x): 181 | # for use guide 182 | if self.use_guide: 183 | z = x.clone() 184 | z.stop_gradient = True 185 | else: 186 | z = x 187 | # for short cut 188 | h = z 189 | # reduce dim 190 | z = self.conv1(z) 191 | z = self.conv2(z) 192 | # SVTR global block 193 | B, C, H, W = z.shape 194 | z = z.flatten(2).permute(0, 2, 1) 195 | 196 | for blk in self.svtr_block: 197 | z = blk(z) 198 | 199 | z = self.norm(z) 200 | # last stage 201 | z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2) 202 | z = self.conv3(z) 203 | z = torch.cat((h, z), dim=1) 204 | z = self.conv1x1(self.conv4(z)) 205 | 206 | return z 207 | 208 | if __name__=="__main__": 209 | svtrRNN = EncoderWithSVTR(56) 210 | print(svtrRNN) -------------------------------------------------------------------------------- /ocr_recog/RecCTCHead.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class CTCHead(nn.Module): 5 | def __init__(self, 6 | in_channels, 7 | out_channels=6625, 8 | fc_decay=0.0004, 9 | mid_channels=None, 10 | return_feats=False, 11 | **kwargs): 12 | super(CTCHead, self).__init__() 13 | if mid_channels is None: 14 | self.fc = nn.Linear( 15 | in_channels, 16 | out_channels, 17 | bias=True,) 18 | else: 19 | self.fc1 = nn.Linear( 20 | in_channels, 21 | mid_channels, 22 | bias=True, 23 | ) 24 | self.fc2 = nn.Linear( 25 | mid_channels, 26 | out_channels, 27 | bias=True, 28 | ) 29 | 30 | self.out_channels = out_channels 31 | self.mid_channels = mid_channels 32 | self.return_feats = return_feats 33 | 34 | def forward(self, x, labels=None): 35 | if self.mid_channels is None: 36 | predicts = self.fc(x) 37 | else: 38 | x = self.fc1(x) 39 | predicts = self.fc2(x) 40 | 41 | if self.return_feats: 42 | result = dict() 43 | result['ctc'] = predicts 44 | result['ctc_neck'] = x 45 | else: 46 | result = predicts 47 | 48 | return result 49 | -------------------------------------------------------------------------------- /ocr_recog/RecModel.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .RNN import SequenceEncoder, Im2Seq, Im2Im 3 | from .RecMv1_enhance import MobileNetV1Enhance 4 | 5 | from .RecCTCHead import CTCHead 6 | 7 | backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance} 8 | neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im} 9 | head_dict = {'CTCHead':CTCHead} 10 | 11 | 12 | class RecModel(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | assert 'in_channels' in config, 'in_channels must in model config' 16 | backbone_type = config.backbone.pop('type') 17 | assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' 18 | self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) 19 | 20 | neck_type = config.neck.pop('type') 21 | assert neck_type in neck_dict, f'neck.type must in {neck_dict}' 22 | self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) 23 | 24 | head_type = config.head.pop('type') 25 | assert head_type in head_dict, f'head.type must in {head_dict}' 26 | self.head = head_dict[head_type](self.neck.out_channels, **config.head) 27 | 28 | self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}' 29 | 30 | def load_3rd_state_dict(self, _3rd_name, _state): 31 | self.backbone.load_3rd_state_dict(_3rd_name, _state) 32 | self.neck.load_3rd_state_dict(_3rd_name, _state) 33 | self.head.load_3rd_state_dict(_3rd_name, _state) 34 | 35 | def forward(self, x): 36 | x = self.backbone(x) 37 | x = self.neck(x) 38 | x = self.head(x) 39 | return x 40 | 41 | def encode(self, x): 42 | x = self.backbone(x) 43 | x = self.neck(x) 44 | x = self.head.ctc_encoder(x) 45 | return x 46 | -------------------------------------------------------------------------------- /ocr_recog/RecMv1_enhance.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .common import Activation 6 | 7 | 8 | class ConvBNLayer(nn.Module): 9 | def __init__(self, 10 | num_channels, 11 | filter_size, 12 | num_filters, 13 | stride, 14 | padding, 15 | channels=None, 16 | num_groups=1, 17 | act='hard_swish'): 18 | super(ConvBNLayer, self).__init__() 19 | self.act = act 20 | self._conv = nn.Conv2d( 21 | in_channels=num_channels, 22 | out_channels=num_filters, 23 | kernel_size=filter_size, 24 | stride=stride, 25 | padding=padding, 26 | groups=num_groups, 27 | bias=False) 28 | 29 | self._batch_norm = nn.BatchNorm2d( 30 | num_filters, 31 | ) 32 | if self.act is not None: 33 | self._act = Activation(act_type=act, inplace=True) 34 | 35 | def forward(self, inputs): 36 | y = self._conv(inputs) 37 | y = self._batch_norm(y) 38 | if self.act is not None: 39 | y = self._act(y) 40 | return y 41 | 42 | 43 | class DepthwiseSeparable(nn.Module): 44 | def __init__(self, 45 | num_channels, 46 | num_filters1, 47 | num_filters2, 48 | num_groups, 49 | stride, 50 | scale, 51 | dw_size=3, 52 | padding=1, 53 | use_se=False): 54 | super(DepthwiseSeparable, self).__init__() 55 | self.use_se = use_se 56 | self._depthwise_conv = ConvBNLayer( 57 | num_channels=num_channels, 58 | num_filters=int(num_filters1 * scale), 59 | filter_size=dw_size, 60 | stride=stride, 61 | padding=padding, 62 | num_groups=int(num_groups * scale)) 63 | if use_se: 64 | self._se = SEModule(int(num_filters1 * scale)) 65 | self._pointwise_conv = ConvBNLayer( 66 | num_channels=int(num_filters1 * scale), 67 | filter_size=1, 68 | num_filters=int(num_filters2 * scale), 69 | stride=1, 70 | padding=0) 71 | 72 | def forward(self, inputs): 73 | y = self._depthwise_conv(inputs) 74 | if self.use_se: 75 | y = self._se(y) 76 | y = self._pointwise_conv(y) 77 | return y 78 | 79 | 80 | class MobileNetV1Enhance(nn.Module): 81 | def __init__(self, 82 | in_channels=3, 83 | scale=0.5, 84 | last_conv_stride=1, 85 | last_pool_type='max', 86 | **kwargs): 87 | super().__init__() 88 | self.scale = scale 89 | self.block_list = [] 90 | 91 | self.conv1 = ConvBNLayer( 92 | num_channels=in_channels, 93 | filter_size=3, 94 | channels=3, 95 | num_filters=int(32 * scale), 96 | stride=2, 97 | padding=1) 98 | 99 | conv2_1 = DepthwiseSeparable( 100 | num_channels=int(32 * scale), 101 | num_filters1=32, 102 | num_filters2=64, 103 | num_groups=32, 104 | stride=1, 105 | scale=scale) 106 | self.block_list.append(conv2_1) 107 | 108 | conv2_2 = DepthwiseSeparable( 109 | num_channels=int(64 * scale), 110 | num_filters1=64, 111 | num_filters2=128, 112 | num_groups=64, 113 | stride=1, 114 | scale=scale) 115 | self.block_list.append(conv2_2) 116 | 117 | conv3_1 = DepthwiseSeparable( 118 | num_channels=int(128 * scale), 119 | num_filters1=128, 120 | num_filters2=128, 121 | num_groups=128, 122 | stride=1, 123 | scale=scale) 124 | self.block_list.append(conv3_1) 125 | 126 | conv3_2 = DepthwiseSeparable( 127 | num_channels=int(128 * scale), 128 | num_filters1=128, 129 | num_filters2=256, 130 | num_groups=128, 131 | stride=(2, 1), 132 | scale=scale) 133 | self.block_list.append(conv3_2) 134 | 135 | conv4_1 = DepthwiseSeparable( 136 | num_channels=int(256 * scale), 137 | num_filters1=256, 138 | num_filters2=256, 139 | num_groups=256, 140 | stride=1, 141 | scale=scale) 142 | self.block_list.append(conv4_1) 143 | 144 | conv4_2 = DepthwiseSeparable( 145 | num_channels=int(256 * scale), 146 | num_filters1=256, 147 | num_filters2=512, 148 | num_groups=256, 149 | stride=(2, 1), 150 | scale=scale) 151 | self.block_list.append(conv4_2) 152 | 153 | for _ in range(5): 154 | conv5 = DepthwiseSeparable( 155 | num_channels=int(512 * scale), 156 | num_filters1=512, 157 | num_filters2=512, 158 | num_groups=512, 159 | stride=1, 160 | dw_size=5, 161 | padding=2, 162 | scale=scale, 163 | use_se=False) 164 | self.block_list.append(conv5) 165 | 166 | conv5_6 = DepthwiseSeparable( 167 | num_channels=int(512 * scale), 168 | num_filters1=512, 169 | num_filters2=1024, 170 | num_groups=512, 171 | stride=(2, 1), 172 | dw_size=5, 173 | padding=2, 174 | scale=scale, 175 | use_se=True) 176 | self.block_list.append(conv5_6) 177 | 178 | conv6 = DepthwiseSeparable( 179 | num_channels=int(1024 * scale), 180 | num_filters1=1024, 181 | num_filters2=1024, 182 | num_groups=1024, 183 | stride=last_conv_stride, 184 | dw_size=5, 185 | padding=2, 186 | use_se=True, 187 | scale=scale) 188 | self.block_list.append(conv6) 189 | 190 | self.block_list = nn.Sequential(*self.block_list) 191 | if last_pool_type == 'avg': 192 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 193 | else: 194 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 195 | self.out_channels = int(1024 * scale) 196 | 197 | def forward(self, inputs): 198 | y = self.conv1(inputs) 199 | y = self.block_list(y) 200 | y = self.pool(y) 201 | return y 202 | 203 | def hardsigmoid(x): 204 | return F.relu6(x + 3., inplace=True) / 6. 205 | 206 | class SEModule(nn.Module): 207 | def __init__(self, channel, reduction=4): 208 | super(SEModule, self).__init__() 209 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 210 | self.conv1 = nn.Conv2d( 211 | in_channels=channel, 212 | out_channels=channel // reduction, 213 | kernel_size=1, 214 | stride=1, 215 | padding=0, 216 | bias=True) 217 | self.conv2 = nn.Conv2d( 218 | in_channels=channel // reduction, 219 | out_channels=channel, 220 | kernel_size=1, 221 | stride=1, 222 | padding=0, 223 | bias=True) 224 | 225 | def forward(self, inputs): 226 | outputs = self.avg_pool(inputs) 227 | outputs = self.conv1(outputs) 228 | outputs = F.relu(outputs) 229 | outputs = self.conv2(outputs) 230 | outputs = hardsigmoid(outputs) 231 | x = torch.mul(inputs, outputs) 232 | 233 | return x 234 | -------------------------------------------------------------------------------- /ocr_recog/common.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Hswish(nn.Module): 9 | def __init__(self, inplace=True): 10 | super(Hswish, self).__init__() 11 | self.inplace = inplace 12 | 13 | def forward(self, x): 14 | return x * F.relu6(x + 3., inplace=self.inplace) / 6. 15 | 16 | # out = max(0, min(1, slop*x+offset)) 17 | # paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None) 18 | class Hsigmoid(nn.Module): 19 | def __init__(self, inplace=True): 20 | super(Hsigmoid, self).__init__() 21 | self.inplace = inplace 22 | 23 | def forward(self, x): 24 | # torch: F.relu6(x + 3., inplace=self.inplace) / 6. 25 | # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. 26 | return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. 27 | 28 | class GELU(nn.Module): 29 | def __init__(self, inplace=True): 30 | super(GELU, self).__init__() 31 | self.inplace = inplace 32 | 33 | def forward(self, x): 34 | return torch.nn.functional.gelu(x) 35 | 36 | 37 | class Swish(nn.Module): 38 | def __init__(self, inplace=True): 39 | super(Swish, self).__init__() 40 | self.inplace = inplace 41 | 42 | def forward(self, x): 43 | if self.inplace: 44 | x.mul_(torch.sigmoid(x)) 45 | return x 46 | else: 47 | return x*torch.sigmoid(x) 48 | 49 | 50 | class Activation(nn.Module): 51 | def __init__(self, act_type, inplace=True): 52 | super(Activation, self).__init__() 53 | act_type = act_type.lower() 54 | if act_type == 'relu': 55 | self.act = nn.ReLU(inplace=inplace) 56 | elif act_type == 'relu6': 57 | self.act = nn.ReLU6(inplace=inplace) 58 | elif act_type == 'sigmoid': 59 | raise NotImplementedError 60 | elif act_type == 'hard_sigmoid': 61 | self.act = Hsigmoid(inplace) 62 | elif act_type == 'hard_swish': 63 | self.act = Hswish(inplace=inplace) 64 | elif act_type == 'leakyrelu': 65 | self.act = nn.LeakyReLU(inplace=inplace) 66 | elif act_type == 'gelu': 67 | self.act = GELU(inplace=inplace) 68 | elif act_type == 'swish': 69 | self.act = Swish(inplace=inplace) 70 | else: 71 | raise NotImplementedError 72 | 73 | def forward(self, inputs): 74 | return self.act(inputs) -------------------------------------------------------------------------------- /ocr_recog/en_dict.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 11 | : 12 | ; 13 | < 14 | = 15 | > 16 | ? 17 | @ 18 | A 19 | B 20 | C 21 | D 22 | E 23 | F 24 | G 25 | H 26 | I 27 | J 28 | K 29 | L 30 | M 31 | N 32 | O 33 | P 34 | Q 35 | R 36 | S 37 | T 38 | U 39 | V 40 | W 41 | X 42 | Y 43 | Z 44 | [ 45 | \ 46 | ] 47 | ^ 48 | _ 49 | ` 50 | a 51 | b 52 | c 53 | d 54 | e 55 | f 56 | g 57 | h 58 | i 59 | j 60 | k 61 | l 62 | m 63 | n 64 | o 65 | p 66 | q 67 | r 68 | s 69 | t 70 | u 71 | v 72 | w 73 | x 74 | y 75 | z 76 | { 77 | | 78 | } 79 | ~ 80 | ! 81 | " 82 | # 83 | $ 84 | % 85 | & 86 | ' 87 | ( 88 | ) 89 | * 90 | + 91 | , 92 | - 93 | . 94 | / 95 | 96 | -------------------------------------------------------------------------------- /ocr_weights/en_dict.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 11 | : 12 | ; 13 | < 14 | = 15 | > 16 | ? 17 | @ 18 | A 19 | B 20 | C 21 | D 22 | E 23 | F 24 | G 25 | H 26 | I 27 | J 28 | K 29 | L 30 | M 31 | N 32 | O 33 | P 34 | Q 35 | R 36 | S 37 | T 38 | U 39 | V 40 | W 41 | X 42 | Y 43 | Z 44 | [ 45 | \ 46 | ] 47 | ^ 48 | _ 49 | ` 50 | a 51 | b 52 | c 53 | d 54 | e 55 | f 56 | g 57 | h 58 | i 59 | j 60 | k 61 | l 62 | m 63 | n 64 | o 65 | p 66 | q 67 | r 68 | s 69 | t 70 | u 71 | v 72 | w 73 | x 74 | y 75 | z 76 | { 77 | | 78 | } 79 | ~ 80 | ! 81 | " 82 | # 83 | $ 84 | % 85 | & 86 | ' 87 | ( 88 | ) 89 | * 90 | + 91 | , 92 | - 93 | . 94 | / 95 | 96 | -------------------------------------------------------------------------------- /ocr_weights/ppv3_rec.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ocr_weights/ppv3_rec.pth -------------------------------------------------------------------------------- /ocr_weights/ppv3_rec_en.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowen-upenn/ControlText/29a05d78a1c5476ce356a2cae3954d1441629a85/ocr_weights/ppv3_rec_en.pth -------------------------------------------------------------------------------- /proj_3d_surface.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | # !pip install timm 7 | 8 | model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest inference speed) 9 | #model_type = "DPT_Hybrid" # MiDaS v3 - Hybrid (medium accuracy, medium inference speed) 10 | #model_type = "MiDaS_small" # MiDaS v2.1 - Small (lowest accuracy, highest inference speed) 11 | 12 | midas = torch.hub.load("intel-isl/MiDaS", model_type) 13 | 14 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 15 | midas.to(device) 16 | midas.eval() 17 | 18 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 19 | 20 | if model_type == "DPT_Large" or model_type == "DPT_Hybrid": 21 | transform = midas_transforms.dpt_transform 22 | else: 23 | transform = midas_transforms.small_transform 24 | 25 | # Load an image 26 | bg_img = cv2.imread('hat.png') 27 | bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB) 28 | 29 | input_batch = transform(bg_img).to(device) 30 | 31 | with torch.no_grad(): 32 | prediction = midas(input_batch) 33 | 34 | prediction = torch.nn.functional.interpolate( 35 | prediction.unsqueeze(1), 36 | size=bg_img.shape[:2], 37 | mode="bicubic", 38 | align_corners=False, 39 | ).squeeze() 40 | 41 | depth = prediction.cpu().numpy() 42 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth)) 43 | depth = 2 - depth 44 | 45 | plt.imshow(depth) 46 | plt.colorbar(label='Depth') 47 | 48 | txt_img = cv2.imread('texts.png')[:, :, 0] 49 | plt.imshow(txt_img) 50 | 51 | # Create 3D points 52 | h, w = depth.shape 53 | x, y = np.meshgrid(np.arange(w), np.arange(h)) 54 | z = depth 55 | 56 | # Stack into a (N, 3) array of 3D points 57 | points_3d = np.stack((x, y, z), axis=-1).reshape(-1, 3) 58 | 59 | # Define camera parameters 60 | focal_length = 1 # Adjust based on the desired perspective 61 | camera_matrix = np.array([[focal_length, 0, 0.0], 62 | [0, focal_length, 0.0], 63 | [0, 0, 1]]) 64 | 65 | # Project 3D points to 2D 66 | points_2d, _ = cv2.projectPoints(points_3d, (0, 0, 0), (0, 0, 0), camera_matrix, None) 67 | 68 | # Reshape to image dimensions 69 | points_2d = points_2d.reshape(h, w, 2) 70 | 71 | # Create a distorted image 72 | distorted_image = np.zeros_like(txt_img) 73 | 74 | # # Map original image pixels to new positions 75 | # for i in range(h): 76 | # for j in range(w): 77 | # x_new, y_new = points_2d[i, j] 78 | # x_new = int(x_new) 79 | # y_new = int(y_new) 80 | # if 0 <= x_new < w and 0 <= y_new < h: 81 | # distorted_image[y_new, x_new] = txt_img[i, j] 82 | 83 | # Create an inverse map 84 | inverse_map = np.full((h, w, 2), -1.0) 85 | 86 | # Populate the inverse map 87 | for i in range(h): 88 | for j in range(w): 89 | x_new, y_new = points_2d[i, j] 90 | x_new = int(x_new) 91 | y_new = int(y_new) 92 | if 0 <= x_new < w and 0 <= y_new < h: 93 | inverse_map[y_new, x_new] = [i, j] 94 | 95 | # Fill the distorted image using the inverse map 96 | for i in range(h): 97 | for j in range(w): 98 | src_y, src_x = inverse_map[i, j] 99 | if src_y >= 0 and src_x >= 0: 100 | distorted_image[i, j] = txt_img[int(src_y), int(src_x)] 101 | 102 | 103 | print('distorted_image', distorted_image.shape, np.min(distorted_image), np.max(distorted_image)) 104 | 105 | # Display the result 106 | plt.imshow(distorted_image) 107 | 108 | 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | addict 3 | aiohappyeyeballs 4 | aiohttp 5 | aiosignal 6 | albumentations 7 | aliyun-python-sdk-core 8 | aliyun-python-sdk-kms 9 | altair 10 | antlr4-python3-runtime 11 | anyio 12 | astor 13 | astunparse 14 | async-timeout 15 | attrs 16 | basicsr 17 | beautifulsoup4 18 | blinker 19 | braceexpand 20 | cffi 21 | cmake 22 | crcmod 23 | cryptography 24 | datasets 25 | diffusers 26 | dill 27 | easydict 28 | einops 29 | entrypoints 30 | fastapi 31 | fire 32 | flatbuffers 33 | frozenlist 34 | fsspec 35 | ftfy 36 | future 37 | gast 38 | gitdb 39 | gitpython 40 | google-auth-oauthlib 41 | google-pasta 42 | gradio 43 | gradio-client 44 | h5py 45 | imageio 46 | imageio-ffmpeg 47 | imgaug 48 | importlib-metadata 49 | jieba 50 | jmespath 51 | jsonschema 52 | jsonschema-specifications 53 | keras 54 | kornia 55 | langid 56 | lazy-loader 57 | libclang 58 | lit 59 | lxml 60 | markdown 61 | mock 62 | modelscope 63 | multidict 64 | multiprocess 65 | mypy-extensions 66 | numpy 67 | oauthlib 68 | omegaconf 69 | open-clip-torch 70 | opencv-contrib-python 71 | opencv-python 72 | opt-einsum 73 | oss2 74 | pandas 75 | pillow 76 | protobuf 77 | pycparser 78 | pycryptodome 79 | pydantic 80 | pydeck 81 | pydeprecate 82 | pympler 83 | pyre-extensions 84 | python-docx 85 | pytorch-lightning 86 | pywavelets 87 | rapidfuzz 88 | referencing 89 | requests 90 | requests-oauthlib 91 | rpds-py 92 | sacremoses 93 | safetensors 94 | scikit-image 95 | semver 96 | simplejson 97 | smmap 98 | sortedcontainers 99 | soupsieve 100 | starlette 101 | subword-nmt 102 | tensorboard 103 | tensorboard-data-server 104 | tensorflow 105 | tensorflow-estimator 106 | tensorflow-io-gcs-filesystem 107 | termcolor 108 | test-tube 109 | tifffile 110 | timm 111 | tokenizers 112 | toml 113 | torch 114 | torchmetrics 115 | torchvision 116 | tqdm 117 | transformers 118 | triton 119 | typing-extensions 120 | typing-inspect 121 | tzlocal 122 | ujson 123 | validators 124 | watchdog 125 | webdataset 126 | werkzeug 127 | wrapt 128 | xformers 129 | xxhash 130 | yapf 131 | yarl 132 | zipp -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | #banner { 2 | max-width: 400px; 3 | margin: auto; 4 | box-shadow: 0 2px 20px rgba(0, 0, 0, 0.5) !important; 5 | border-radius: 20px; 6 | } 7 | 8 | .run { 9 | background-color: #624AFF !important; 10 | color: #FFFFFF !important; 11 | border-radius: 2px !important; 12 | box-shadow: 0 3px 5px rgba(0, 0, 0, 0.5) !important; 13 | } 14 | .run:active { 15 | background-color: #d96565 !important; 16 | } 17 | .run:hover { 18 | background-color: #a079f5 !important; 19 | } 20 | /* tab button style */ 21 | button.svelte-kqij2n { 22 | margin-bottom: -1px; 23 | border: 1px solid transparent; 24 | border-color: transparent; 25 | border-bottom: none; 26 | color: #9CA3AF !important; 27 | font-size: 16px; 28 | } 29 | button.selected.svelte-kqij2n { 30 | background: #ddd8f9 !important; 31 | color: rgb(62, 7, 240) !important; 32 | } -------------------------------------------------------------------------------- /synthetic_dataset/generate_prompt_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import base64 4 | import requests 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | def query_openai_gpt_4v(image_path, api_key): 10 | # we have to crop the image before converting it to base64 11 | image = cv2.imread(image_path) 12 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 13 | 14 | _, buffer = cv2.imencode('.jpg', image) 15 | image_bytes = np.array(buffer).tobytes() 16 | base64_image = base64.b64encode(image_bytes).decode('utf-8') 17 | 18 | messages = "Your task is to generate a detailed prompt for ControlNet to replicate the specific font style seen in an image. " \ 19 | "Focus on describing the unique visual characteristics of the font that make it different from other fonts, " \ 20 | "such as letter shape, line weight, glyph width, and any distinct features and styles of the typeface that stand out. " \ 21 | "Avoid mentioning general attributes and ignore any shape distortion, perspective transformation, or rotation in the text. " \ 22 | "Provide the description in two sentences. If there are multiple lines, describe each line separately. " \ 23 | "Here is the prompt:" 24 | max_tokens = 400 25 | 26 | prompt = { 27 | "model": "gpt-4-vision-preview", 28 | "messages": [ 29 | { 30 | "role": "user", 31 | "content": [ 32 | {"type": "text", "text": messages}, 33 | {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} 34 | ] 35 | } 36 | ], 37 | "max_tokens": max_tokens 38 | } 39 | 40 | # Send request to OpenAI API 41 | headers = { 42 | "Content-Type": "application/json", 43 | "Authorization": f"Bearer {api_key}" 44 | } 45 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=prompt) 46 | response_json = response.json() 47 | 48 | # Process the response 49 | # Check if the response is valid and contains the expected data 50 | if 'choices' in response_json and len(response_json['choices']) > 0: 51 | completion_text = response_json['choices'][0].get('message', {}).get('content', '') 52 | else: 53 | completion_text = "" 54 | 55 | return completion_text 56 | 57 | 58 | if __name__ == "__main__": 59 | # Load the API key 60 | with open("openai_key.txt", "r") as api_key_file: 61 | api_key = api_key_file.read() 62 | 63 | # Generate the data 64 | data = [] 65 | error_count = 0 66 | for i in tqdm(range(300000)): 67 | text_idx = i // 3 68 | font_idx = i % 3 69 | image_idx = str(text_idx) + '_' + str(font_idx) 70 | 71 | try: 72 | with open("test_dataset/texts/" + image_idx + ".txt", 'r') as file: 73 | texts = file.read() 74 | except: 75 | error_count += 1 76 | continue 77 | 78 | source_image_path = f"test_dataset/target_curved/{image_idx}.png" 79 | target_image_path = f"test_dataset/target/{image_idx}.png" 80 | # print('image_idx', image_idx, 'source_image_path', source_image_path, 'target_image_path', target_image_path, 'texts', texts) 81 | 82 | data.append({"source": source_image_path, "target": target_image_path, 83 | "prompt": f'A black background with the texts "{texts}" that have no shape distortion, curvature, or rotation. Follow the same fonts in the condition image.'}) 84 | 85 | # font_description = query_openai_gpt_4v(source_image_path, api_key) 86 | # print('Font description:', font_description) 87 | 88 | # data.append({"source": source_image_path, "target": target_image_path, 89 | # "prompt": f'A black background with the texts "{texts}" that have no shape distortion, curvature, or rotation.' + font_description}) 90 | 91 | print('Error count:', error_count, 'number of data:', len(data)) 92 | 93 | # Write to a JSON file 94 | with open('prompt.json', 'w') as file: 95 | for entry in data: 96 | json.dump(entry, file) 97 | file.write('\n') 98 | -------------------------------------------------------------------------------- /synthetic_dataset/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==23.2.0 6 | cachetools==5.3.3 7 | certifi==2024.2.2 8 | charset-normalizer==3.3.2 9 | contourpy==1.1.1 10 | cycler==0.12.1 11 | datasets==2.19.0 12 | dill==0.3.8 13 | filelock==3.13.4 14 | fonttools==4.51.0 15 | frozenlist==1.4.1 16 | fsspec==2024.3.1 17 | google-auth==2.29.0 18 | google-auth-oauthlib==1.0.0 19 | grpcio==1.62.2 20 | huggingface-hub==0.22.2 21 | idna==3.7 22 | importlib_metadata==7.1.0 23 | importlib_resources==6.4.0 24 | Jinja2==3.1.3 25 | kiwisolver==1.4.5 26 | lightning-utilities==0.11.2 27 | Markdown==3.6 28 | MarkupSafe==2.1.5 29 | matplotlib==3.7.5 30 | mpmath==1.3.0 31 | multidict==6.0.5 32 | multiprocess==0.70.16 33 | networkx==3.1 34 | numpy==1.24.4 35 | nvidia-cublas-cu12==12.1.3.1 36 | nvidia-cuda-cupti-cu12==12.1.105 37 | nvidia-cuda-nvrtc-cu12==12.1.105 38 | nvidia-cuda-runtime-cu12==12.1.105 39 | nvidia-cudnn-cu12==8.9.2.26 40 | nvidia-cufft-cu12==11.0.2.54 41 | nvidia-curand-cu12==10.3.2.106 42 | nvidia-cusolver-cu12==11.4.5.107 43 | nvidia-cusparse-cu12==12.1.0.106 44 | nvidia-nccl-cu12==2.19.3 45 | nvidia-nvjitlink-cu12==12.4.127 46 | nvidia-nvtx-cu12==12.1.105 47 | oauthlib==3.2.2 48 | opencv-python==4.9.0.80 49 | packaging==24.0 50 | pandas==2.0.3 51 | pillow==10.3.0 52 | protobuf==5.26.1 53 | pyarrow==16.0.0 54 | pyarrow-hotfix==0.6 55 | pyasn1==0.6.0 56 | pyasn1_modules==0.4.0 57 | pyparsing==3.1.2 58 | python-dateutil==2.9.0.post0 59 | pytorch-lightning==2.2.3 60 | pytz==2024.1 61 | PyYAML==6.0.1 62 | regex==2024.4.16 63 | requests==2.31.0 64 | requests-oauthlib==2.0.0 65 | rsa==4.9 66 | safetensors==0.4.3 67 | scipy==1.13.0 68 | six==1.16.0 69 | sympy==1.12 70 | tensorboard==2.14.0 71 | tensorboard-data-server==0.7.2 72 | tokenizers==0.19.1 73 | torch==2.2.2 74 | torchaudio==2.2.2 75 | torchmetrics==1.3.2 76 | torchvision==0.17.2 77 | tqdm==4.66.2 78 | transformers==4.40.1 79 | triton==2.2.0 80 | typing_extensions==4.11.0 81 | tzdata==2024.1 82 | urllib3==2.2.1 83 | Werkzeug==3.0.2 84 | xxhash==3.4.1 85 | yarl==1.9.4 86 | zipp==3.18.1 87 | -------------------------------------------------------------------------------- /synthetic_dataset/test_rectify.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def order_points(pts): 6 | # Initial sorting based on the x-coordinate 7 | xSorted = pts[np.argsort(pts[:, 0]), :] 8 | 9 | # Get the left-most and right-most points 10 | leftMost = xSorted[:2, :] 11 | rightMost = xSorted[2:, :] 12 | 13 | # Sort left-most coordinates according to their y-coordinates 14 | leftMost = leftMost[np.argsort(leftMost[:, 1]), :] 15 | (tl, bl) = leftMost 16 | 17 | # Sort the right-most coordinates according to their y-coordinates 18 | rightMost = rightMost[np.argsort(rightMost[:, 1]), :] 19 | (tr, br) = rightMost 20 | 21 | # Return the coordinates in top-left, top-right, bottom-right, bottom-left order 22 | return np.array([tl, tr, br, bl], dtype="float32") 23 | 24 | 25 | # Load image 26 | image = cv2.imread('toy_examples/target_curved/0.png') 27 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 28 | 29 | # Denoise and threshold 30 | _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV) 31 | 32 | # Detect edges and find contours 33 | edges = cv2.Canny(thresh, 100, 200) 34 | contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 35 | 36 | text_contour = contours[0] 37 | 38 | # Approximate the contour to a quadrilateral 39 | epsilon = 0.1 * cv2.arcLength(text_contour, True) 40 | approx = cv2.approxPolyDP(text_contour, epsilon, True) 41 | 42 | # Draw the contour (for visualization) 43 | cv2.drawContours(image, [approx], -1, (0, 255, 0), 3) 44 | 45 | # Save or display the image 46 | cv2.imwrite('image_with_quadrilateral.png', image) 47 | 48 | # # Calculate the bounding box for all contours and draw it 49 | # x_min = min([cv2.boundingRect(contour)[0] for contour in contours]) 50 | # y_min = min([cv2.boundingRect(contour)[1] for contour in contours]) 51 | # x_max = max([cv2.boundingRect(contour)[0] + cv2.boundingRect(contour)[2] for contour in contours]) 52 | # y_max = max([cv2.boundingRect(contour)[1] + cv2.boundingRect(contour)[3] for contour in contours]) 53 | # 54 | # # Draw bounding rectangle around the whole text 55 | # cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) 56 | # 57 | # # Save or display the image 58 | # cv2.imwrite('image_with_bounding_box.png', image) 59 | # 60 | # # Perspective correction (assuming approx is a quadrilateral) 61 | # src_pts = np.array([ 62 | # [x_min, y_min], 63 | # [x_max, y_min], 64 | # [x_max, y_max], 65 | # [x_min, y_max] 66 | # ], dtype='float32') 67 | # dst_pts = np.array([[0, 0], [512, 0], [512, 512], [0, 512]], dtype='float32') # Adjust size as needed 68 | # print('Source points:', src_pts, '\nDestination points:', dst_pts) 69 | # matrix = cv2.getPerspectiveTransform(src_pts, dst_pts) 70 | # warped = cv2.warpPerspective(gray, matrix, (512, 512)) 71 | # cv2.imwrite('warped.png', warped) 72 | # 73 | # # Inverting colors and changing to white text on black background 74 | # _, final_binary = cv2.threshold(warped, 127, 255, cv2.THRESH_BINARY_INV) 75 | # 76 | # # Save the result to a file 77 | # cv2.imwrite('rectified_text.png', final_binary) 78 | -------------------------------------------------------------------------------- /synthetic_dataset/unet_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms 5 | from PIL import Image, ImageDraw 6 | import numpy as np 7 | import math 8 | import cv2 9 | 10 | from generate_text_transformation_pairs import gaussian 11 | from restore_from_transformations import find_min_max_coordinates 12 | 13 | 14 | class SyntheticDataset(Dataset): 15 | def __init__(self, images_dir, targets_curved_dir, target_corners_dir, target_midlines_dir, image_size, step): 16 | if step == 'extract': 17 | self.sources_dir = images_dir 18 | self.targets_dir = targets_curved_dir 19 | elif step == 'rectify': 20 | self.sources_dir = targets_curved_dir 21 | self.target_corners_dir = target_corners_dir 22 | self.target_midlines_dir = target_midlines_dir 23 | else: 24 | raise ValueError('Invalid step. Please choose between "extract" and "rectify"') 25 | 26 | self.images_dir = images_dir 27 | self.images = os.listdir(images_dir) 28 | self.step = step 29 | self.transform = transforms.Compose([transforms.ToTensor(), 30 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 31 | transforms.Resize((image_size, image_size), antialias=True)]) 32 | self.transform_grayscale = transforms.Compose([transforms.ToTensor(), 33 | transforms.Resize((image_size, image_size), antialias=True)]) 34 | 35 | def __len__(self): 36 | return len(self.images) 37 | 38 | def __getitem__(self, idx): 39 | img_name = self.images[idx] 40 | 41 | if self.step == 'extract': 42 | source_path = os.path.join(self.sources_dir, img_name) 43 | target_path = os.path.join(self.targets_dir, img_name) 44 | 45 | source = Image.open(source_path).convert('RGB') 46 | target = Image.open(target_path).convert('L') 47 | 48 | if self.transform is not None: 49 | source = self.transform(source) 50 | target = self.transform_grayscale(target) 51 | 52 | # Process the target image to extract the binary mask of the texts 53 | target = (target != 0).squeeze(0).float() 54 | return source, target 55 | else: 56 | source_path = os.path.join(self.sources_dir, img_name) 57 | target_corners_path = os.path.join(self.target_corners_dir, img_name) 58 | target_midlines_path = os.path.join(self.target_midlines_dir, img_name) 59 | 60 | source = Image.open(source_path).convert('RGB') 61 | target_corners = Image.open(target_corners_path).convert('L') 62 | target_midlines = Image.open(target_midlines_path) 63 | target_midline_endpoints = self.draw_end_points(source.size, target_midlines) 64 | target_midlines = self.convert_line_to_gaussian(target_midlines) 65 | target_midlines = target_midlines.convert('L') 66 | 67 | if self.transform is not None: 68 | source = self.transform(source) 69 | target_corners = self.transform_grayscale(target_corners) 70 | target_midlines = self.transform_grayscale(target_midlines) 71 | target_midline_endpoints = self.transform_grayscale(target_midline_endpoints) 72 | 73 | # Process the target image to extract the binary mask of the texts 74 | source = (source != 0).any(axis=0).float().unsqueeze(0).repeat(3, 1, 1) 75 | target_corners = target_corners.squeeze(0).float() 76 | target_midlines = target_midlines.squeeze(0).float() 77 | target_midline_endpoints = target_midline_endpoints.squeeze(0).float() 78 | 79 | return source, target_corners, target_midlines, target_midline_endpoints 80 | 81 | 82 | def draw_end_points(self, image_size, target_midlines): 83 | midline_start, midline_end = find_min_max_coordinates(target_midlines) 84 | 85 | # Create an image for the rectangle 86 | end_points_img = Image.new('L', image_size, 'black') 87 | rect_draw = ImageDraw.Draw(end_points_img) 88 | circle_diameter = 20 # You can adjust the size of the circles here 89 | sigma = circle_diameter / 3 # Standard deviation for Gaussian blur 90 | 91 | ends = [midline_start, midline_end] 92 | 93 | # Draw Gaussian dots at each end 94 | for end in ends: 95 | for dx in range(-circle_diameter, circle_diameter): 96 | for dy in range(-circle_diameter, circle_diameter): 97 | dist = math.sqrt(dx ** 2 + dy ** 2) 98 | if dist <= circle_diameter: 99 | intensity = int(255 * gaussian(dist, 0, sigma)) 100 | x = end[0] + dx 101 | y = end[1] + dy 102 | if 0 <= x < image_size[0] and 0 <= y < image_size[1]: # Check bounds 103 | rect_draw.point((x, y), fill=intensity) 104 | 105 | return np.array(end_points_img) 106 | 107 | 108 | def convert_line_to_gaussian(self, target_midline): 109 | target_midline = np.array(target_midline) 110 | blurred_midline = cv2.GaussianBlur(target_midline, (15, 15), 0) 111 | # blurred_midline = blurred_midline + target_midline 112 | # blurred_midline[blurred_midline > 255] = 255 113 | blurred_midline = Image.fromarray(blurred_midline) 114 | # blurred_midline = Image.fromarray(blurred_midline.astype(np.uint8)) 115 | return blurred_midline -------------------------------------------------------------------------------- /synthetic_dataset/unet_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | import yaml 6 | 7 | from unet_models import ModifiedUNet 8 | 9 | 10 | def load_checkpoint(model, checkpoint_path, device): 11 | if device.type == 'cuda': 12 | device_index = device.index if device.index is not None else 0 13 | else: 14 | device_index = 'cpu' 15 | map_location = lambda storage, loc: storage.cuda(device_index) 16 | 17 | # Load the checkpoint on CPU 18 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 19 | 20 | # Adjust the keys: remove 'module.' prefix 21 | adjusted_checkpoint = {} 22 | for k, v in checkpoint.items(): 23 | name = k[7:] if k.startswith('module.') else k # remove `module.` when not using DDP in model inference 24 | adjusted_checkpoint[name] = v 25 | 26 | # Load the weights into the model 27 | model.load_state_dict(adjusted_checkpoint) 28 | model.to(device) 29 | return model 30 | 31 | 32 | def inference(args, model_path, image_path, output_path, device): 33 | # Load the model 34 | unet_model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=args['model']['out_channels'], init_features=32, pretrained=False) 35 | unet_model = ModifiedUNet(unet_model, base_model_final_channels=args['model']['out_channels']) 36 | # print(unet_model) 37 | 38 | unet_model = load_checkpoint(unet_model, model_path, device) 39 | unet_model.eval() 40 | 41 | # Load and transform the image 42 | transform = transforms.Compose([ 43 | transforms.Resize((args['dataset']['image_size'], args['dataset']['image_size'])), 44 | transforms.ToTensor() 45 | ]) 46 | image = Image.open(image_path).convert('RGB') 47 | image = transform(image).unsqueeze(0).to(device) 48 | 49 | # Perform inference 50 | with torch.no_grad(): 51 | outputs = unet_model(image) 52 | binary_mask = torch.sigmoid(outputs['binary_mask']).squeeze(0).cpu() 53 | color_map = outputs['color_map'].squeeze(0).cpu() 54 | 55 | # Process output 56 | binary_mask = (binary_mask > 0.5).float() # Thresholding the binary mask 57 | colored_text = color_map * binary_mask.expand_as(color_map) # Apply color map to the text 58 | 59 | # Convert to PIL Image and save 60 | binary_mask = transforms.ToPILImage()(binary_mask) 61 | binary_mask.save(output_path + '_binary_mask.png') 62 | colored_text = transforms.ToPILImage()(colored_text) 63 | colored_text.save(output_path + '_colored_text.png') 64 | 65 | 66 | if __name__ == "__main__": 67 | # This function performs single-image inference 68 | print('Torch', torch.__version__, 'Torchvision', torchvision.__version__) 69 | # load hyperparameters 70 | try: 71 | with open('unet_train_config.yaml', 'r') as file: 72 | args = yaml.safe_load(file) 73 | except Exception as e: 74 | print('Error reading the config file', e) 75 | 76 | model_path = 'unet_ckpts/unet_model_' + str(args['training']['test_epoch']) + '.pth' 77 | image_path = 'toy_examples/noisy_wrapped_text_2.png' 78 | output_path = 'toy_examples/recovered_text_2' 79 | 80 | torch.manual_seed(0) 81 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 82 | print('device', device) 83 | print('torch.distributed.is_available', torch.distributed.is_available()) 84 | 85 | print(args) 86 | inference(args, model_path, image_path, output_path, device) 87 | -------------------------------------------------------------------------------- /synthetic_dataset/unet_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.ops as ops 4 | 5 | 6 | class DeformConvSequential(nn.Module): 7 | def __init__(self, offset_conv, deform_conv): 8 | super(DeformConvSequential, self).__init__() 9 | self.offset_conv = offset_conv 10 | self.deform_conv = deform_conv 11 | 12 | def forward(self, x): 13 | # Move the entire module to the device of x 14 | self.to(x.device) 15 | 16 | # the output from the offset_conv layer is passed as the offset parameter to the deform_conv layer. 17 | offset = self.offset_conv(x) 18 | return self.deform_conv(input=x, offset=offset) 19 | 20 | 21 | class ModifiedUNet(nn.Module): 22 | """The U-Net model is a fully convolutional neural network and the modified network replaces all the Conv2d layers with DeformConv2d layers. 23 | In addition, the final layer of the U-Net needs to be modified to produce three outputs: 24 | - Binary Mask: A single-channel output with a sigmoid activation for the binary mask. 25 | - Color Map: A three-channel output (assuming RGB) for the color map. 26 | - Feature Map for Text Perceptual Loss: Depending on the design, this could be a feature map from one of the intermediate layers. 27 | """ 28 | def __init__(self, base_model, base_model_final_channels, step, deformable=False): 29 | super(ModifiedUNet, self).__init__() 30 | self.step = step 31 | self.base_model = base_model 32 | self.base_model_final_channels = base_model_final_channels 33 | 34 | if step == 'extract': 35 | self.prediction_head = nn.Conv2d(in_channels=base_model_final_channels, out_channels=1, kernel_size=1) # sigmoid is in nn.BCEWithLogitsLoss 36 | elif step == 'rectify': 37 | self.prediction_head = nn.Conv2d(in_channels=base_model_final_channels, out_channels=3, kernel_size=1) 38 | else: 39 | raise ValueError('step must be either "extract" or "rectify"') 40 | # self.color_map_head = nn.Conv2d(in_channels=self.base_model_final_channels, out_channels=3, kernel_size=1) 41 | 42 | if deformable: 43 | # Replace Conv2d with DeformConv2d 44 | self.replace_conv2d_with_deformconv2d(self.base_model) 45 | 46 | def replace_conv2d_in_sequential(self, sequential_module): 47 | """Replace all the Conv2d layers with DeformConv2d layers in a Sequential module. 48 | Example: (encoder1): Sequential( 49 | (enc1conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 50 | (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 51 | (enc1relu1): ReLU(inplace=True) 52 | (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 53 | (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 54 | (enc1relu2): ReLU(inplace=True) 55 | ) 56 | becomes 57 | (encoder1): Sequential( 58 | (enc1conv1): DeformConvSequential( 59 | (offset_conv): Conv2d(3, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 60 | (deform_conv): DeformConv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 61 | ) 62 | (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 63 | (enc1relu1): ReLU(inplace=True) 64 | (enc1conv2): DeformConvSequential( 65 | (offset_conv): Conv2d(32, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 66 | (deform_conv): DeformConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 67 | ) 68 | (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 69 | (enc1relu2): ReLU(inplace=True) 70 | ) 71 | """ 72 | new_sequential = nn.Sequential() 73 | for layer_name, layer in sequential_module.named_children(): 74 | if isinstance(layer, nn.Conv2d): 75 | # Define the new DeformConv2d layer 76 | deform_conv = ops.DeformConv2d(layer.in_channels, layer.out_channels, 77 | kernel_size=layer.kernel_size, stride=layer.stride, 78 | padding=layer.padding, dilation=layer.dilation) 79 | 80 | # Define the offset convolution layer 81 | offset_channels = 2 * layer.kernel_size[0] * layer.kernel_size[1] 82 | offset_conv = nn.Conv2d(layer.in_channels, offset_channels, kernel_size=layer.kernel_size, 83 | stride=layer.stride, padding=layer.padding) 84 | 85 | # Replace the Conv2d layer with a sequential container of offset_conv and deform_conv 86 | new_layer = DeformConvSequential(offset_conv, deform_conv) 87 | new_sequential.add_module(layer_name, new_layer) 88 | else: 89 | new_sequential.add_module(layer_name, layer) 90 | 91 | return new_sequential 92 | 93 | def replace_conv2d_with_deformconv2d(self, model): 94 | for name, module in model.named_children(): 95 | if isinstance(module, nn.Sequential): 96 | # Replace Conv2d layers in the Sequential module 97 | new_sequential = self.replace_conv2d_in_sequential(module) 98 | setattr(model, name, new_sequential) 99 | 100 | 101 | def forward(self, x): 102 | # Forward pass through the base model 103 | base_output = self.base_model(x) 104 | output = self.prediction_head(base_output) 105 | # color_map = self.color_map_head(base_output) 106 | 107 | # Extract the feature map for perceptual loss 108 | # perceptual_feature_map = base_output # or some intermediate layer 109 | 110 | return output #{'binary_mask': binary_mask, 'color_map': color_map} 111 | 112 | 113 | # # Example usage: 114 | # # Load the U-Net model from https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/ 115 | # # set the out_channels from 1 to 16 and attach it to our ModifiedUNet's output layers 116 | # unet_model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', 117 | # in_channels=3, out_channels=16, init_features=32, pretrained=False) 118 | # 119 | # unet_model = ModifiedUNet(unet_model, base_model_final_channels=16) # the out_channels of the U-Net model 120 | # print(unet_model) 121 | -------------------------------------------------------------------------------- /synthetic_dataset/unet_ocr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import ( 3 | TrOCRConfig, 4 | TrOCRProcessor, 5 | TrOCRForCausalLM, 6 | ViTConfig, 7 | ViTModel, 8 | VisionEncoderDecoderModel, 9 | ) 10 | import requests 11 | from PIL import Image 12 | 13 | 14 | # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel 15 | # init vision2text model with random weights 16 | encoder = ViTModel(ViTConfig()) 17 | decoder = TrOCRForCausalLM(TrOCRConfig()) 18 | model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) 19 | 20 | # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel` 21 | processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") 22 | model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") 23 | 24 | # load image from the IAM dataset 25 | url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" 26 | image = Image.open(requests.get(url, stream=True).raw).convert("RGB") 27 | pixel_values = processor(image, return_tensors="pt").pixel_values 28 | 29 | # training 30 | model.config.decoder_start_token_id = processor.tokenizer.cls_token_id 31 | model.config.pad_token_id = processor.tokenizer.pad_token_id 32 | model.config.vocab_size = model.config.decoder.vocab_size 33 | 34 | outputs = model(pixel_values, output_hidden_states=True, return_dict=True) 35 | print('outputs', [key for key in outputs.__dict__]) 36 | 37 | hidden_states = outputs.decoder_hidden_states[-1] 38 | print('hidden_states', hidden_states.shape) 39 | 40 | # inference 41 | generated_ids = model.generate(pixel_values) 42 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] 43 | print('generated_text', generated_text) -------------------------------------------------------------------------------- /synthetic_dataset/unet_train_config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | images_dir: 'test_dataset/source' 3 | targets_dir: 'test_dataset/target' 4 | targets_curved_dir: 'test_dataset/target_curved' 5 | target_corners_dir: 'test_dataset/corners' 6 | target_midlines_dir: 'test_dataset/midline' 7 | image_size: 512 8 | train_val_split: 0.95 9 | percent_train: 1 #0.001 10 | percent_valid: 1 #0.01 11 | model: 12 | out_channels: 16 13 | bottleneck_out_dim: 524288 14 | training: 15 | run_mode: 'train' 16 | step: 'rectify' # 'extract', 'rectify' 17 | checkpoint_path: 'unet_ckpts' 18 | continue_train: False 19 | batch_size: 32 20 | learning_rate: 0.0001 21 | patience: 10 22 | start_epoch: 0 23 | test_epoch: 10 24 | ckpt_epoch_extract: 19 25 | ckpt_epoch_rectify: 29 26 | num_epochs: 20 27 | print_every: 100 28 | val_every: 500 29 | lambda_loss_color_map: 1 30 | bce_pos_weight: 1 31 | bce_weight: 0.3 32 | dice_weight: 0.7 33 | lambda_reg: 10 -------------------------------------------------------------------------------- /tool_add_anytext.py: -------------------------------------------------------------------------------- 1 | ''' 2 | AnyText: Multilingual Visual Text Generation And Editing 3 | Paper: https://arxiv.org/abs/2311.03054 4 | Code: https://github.com/tyxsspa/AnyText 5 | Copyright (c) Alibaba, Inc. and its affiliates. 6 | ''' 7 | import sys 8 | import os 9 | import torch 10 | from cldm.model import create_model 11 | 12 | add_ocr = True # merge OCR model 13 | ocr_path = './ocr_weights/ppv3_rec.pth' 14 | 15 | 16 | if len(sys.argv) == 3: 17 | input_path = sys.argv[1] 18 | output_path = sys.argv[2] 19 | else: 20 | print('Args are wrong, using default input and output path!') 21 | input_path = './models/v1-5-pruned.ckpt' # sd1.5 22 | output_path = './models/anytext_sd15_scratch.ckpt' 23 | 24 | assert os.path.exists(input_path), 'Input model does not exist.' 25 | assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.' 26 | 27 | 28 | def get_node_name(name, parent_name): 29 | if len(name) <= len(parent_name): 30 | return False, '' 31 | p = name[:len(parent_name)] 32 | if p != parent_name: 33 | return False, '' 34 | return True, name[len(parent_name):] 35 | 36 | 37 | model = create_model(config_path='./models_yaml/anytext_sd15.yaml') 38 | 39 | pretrained_weights = torch.load(input_path) 40 | if 'state_dict' in pretrained_weights: 41 | pretrained_weights = pretrained_weights['state_dict'] 42 | 43 | scratch_dict = model.state_dict() 44 | 45 | target_dict = {} 46 | for k in scratch_dict.keys(): 47 | is_control, name = get_node_name(k, 'control_') 48 | if is_control: 49 | copy_k = 'model.diffusion_' + name 50 | else: 51 | copy_k = k 52 | if copy_k in pretrained_weights: 53 | target_dict[k] = pretrained_weights[copy_k].clone() 54 | else: 55 | target_dict[k] = scratch_dict[k].clone() 56 | print(f'These weights are newly added: {k}') 57 | 58 | if add_ocr: 59 | ocr_weights = torch.load(ocr_path) 60 | if 'state_dict' in ocr_weights: 61 | ocr_weights = ocr_weights['state_dict'] 62 | for key in ocr_weights: 63 | new_key = 'text_predictor.' + key 64 | target_dict[new_key] = ocr_weights[key] 65 | print('ocr weights are added!') 66 | 67 | model.load_state_dict(target_dict, strict=True) 68 | torch.save(model.state_dict(), output_path) 69 | print('Done.') 70 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | from t3_dataset import T3DataSet 6 | from cldm.logger import ImageLogger 7 | from cldm.model import create_model, load_state_dict 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | import shutil 10 | import torch.multiprocessing as mp 11 | from torch.utils.data.distributed import DistributedSampler 12 | 13 | NUM_NODES = 1 14 | # Configs 15 | batch_size = 2 # default 6 16 | grad_accum = 3 # enable perceptual loss may cost a lot of VRAM, you can set a smaller batch_size and make sure grad_accum * batch_size = 6 17 | ckpt_path = None # if not None, load ckpt_path and continue training task, will not load "resume_path" 18 | resume_path = './models/anytext_v1.1.ckpt' # './models/anytext_sd15_scratch.ckpt' # finetune from scratch 19 | model_config = './models_yaml/anytext_sd15.yaml' # use anytext_sd15_perloss.yaml to enable perceptual loss 20 | invalid_json_path = './Rethinking-Text-Segmentation/log/images/ocr_verified/invalid_gly_lines.json' 21 | logger_freq = 5000 22 | learning_rate = 2e-5 # default 2e-5 23 | mask_ratio = 1 # default 0.5, ratio of mask for inpainting(text editing task), set 0 to disable 24 | wm_thresh = 0.5 # set 0.5 to skip watermark imgs from training(ch:~25%, en:~8%, @Precision93.67%+Recall88.80%), 1.0 not skip 25 | root_dir = './models' # path for save checkpoints 26 | dataset_percent = 1.0 # 1.0 use full datasets, 0.0566 use ~200k images for ablation study 27 | save_steps = 5000 # step frequency of saving checkpoints 28 | save_epochs = None # epoch frequency of saving checkpoints 29 | max_epochs = 10 # default 60 30 | assert (save_steps is None) != (save_epochs is None) 31 | 32 | 33 | if __name__ == '__main__': 34 | # mp.set_start_method('spawn', force=True) 35 | log_img = os.path.join(root_dir, 'image_log/train') 36 | if os.path.exists(log_img): 37 | try: 38 | shutil.rmtree(log_img) 39 | except OSError: 40 | pass 41 | # model = create_model(model_config).cpu() 42 | model = create_model(model_config) 43 | if ckpt_path is None: 44 | model.load_state_dict(load_state_dict(resume_path, location='cpu')) 45 | model.learning_rate = learning_rate 46 | model.sd_locked = True 47 | model.only_mid_control = False 48 | model.unlockKV = False 49 | 50 | checkpoint_callback = ModelCheckpoint( 51 | every_n_train_steps=save_steps, 52 | every_n_epochs=save_epochs, 53 | save_top_k=-1, 54 | save_last=True, 55 | monitor="global_step", 56 | mode="max", 57 | ) 58 | 59 | total_params = sum(p.numel() for p in model.parameters()) 60 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 61 | 62 | print(f"Total parameters: {total_params:,}") 63 | print(f"Trainable parameters: {trainable_params:,}") 64 | 65 | json_paths = [ 66 | r'/tmp/datasets/AnyWord-3M/link_download/laion/data_v1.1.json', 67 | r'/tmp/datasets/AnyWord-3M/link_download/wukong_1of5/data_v1.1.json', 68 | r'/tmp/datasets/AnyWord-3M/link_download/wukong_2of5/data_v1.1.json', 69 | r'/tmp/datasets/AnyWord-3M/link_download/wukong_3of5/data_v1.1.json', 70 | r'/tmp/datasets/AnyWord-3M/link_download/wukong_4of5/data_v1.1.json', 71 | r'/tmp/datasets/AnyWord-3M/link_download/wukong_5of5/data_v1.1.json', 72 | r'/tmp/datasets/AnyWord-3M/link_download/ocr_data/Art/data.json', 73 | r'/tmp/datasets/AnyWord-3M/link_download/ocr_data/COCO_Text/data.json', 74 | r'/tmp/datasets/AnyWord-3M/link_download/ocr_data/icdar2017rctw/data.json', 75 | r'/tmp/datasets/AnyWord-3M/link_download/ocr_data/LSVT/data.json', 76 | r'/tmp/datasets/AnyWord-3M/link_download/ocr_data/mlt2019/data.json', 77 | r'/tmp/datasets/AnyWord-3M/link_download/ocr_data/MTWI2018/data.json', 78 | r'/tmp/datasets/AnyWord-3M/link_download/ocr_data/ReCTS/data.json' 79 | ] 80 | glyph_paths = [ 81 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/laion', 82 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/wukong_1of5', 83 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/wukong_2of5', 84 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/wukong_3of5', 85 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/wukong_4of5', 86 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/wukong_5of5', 87 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/Art', 88 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/COCO_Text', 89 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/icdar2017rctw', 90 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/LSVT', 91 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/mlt2019', 92 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/MTWI2018', 93 | r'./Rethinking-Text-Segmentation/log/images/ocr_verified/ReCTS' 94 | ] 95 | 96 | dataset = T3DataSet(json_paths, glyph_paths, max_lines=5, max_chars=20, caption_pos_prob=0.0, mask_pos_prob=1.0, mask_img_prob=mask_ratio, glyph_scale=2, percent=dataset_percent, debug=False, using_dlc=False, wm_thresh=wm_thresh, invalid_json_path=invalid_json_path) 97 | # sampler = DistributedSampler(dataset) 98 | # dataloader = DataLoader(dataset, num_workers=8, batch_size=batch_size, sampler=sampler) 99 | dataloader = DataLoader(dataset, num_workers=8, persistent_workers=True, batch_size=batch_size, shuffle=True) 100 | logger = ImageLogger(batch_frequency=logger_freq) 101 | trainer = pl.Trainer(gpus=-1, precision=32, max_epochs=max_epochs, num_nodes=NUM_NODES, accumulate_grad_batches=grad_accum, callbacks=[logger, checkpoint_callback], default_root_dir=root_dir, strategy='ddp') 102 | print('Start training...') 103 | trainer.fit(model, dataloader, ckpt_path=ckpt_path) 104 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import cv2 4 | 5 | 6 | def save_images(img_list, folder): 7 | if not os.path.exists(folder): 8 | os.makedirs(folder) 9 | now = datetime.datetime.now() 10 | date_str = now.strftime("%Y-%m-%d") 11 | folder_path = os.path.join(folder, date_str) 12 | if not os.path.exists(folder_path): 13 | os.makedirs(folder_path) 14 | time_str = now.strftime("%H_%M_%S") 15 | for idx, img in enumerate(img_list): 16 | image_number = idx + 1 17 | filename = f"{time_str}_{image_number}.jpg" 18 | save_path = os.path.join(folder_path, filename) 19 | cv2.imwrite(save_path, img[..., ::-1]) 20 | 21 | 22 | def check_channels(image): 23 | channels = image.shape[2] if len(image.shape) == 3 else 1 24 | if channels == 1: 25 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 26 | elif channels > 3: 27 | image = image[:, :, :3] 28 | return image 29 | 30 | 31 | def resize_image(img, max_length=768): 32 | height, width = img.shape[:2] 33 | max_dimension = max(height, width) 34 | 35 | if max_dimension > max_length: 36 | scale_factor = max_length / max_dimension 37 | new_width = int(round(width * scale_factor)) 38 | new_height = int(round(height * scale_factor)) 39 | new_size = (new_width, new_height) 40 | img = cv2.resize(img, new_size) 41 | height, width = img.shape[:2] 42 | img = cv2.resize(img, (width-(width % 64), height-(height % 64))) 43 | return img 44 | 45 | 46 | def clamp(value, min_val, max_val): 47 | return max(min_val, min(value, max_val)) 48 | 49 | 50 | def update_font_filename(): 51 | directory = "./fonts" 52 | for filename in os.listdir(directory): 53 | # Check if there are spaces in the filename 54 | if " " in filename: 55 | # Construct the new filename by replacing spaces with underscores 56 | new_filename = filename.replace(" ", "_") 57 | 58 | # Get the full path of the old and new filenames 59 | old_file = os.path.join(directory, filename) 60 | new_file = os.path.join(directory, new_filename) 61 | 62 | # Rename the file 63 | os.rename(old_file, new_file) 64 | print(f"Renamed: '{filename}' to '{new_filename}'") 65 | --------------------------------------------------------------------------------