├── .gitignore ├── LICENSE ├── README.md ├── assets ├── dualencoder.png ├── edit.png ├── font-extro.png ├── font-intro.png ├── naturallang.png └── stylecond.png ├── checkpoints └── chartokenizer │ ├── char_vocab.json │ └── tokenizer_config.json ├── configs └── config_charinpaint.yaml ├── diffusers ├── __init__.py ├── commands │ ├── __init__.py │ ├── diffusers_cli.py │ └── env.py ├── configuration_utils.py ├── dependency_versions_check.py ├── dependency_versions_table.py ├── dynamic_modules_utils.py ├── hub_utils.py ├── modeling_flax_pytorch_utils.py ├── modeling_flax_utils.py ├── modeling_utils.py ├── models │ ├── README.md │ ├── __init__.py │ ├── attention.py │ ├── attention_flax.py │ ├── embeddings.py │ ├── embeddings_flax.py │ ├── resnet.py │ ├── resnet_flax.py │ ├── unet_1d.py │ ├── unet_1d_blocks.py │ ├── unet_2d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_blocks_flax.py │ ├── unet_2d_condition.py │ ├── unet_2d_condition_flax.py │ ├── vae.py │ └── vae_flax.py ├── onnx_utils.py ├── optimization.py ├── pipeline_flax_utils.py ├── pipeline_utils.py ├── pipelines │ ├── README.md │ ├── __init__.py │ ├── dance_diffusion │ │ ├── __init__.py │ │ └── pipeline_dance_diffusion.py │ ├── ddim │ │ ├── __init__.py │ │ └── pipeline_ddim.py │ ├── ddpm │ │ ├── __init__.py │ │ └── pipeline_ddpm.py │ ├── latent_diffusion │ │ ├── __init__.py │ │ ├── pipeline_latent_diffusion.py │ │ └── pipeline_latent_diffusion_superresolution.py │ ├── latent_diffusion_uncond │ │ ├── __init__.py │ │ └── pipeline_latent_diffusion_uncond.py │ ├── pndm │ │ ├── __init__.py │ │ └── pipeline_pndm.py │ ├── repaint │ │ ├── __init__.py │ │ └── pipeline_repaint.py │ ├── score_sde_ve │ │ ├── __init__.py │ │ └── pipeline_score_sde_ve.py │ ├── stable_diffusion │ │ ├── README.md │ │ ├── __init__.py │ │ ├── pipeline_cycle_diffusion.py │ │ ├── pipeline_flax_stable_diffusion.py │ │ ├── pipeline_onnx_stable_diffusion.py │ │ ├── pipeline_onnx_stable_diffusion_img2img.py │ │ ├── pipeline_onnx_stable_diffusion_inpaint.py │ │ ├── pipeline_stable_diffusion.py │ │ ├── pipeline_stable_diffusion_img2img.py │ │ ├── pipeline_stable_diffusion_inpaint.py │ │ ├── pipeline_stable_diffusion_inpaint_legacy.py │ │ ├── safety_checker.py │ │ └── safety_checker_flax.py │ ├── stochastic_karras_ve │ │ ├── __init__.py │ │ └── pipeline_stochastic_karras_ve.py │ └── vq_diffusion │ │ ├── __init__.py │ │ └── pipeline_vq_diffusion.py ├── schedulers │ ├── README.md │ ├── __init__.py │ ├── scheduling_ddim.py │ ├── scheduling_ddim_flax.py │ ├── scheduling_ddpm.py │ ├── scheduling_ddpm_flax.py │ ├── scheduling_dpmsolver_multistep.py │ ├── scheduling_dpmsolver_multistep_flax.py │ ├── scheduling_euler_ancestral_discrete.py │ ├── scheduling_euler_discrete.py │ ├── scheduling_ipndm.py │ ├── scheduling_karras_ve.py │ ├── scheduling_karras_ve_flax.py │ ├── scheduling_lms_discrete.py │ ├── scheduling_lms_discrete_flax.py │ ├── scheduling_pndm.py │ ├── scheduling_pndm_flax.py │ ├── scheduling_repaint.py │ ├── scheduling_sde_ve.py │ ├── scheduling_sde_ve_flax.py │ ├── scheduling_sde_vp.py │ ├── scheduling_utils.py │ ├── scheduling_utils_flax.py │ └── scheduling_vq_diffusion.py ├── training_utils.py └── utils │ ├── __init__.py │ ├── deprecation_utils.py │ ├── dummy_flax_and_transformers_objects.py │ ├── dummy_flax_objects.py │ ├── dummy_pt_objects.py │ ├── dummy_torch_and_scipy_objects.py │ ├── dummy_torch_and_transformers_and_onnx_objects.py │ ├── dummy_torch_and_transformers_objects.py │ ├── import_utils.py │ ├── logging.py │ ├── model_card_template.md │ ├── outputs.py │ └── testing_utils.py ├── examples ├── mask0.png ├── mask1.png ├── mask2.png ├── mask3.png ├── mask4.png ├── sample0.png ├── sample1.png ├── sample2.png ├── sample3.png └── sample4.png ├── generate.py ├── requirements.txt ├── scripts ├── down_data.sh └── gen_synth.sh ├── src ├── __init__.py ├── abinet │ ├── __init__.py │ ├── abinet_base.py │ ├── configs │ │ ├── pretrain_language_model.yaml │ │ ├── pretrain_vision_model.yaml │ │ ├── pretrain_vision_model_sv.yaml │ │ ├── template.yaml │ │ ├── train_abinet.yaml │ │ ├── train_abinet_sv.yaml │ │ ├── train_abinet_wo_iter.yaml │ │ └── train_contrast_abinet.yaml │ ├── data │ │ ├── charset_36.txt │ │ ├── charset_62.txt │ │ └── charset_vn.txt │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── backbone.py │ │ ├── backbone_v2.py │ │ ├── losses.py │ │ ├── model.py │ │ ├── model_abinet.py │ │ ├── model_abinet_iter.py │ │ ├── model_alignment.py │ │ ├── model_language.py │ │ ├── model_vision.py │ │ ├── module_util.py │ │ ├── resnet.py │ │ └── transformer.py │ ├── transforms.py │ └── utils.py ├── dataset │ ├── __init__.py │ ├── sceneocr.py │ ├── synthocr.py │ ├── utils.py │ └── vocab.json ├── model │ ├── __init__.py │ ├── charencoder.py │ ├── unet_2d_blocks.py │ ├── unet_2d_multicondition.py │ └── utils.py └── trainers │ ├── __init__.py │ ├── callbacks.py │ ├── datawrapper.py │ ├── inpaint_trainer.py │ ├── utils.py │ └── vae_trainer.py ├── synthgenerator ├── README.md ├── __init__.py ├── generate_synth.py ├── multistyle_template.py ├── new_components.py ├── requirements.txt ├── resources │ ├── 100fonts │ │ └── chosen_fonts.json │ ├── charset │ │ ├── alphanum.txt │ │ ├── alphanum_lower.txt │ │ └── alphanum_special.txt │ ├── colormap │ │ └── iiit5k_gray.txt │ └── corpus │ │ └── wikicorpus.json ├── synthgen_config.yaml └── utils.py ├── tools └── create_mask │ ├── README.md │ ├── main.py │ └── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.ckpt 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Jiabao Ji, Guanhua Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/dualencoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/assets/dualencoder.png -------------------------------------------------------------------------------- /assets/edit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/assets/edit.png -------------------------------------------------------------------------------- /assets/font-extro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/assets/font-extro.png -------------------------------------------------------------------------------- /assets/font-intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/assets/font-intro.png -------------------------------------------------------------------------------- /assets/naturallang.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/assets/naturallang.png -------------------------------------------------------------------------------- /assets/stylecond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/assets/stylecond.png -------------------------------------------------------------------------------- /checkpoints/chartokenizer/char_vocab.json: -------------------------------------------------------------------------------- 1 | [" ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "!", "\"", "#", "$", "%", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~"] -------------------------------------------------------------------------------- /checkpoints/chartokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | } -------------------------------------------------------------------------------- /configs/config_charinpaint.yaml: -------------------------------------------------------------------------------- 1 | # train on combination of all datas 2 | data: 3 | target: "src.trainers.WrappedDataModule" 4 | batch_size: 4 5 | scene_data: ocr-dataset/ 6 | synth_data: ocr-dataset/synth/ 7 | train: 8 | size: 256 9 | max_num: 2000000 # diffly choose this number 10 | augconf: 11 | synth: 12 | center: 0.1 13 | pad: false 14 | scene: 15 | expand_mask: 16 | center_mask: 0.6 17 | additional_mask: 0.4 18 | crop: 19 | mask_image_ratio: 15 20 | rotate: 21 | cat_prob: [1, 0, 0] 22 | angle_list: [-15, -30, -45, -60, -90, 15, 30, 45, 60, 90] 23 | rotate_range: 90 24 | 25 | dataconfs: 26 | ArT: 27 | type: scene 28 | label_path: ${data.scene_data}/ArT/train_labels.json 29 | image_dir: ${data.scene_data}/ArT/train_images/ 30 | 31 | COCO: 32 | type: scene 33 | label_path: ${data.scene_data}/COCO/cocotext.v2.json 34 | image_dir: ${data.scene_data}/COCO/train2014/ 35 | 36 | TextOCR: 37 | type: scene 38 | label_path: ${data.scene_data}/TextOCR/TextOCR_0.1_train.json 39 | image_dir: ${data.scene_data}/TextOCR/train_images/ 40 | 41 | Synthtiger: 42 | type: synth 43 | label_path: ${data.synth_data}/train_data.csv 44 | image_dir: ${data.synth_data}/ 45 | style_mode: same-same 46 | use_textbbox: false 47 | style_dropout: [0.5, 0.5] 48 | rand_mask_text: true 49 | 50 | validation: 51 | size: 256 52 | # max_num: 6400 # diffly choose this number 53 | augconf: 54 | synth: 55 | center: 1. 56 | pad: false 57 | scene: 58 | expand_mask: 59 | center_mask: 0. 60 | additional_mask: 0. 61 | crop: 62 | mask_image_ratio: 30 63 | rotate: 64 | cat_prob: [1, 0, 0] 65 | angle_list: [-15, -30, -45, -60, -90, 15, 30, 45, 60, 90] 66 | rotate_range: 90 67 | 68 | dataconfs: 69 | ArT: 70 | type: scene 71 | label_path: ${data.scene_data}/ArT/val_split.json 72 | image_dir: ${data.scene_data}/ArT/train_images/ 73 | 74 | COCO: 75 | type: scene 76 | label_path: ${data.scene_data}/COCO/cocotext.v2.val.json 77 | image_dir: ${data.scene_data}/COCO/train2014/ 78 | 79 | TextOCR: 80 | type: scene 81 | label_path: ${data.scene_data}/TextOCR/TextOCR_0.1_val.json 82 | image_dir: ${data.scene_data}/TextOCR/train_images/ 83 | 84 | model: 85 | source: raw 86 | target: "src.trainers.CharInpaintModelWrapper" 87 | pretrained_model_path: runwayml/stable-diffusion-inpainting 88 | loss_type: MaskMSELoss 89 | loss_alpha: 5 90 | base_learning_rate: 5.0e-5 91 | precision: 16 92 | weight_decay: 0.0 93 | adam_epsilon: 1.0e-8 94 | freeze_char_embedder: false 95 | optimize_vae: false 96 | vae: 97 | 98 | tokenizer: 99 | model_max_length: 20 100 | char_tokenizer: 101 | pretrained_path: checkpoints/chartokenizer 102 | pad_token: " " 103 | unk_token: " " 104 | model_max_length: 20 105 | char_embedder: 106 | vocab_size: 95 # by default 107 | embedding_dim: 32 108 | max_length: 20 109 | padding_idx: 0 110 | attention_head_dim: 2 111 | unet: 112 | attention_head_dim: { "text": 8, "char": 2 } 113 | cross_attention_dim: { "text": 768, "char": 32 } 114 | noise_scheduler: diffusers.DDIMScheduler 115 | 116 | lightning: 117 | logger: 118 | callbacks: 119 | checkpoint_callback: 120 | params: 121 | save_top_k: -1 122 | image_logger: 123 | target: "src.trainers.CharInpaintImageLogger" 124 | params: 125 | # train_batch_frequency: 2400 126 | # valid_batch_frequency: 500 127 | train_batch_frequency: 2 128 | valid_batch_frequency: 2 129 | disable_wandb: true 130 | generation_kwargs: 131 | num_inference_steps: 30 132 | num_sample_per_image: 3 133 | guidance_scale: 7.5 134 | seed: 42 135 | 136 | # NOTE: Download pretrained ABINet model from https://github.com/FangShancheng/ABINet.git and 137 | # put model checkpoints in checkpoints/abinet to use this callback 138 | # ocracc_logger: 139 | # target: "src.trainers.OCRAccLogger" 140 | # params: 141 | # train_eval_conf: 142 | # size: 256 143 | # augconf: ${data.validation.augconf} 144 | # max_num: 5 145 | # dataconfs: 146 | # TextOCR: 147 | # type: scene 148 | # label_path: ${data.scene_data}/TextOCR/TextOCR_0.1_train.json 149 | # image_dir: ${data.scene_data}/TextOCR/train_images/ 150 | # len_counter: 151 | # eachnum: 10 152 | 153 | # val_eval_conf: 154 | # size: 256 155 | # augconf: ${data.validation.augconf} 156 | # max_num: 5 157 | # dataconfs: 158 | # TextOCR: 159 | # type: scene 160 | # label_path: ${data.scene_data}/TextOCR/TextOCR_0.1_val.json 161 | # image_dir: ${data.scene_data}/TextOCR/train_images/ 162 | # max_num: 1000 163 | # base_log_dir: ${base_log_dir}/ocrlogs # will be set in code 164 | 165 | trainer: 166 | accelerator: gpu 167 | devices: [0, 1, 2, 3, 4, 5, 6, 7] 168 | strategy: ddp 169 | amp_backend: native 170 | log_every_n_steps: 16 # this is global step 171 | precision: 16 172 | max_epochs: 15 173 | check_val_every_n_epoch: 1 174 | accumulate_grad_batches: 8 175 | gradient_clip_val: 3. 176 | gradient_clip_algorithm: norm 177 | benchmark: true 178 | -------------------------------------------------------------------------------- /diffusers/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | is_flax_available, 3 | is_inflect_available, 4 | is_onnx_available, 5 | is_scipy_available, 6 | is_torch_available, 7 | is_transformers_available, 8 | is_unidecode_available, 9 | ) 10 | 11 | 12 | __version__ = "0.8.0.dev0" 13 | 14 | from .configuration_utils import ConfigMixin 15 | from .onnx_utils import OnnxRuntimeModel 16 | from .utils import logging 17 | 18 | 19 | if is_torch_available(): 20 | from .modeling_utils import ModelMixin 21 | from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel 22 | from .optimization import ( 23 | get_constant_schedule, 24 | get_constant_schedule_with_warmup, 25 | get_cosine_schedule_with_warmup, 26 | get_cosine_with_hard_restarts_schedule_with_warmup, 27 | get_linear_schedule_with_warmup, 28 | get_polynomial_decay_schedule_with_warmup, 29 | get_scheduler, 30 | ) 31 | from .pipeline_utils import DiffusionPipeline 32 | from .pipelines import ( 33 | DanceDiffusionPipeline, 34 | DDIMPipeline, 35 | DDPMPipeline, 36 | KarrasVePipeline, 37 | LDMPipeline, 38 | LDMSuperResolutionPipeline, 39 | PNDMPipeline, 40 | RePaintPipeline, 41 | ScoreSdeVePipeline, 42 | ) 43 | from .schedulers import ( 44 | DDIMScheduler, 45 | DDPMScheduler, 46 | DPMSolverMultistepScheduler, 47 | EulerAncestralDiscreteScheduler, 48 | EulerDiscreteScheduler, 49 | IPNDMScheduler, 50 | KarrasVeScheduler, 51 | PNDMScheduler, 52 | RePaintScheduler, 53 | SchedulerMixin, 54 | ScoreSdeVeScheduler, 55 | VQDiffusionScheduler, 56 | ) 57 | from .training_utils import EMAModel 58 | else: 59 | from .utils.dummy_pt_objects import * # noqa F403 60 | 61 | if is_torch_available() and is_scipy_available(): 62 | from .schedulers import LMSDiscreteScheduler 63 | else: 64 | from .utils.dummy_torch_and_scipy_objects import * # noqa F403 65 | 66 | if is_torch_available() and is_transformers_available(): 67 | from .pipelines import ( 68 | CycleDiffusionPipeline, 69 | LDMTextToImagePipeline, 70 | StableDiffusionImg2ImgPipeline, 71 | StableDiffusionInpaintPipeline, 72 | StableDiffusionInpaintPipelineLegacy, 73 | StableDiffusionPipeline, 74 | VQDiffusionPipeline, 75 | ) 76 | else: 77 | from .utils.dummy_torch_and_transformers_objects import * # noqa F403 78 | 79 | if is_torch_available() and is_transformers_available() and is_onnx_available(): 80 | from .pipelines import ( 81 | OnnxStableDiffusionImg2ImgPipeline, 82 | OnnxStableDiffusionInpaintPipeline, 83 | OnnxStableDiffusionPipeline, 84 | StableDiffusionOnnxPipeline, 85 | ) 86 | else: 87 | from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 88 | 89 | if is_flax_available(): 90 | from .modeling_flax_utils import FlaxModelMixin 91 | from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel 92 | from .models.vae_flax import FlaxAutoencoderKL 93 | from .pipeline_flax_utils import FlaxDiffusionPipeline 94 | from .schedulers import ( 95 | FlaxDDIMScheduler, 96 | FlaxDDPMScheduler, 97 | FlaxDPMSolverMultistepScheduler, 98 | FlaxKarrasVeScheduler, 99 | FlaxLMSDiscreteScheduler, 100 | FlaxPNDMScheduler, 101 | FlaxSchedulerMixin, 102 | FlaxScoreSdeVeScheduler, 103 | ) 104 | else: 105 | from .utils.dummy_flax_objects import * # noqa F403 106 | 107 | if is_flax_available() and is_transformers_available(): 108 | from .pipelines import FlaxStableDiffusionPipeline 109 | else: 110 | from .utils.dummy_flax_and_transformers_objects import * # noqa F403 111 | -------------------------------------------------------------------------------- /diffusers/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from argparse import ArgumentParser 17 | 18 | 19 | class BaseDiffusersCLICommand(ABC): 20 | @staticmethod 21 | @abstractmethod 22 | def register_subcommand(parser: ArgumentParser): 23 | raise NotImplementedError() 24 | 25 | @abstractmethod 26 | def run(self): 27 | raise NotImplementedError() 28 | -------------------------------------------------------------------------------- /diffusers/commands/diffusers_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2022 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from argparse import ArgumentParser 17 | 18 | from .env import EnvironmentCommand 19 | 20 | 21 | def main(): 22 | parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") 23 | commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") 24 | 25 | # Register commands 26 | EnvironmentCommand.register_subcommand(commands_parser) 27 | 28 | # Let's go 29 | args = parser.parse_args() 30 | 31 | if not hasattr(args, "func"): 32 | parser.print_help() 33 | exit(1) 34 | 35 | # Run 36 | service = args.func(args) 37 | service.run() 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /diffusers/commands/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import platform 16 | from argparse import ArgumentParser 17 | 18 | import huggingface_hub 19 | 20 | from .. import __version__ as version 21 | from ..utils import is_torch_available, is_transformers_available 22 | from . import BaseDiffusersCLICommand 23 | 24 | 25 | def info_command_factory(_): 26 | return EnvironmentCommand() 27 | 28 | 29 | class EnvironmentCommand(BaseDiffusersCLICommand): 30 | @staticmethod 31 | def register_subcommand(parser: ArgumentParser): 32 | download_parser = parser.add_parser("env") 33 | download_parser.set_defaults(func=info_command_factory) 34 | 35 | def run(self): 36 | hub_version = huggingface_hub.__version__ 37 | 38 | pt_version = "not installed" 39 | pt_cuda_available = "NA" 40 | if is_torch_available(): 41 | import torch 42 | 43 | pt_version = torch.__version__ 44 | pt_cuda_available = torch.cuda.is_available() 45 | 46 | transformers_version = "not installed" 47 | if is_transformers_available: 48 | import transformers 49 | 50 | transformers_version = transformers.__version__ 51 | 52 | info = { 53 | "`diffusers` version": version, 54 | "Platform": platform.platform(), 55 | "Python version": platform.python_version(), 56 | "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", 57 | "Huggingface_hub version": hub_version, 58 | "Transformers version": transformers_version, 59 | "Using GPU in script?": "", 60 | "Using distributed or parallel set-up in script?": "", 61 | } 62 | 63 | print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") 64 | print(self.format_dict(info)) 65 | 66 | return info 67 | 68 | @staticmethod 69 | def format_dict(d): 70 | return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" 71 | -------------------------------------------------------------------------------- /diffusers/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | from .dependency_versions_table import deps 17 | from .utils.versions import require_version, require_version_core 18 | 19 | 20 | # define which module versions we always want to check at run time 21 | # (usually the ones defined in `install_requires` in setup.py) 22 | # 23 | # order specific notes: 24 | # - tqdm must be checked before tokenizers 25 | 26 | pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() 27 | if sys.version_info < (3, 7): 28 | pkgs_to_check_at_runtime.append("dataclasses") 29 | if sys.version_info < (3, 8): 30 | pkgs_to_check_at_runtime.append("importlib_metadata") 31 | 32 | for pkg in pkgs_to_check_at_runtime: 33 | if pkg in deps: 34 | if pkg == "tokenizers": 35 | # must be loaded here, or else tqdm check may fail 36 | from .utils import is_tokenizers_available 37 | 38 | if not is_tokenizers_available(): 39 | continue # not required, check version only if installed 40 | 41 | require_version_core(deps[pkg]) 42 | else: 43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 44 | 45 | 46 | def dep_version_check(pkg, hint=None): 47 | require_version(deps[pkg], hint) 48 | -------------------------------------------------------------------------------- /diffusers/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "Pillow": "Pillow<10.0", 6 | "accelerate": "accelerate>=0.11.0", 7 | "black": "black==22.8", 8 | "datasets": "datasets", 9 | "filelock": "filelock", 10 | "flake8": "flake8>=3.8.3", 11 | "flax": "flax>=0.4.1", 12 | "hf-doc-builder": "hf-doc-builder>=0.3.0", 13 | "huggingface-hub": "huggingface-hub>=0.10.0", 14 | "importlib_metadata": "importlib_metadata", 15 | "isort": "isort>=5.5.4", 16 | "jax": "jax>=0.2.8,!=0.3.2", 17 | "jaxlib": "jaxlib>=0.1.65", 18 | "modelcards": "modelcards>=0.1.4", 19 | "numpy": "numpy", 20 | "parameterized": "parameterized", 21 | "pytest": "pytest", 22 | "pytest-timeout": "pytest-timeout", 23 | "pytest-xdist": "pytest-xdist", 24 | "scipy": "scipy", 25 | "regex": "regex!=2019.12.17", 26 | "requests": "requests", 27 | "tensorboard": "tensorboard", 28 | "torch": "torch>=1.4", 29 | "torchvision": "torchvision", 30 | "transformers": "transformers>=4.21.0", 31 | } 32 | -------------------------------------------------------------------------------- /diffusers/modeling_flax_pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch - Flax general utilities.""" 16 | import re 17 | 18 | import jax.numpy as jnp 19 | from flax.traverse_util import flatten_dict, unflatten_dict 20 | from jax.random import PRNGKey 21 | 22 | from .utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | def rename_key(key): 29 | regex = r"\w+[.]\d+" 30 | pats = re.findall(regex, key) 31 | for pat in pats: 32 | key = key.replace(pat, "_".join(pat.split("."))) 33 | return key 34 | 35 | 36 | ##################### 37 | # PyTorch => Flax # 38 | ##################### 39 | 40 | # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 41 | # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py 42 | def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): 43 | """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" 44 | 45 | # conv norm or layer norm 46 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 47 | if ( 48 | any("norm" in str_ for str_ in pt_tuple_key) 49 | and (pt_tuple_key[-1] == "bias") 50 | and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) 51 | and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) 52 | ): 53 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 54 | return renamed_pt_tuple_key, pt_tensor 55 | elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: 56 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 57 | return renamed_pt_tuple_key, pt_tensor 58 | 59 | # embedding 60 | if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: 61 | pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) 62 | return renamed_pt_tuple_key, pt_tensor 63 | 64 | # conv layer 65 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) 66 | if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: 67 | pt_tensor = pt_tensor.transpose(2, 3, 1, 0) 68 | return renamed_pt_tuple_key, pt_tensor 69 | 70 | # linear layer 71 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) 72 | if pt_tuple_key[-1] == "weight": 73 | pt_tensor = pt_tensor.T 74 | return renamed_pt_tuple_key, pt_tensor 75 | 76 | # old PyTorch layer norm weight 77 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) 78 | if pt_tuple_key[-1] == "gamma": 79 | return renamed_pt_tuple_key, pt_tensor 80 | 81 | # old PyTorch layer norm bias 82 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) 83 | if pt_tuple_key[-1] == "beta": 84 | return renamed_pt_tuple_key, pt_tensor 85 | 86 | return pt_tuple_key, pt_tensor 87 | 88 | 89 | def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): 90 | # Step 1: Convert pytorch tensor to numpy 91 | pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} 92 | 93 | # Step 2: Since the model is stateless, get random Flax params 94 | random_flax_params = flax_model.init_weights(PRNGKey(init_key)) 95 | 96 | random_flax_state_dict = flatten_dict(random_flax_params) 97 | flax_state_dict = {} 98 | 99 | # Need to change some parameters name to match Flax names 100 | for pt_key, pt_tensor in pt_state_dict.items(): 101 | renamed_pt_key = rename_key(pt_key) 102 | pt_tuple_key = tuple(renamed_pt_key.split(".")) 103 | 104 | # Correctly rename weight parameters 105 | flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) 106 | 107 | if flax_key in random_flax_state_dict: 108 | if flax_tensor.shape != random_flax_state_dict[flax_key].shape: 109 | raise ValueError( 110 | f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " 111 | f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." 112 | ) 113 | 114 | # also add unexpected weight so that warning is thrown 115 | flax_state_dict[flax_key] = jnp.asarray(flax_tensor) 116 | 117 | return unflatten_dict(flax_state_dict) 118 | -------------------------------------------------------------------------------- /diffusers/models/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models). -------------------------------------------------------------------------------- /diffusers/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ..utils import is_flax_available, is_torch_available 16 | 17 | 18 | if is_torch_available(): 19 | from .attention import Transformer2DModel 20 | from .unet_1d import UNet1DModel 21 | from .unet_2d import UNet2DModel 22 | from .unet_2d_condition import UNet2DConditionModel 23 | from .vae import AutoencoderKL, VQModel 24 | 25 | if is_flax_available(): 26 | from .unet_2d_condition_flax import FlaxUNet2DConditionModel 27 | from .vae_flax import FlaxAutoencoderKL 28 | -------------------------------------------------------------------------------- /diffusers/models/embeddings_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | 16 | import flax.linen as nn 17 | import jax.numpy as jnp 18 | 19 | 20 | def get_sinusoidal_embeddings( 21 | timesteps: jnp.ndarray, 22 | embedding_dim: int, 23 | freq_shift: float = 1, 24 | min_timescale: float = 1, 25 | max_timescale: float = 1.0e4, 26 | flip_sin_to_cos: bool = False, 27 | scale: float = 1.0, 28 | ) -> jnp.ndarray: 29 | """Returns the positional encoding (same as Tensor2Tensor). 30 | Args: 31 | timesteps: a 1-D Tensor of N indices, one per batch element. 32 | These may be fractional. 33 | embedding_dim: The number of output channels. 34 | min_timescale: The smallest time unit (should probably be 0.0). 35 | max_timescale: The largest time unit. 36 | Returns: 37 | a Tensor of timing signals [N, num_channels] 38 | """ 39 | assert timesteps.ndim == 1, "Timesteps should be a 1d-array" 40 | assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" 41 | num_timescales = float(embedding_dim // 2) 42 | log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) 43 | inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) 44 | emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) 45 | 46 | # scale embeddings 47 | scaled_time = scale * emb 48 | 49 | if flip_sin_to_cos: 50 | signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) 51 | else: 52 | signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) 53 | signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) 54 | return signal 55 | 56 | 57 | class FlaxTimestepEmbedding(nn.Module): 58 | r""" 59 | Time step Embedding Module. Learns embeddings for input time steps. 60 | 61 | Args: 62 | time_embed_dim (`int`, *optional*, defaults to `32`): 63 | Time step embedding dimension 64 | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): 65 | Parameters `dtype` 66 | """ 67 | time_embed_dim: int = 32 68 | dtype: jnp.dtype = jnp.float32 69 | 70 | @nn.compact 71 | def __call__(self, temb): 72 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) 73 | temb = nn.silu(temb) 74 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) 75 | return temb 76 | 77 | 78 | class FlaxTimesteps(nn.Module): 79 | r""" 80 | Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 81 | 82 | Args: 83 | dim (`int`, *optional*, defaults to `32`): 84 | Time step embedding dimension 85 | """ 86 | dim: int = 32 87 | freq_shift: float = 1 88 | 89 | @nn.compact 90 | def __call__(self, timesteps): 91 | return get_sinusoidal_embeddings( 92 | timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True 93 | ) 94 | -------------------------------------------------------------------------------- /diffusers/models/resnet_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import flax.linen as nn 15 | import jax 16 | import jax.numpy as jnp 17 | 18 | 19 | class FlaxUpsample2D(nn.Module): 20 | out_channels: int 21 | dtype: jnp.dtype = jnp.float32 22 | 23 | def setup(self): 24 | self.conv = nn.Conv( 25 | self.out_channels, 26 | kernel_size=(3, 3), 27 | strides=(1, 1), 28 | padding=((1, 1), (1, 1)), 29 | dtype=self.dtype, 30 | ) 31 | 32 | def __call__(self, hidden_states): 33 | batch, height, width, channels = hidden_states.shape 34 | hidden_states = jax.image.resize( 35 | hidden_states, 36 | shape=(batch, height * 2, width * 2, channels), 37 | method="nearest", 38 | ) 39 | hidden_states = self.conv(hidden_states) 40 | return hidden_states 41 | 42 | 43 | class FlaxDownsample2D(nn.Module): 44 | out_channels: int 45 | dtype: jnp.dtype = jnp.float32 46 | 47 | def setup(self): 48 | self.conv = nn.Conv( 49 | self.out_channels, 50 | kernel_size=(3, 3), 51 | strides=(2, 2), 52 | padding=((1, 1), (1, 1)), # padding="VALID", 53 | dtype=self.dtype, 54 | ) 55 | 56 | def __call__(self, hidden_states): 57 | # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim 58 | # hidden_states = jnp.pad(hidden_states, pad_width=pad) 59 | hidden_states = self.conv(hidden_states) 60 | return hidden_states 61 | 62 | 63 | class FlaxResnetBlock2D(nn.Module): 64 | in_channels: int 65 | out_channels: int = None 66 | dropout_prob: float = 0.0 67 | use_nin_shortcut: bool = None 68 | dtype: jnp.dtype = jnp.float32 69 | 70 | def setup(self): 71 | out_channels = self.in_channels if self.out_channels is None else self.out_channels 72 | 73 | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) 74 | self.conv1 = nn.Conv( 75 | out_channels, 76 | kernel_size=(3, 3), 77 | strides=(1, 1), 78 | padding=((1, 1), (1, 1)), 79 | dtype=self.dtype, 80 | ) 81 | 82 | self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) 83 | 84 | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) 85 | self.dropout = nn.Dropout(self.dropout_prob) 86 | self.conv2 = nn.Conv( 87 | out_channels, 88 | kernel_size=(3, 3), 89 | strides=(1, 1), 90 | padding=((1, 1), (1, 1)), 91 | dtype=self.dtype, 92 | ) 93 | 94 | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut 95 | 96 | self.conv_shortcut = None 97 | if use_nin_shortcut: 98 | self.conv_shortcut = nn.Conv( 99 | out_channels, 100 | kernel_size=(1, 1), 101 | strides=(1, 1), 102 | padding="VALID", 103 | dtype=self.dtype, 104 | ) 105 | 106 | def __call__(self, hidden_states, temb, deterministic=True): 107 | residual = hidden_states 108 | hidden_states = self.norm1(hidden_states) 109 | hidden_states = nn.swish(hidden_states) 110 | hidden_states = self.conv1(hidden_states) 111 | 112 | temb = self.time_emb_proj(nn.swish(temb)) 113 | temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) 114 | hidden_states = hidden_states + temb 115 | 116 | hidden_states = self.norm2(hidden_states) 117 | hidden_states = nn.swish(hidden_states) 118 | hidden_states = self.dropout(hidden_states, deterministic) 119 | hidden_states = self.conv2(hidden_states) 120 | 121 | if self.conv_shortcut is not None: 122 | residual = self.conv_shortcut(residual) 123 | 124 | return hidden_states + residual 125 | -------------------------------------------------------------------------------- /diffusers/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available 2 | 3 | 4 | if is_torch_available(): 5 | from .dance_diffusion import DanceDiffusionPipeline 6 | from .ddim import DDIMPipeline 7 | from .ddpm import DDPMPipeline 8 | from .latent_diffusion import LDMSuperResolutionPipeline 9 | from .latent_diffusion_uncond import LDMPipeline 10 | from .pndm import PNDMPipeline 11 | from .repaint import RePaintPipeline 12 | from .score_sde_ve import ScoreSdeVePipeline 13 | from .stochastic_karras_ve import KarrasVePipeline 14 | else: 15 | from ..utils.dummy_pt_objects import * # noqa F403 16 | 17 | if is_torch_available() and is_transformers_available(): 18 | from .latent_diffusion import LDMTextToImagePipeline 19 | from .stable_diffusion import ( 20 | CycleDiffusionPipeline, 21 | StableDiffusionImg2ImgPipeline, 22 | StableDiffusionInpaintPipeline, 23 | StableDiffusionInpaintPipelineLegacy, 24 | StableDiffusionPipeline, 25 | ) 26 | from .vq_diffusion import VQDiffusionPipeline 27 | 28 | if is_transformers_available() and is_onnx_available(): 29 | from .stable_diffusion import ( 30 | OnnxStableDiffusionImg2ImgPipeline, 31 | OnnxStableDiffusionInpaintPipeline, 32 | OnnxStableDiffusionPipeline, 33 | StableDiffusionOnnxPipeline, 34 | ) 35 | 36 | if is_transformers_available() and is_flax_available(): 37 | from .stable_diffusion import FlaxStableDiffusionPipeline 38 | -------------------------------------------------------------------------------- /diffusers/pipelines/dance_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .pipeline_dance_diffusion import DanceDiffusionPipeline 3 | -------------------------------------------------------------------------------- /diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline 21 | from ...utils import logging 22 | 23 | 24 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 25 | 26 | 27 | class DanceDiffusionPipeline(DiffusionPipeline): 28 | r""" 29 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 30 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 31 | 32 | Parameters: 33 | unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image. 34 | scheduler ([`SchedulerMixin`]): 35 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 36 | [`IPNDMScheduler`]. 37 | """ 38 | 39 | def __init__(self, unet, scheduler): 40 | super().__init__() 41 | self.register_modules(unet=unet, scheduler=scheduler) 42 | 43 | @torch.no_grad() 44 | def __call__( 45 | self, 46 | batch_size: int = 1, 47 | num_inference_steps: int = 100, 48 | generator: Optional[torch.Generator] = None, 49 | audio_length_in_s: Optional[float] = None, 50 | return_dict: bool = True, 51 | ) -> Union[AudioPipelineOutput, Tuple]: 52 | r""" 53 | Args: 54 | batch_size (`int`, *optional*, defaults to 1): 55 | The number of audio samples to generate. 56 | num_inference_steps (`int`, *optional*, defaults to 50): 57 | The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at 58 | the expense of slower inference. 59 | generator (`torch.Generator`, *optional*): 60 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 61 | deterministic. 62 | audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): 63 | The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* 64 | `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`. 65 | return_dict (`bool`, *optional*, defaults to `True`): 66 | Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple. 67 | 68 | Returns: 69 | [`~pipeline_utils.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if 70 | `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the 71 | generated images. 72 | """ 73 | 74 | if audio_length_in_s is None: 75 | audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate 76 | 77 | sample_size = audio_length_in_s * self.unet.sample_rate 78 | 79 | down_scale_factor = 2 ** len(self.unet.up_blocks) 80 | if sample_size < 3 * down_scale_factor: 81 | raise ValueError( 82 | f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" 83 | f" {3 * down_scale_factor / self.unet.sample_rate}." 84 | ) 85 | 86 | original_sample_size = int(sample_size) 87 | if sample_size % down_scale_factor != 0: 88 | sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor 89 | logger.info( 90 | f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled" 91 | f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising" 92 | " process." 93 | ) 94 | sample_size = int(sample_size) 95 | 96 | dtype = next(iter(self.unet.parameters())).dtype 97 | audio = torch.randn( 98 | (batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype 99 | ) 100 | 101 | # set step values 102 | self.scheduler.set_timesteps(num_inference_steps, device=audio.device) 103 | self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) 104 | 105 | for t in self.progress_bar(self.scheduler.timesteps): 106 | # 1. predict noise model_output 107 | model_output = self.unet(audio, t).sample 108 | 109 | # 2. compute previous image: x_t -> t_t-1 110 | audio = self.scheduler.step(model_output, t, audio).prev_sample 111 | 112 | audio = audio.clamp(-1, 1).float().cpu().numpy() 113 | 114 | audio = audio[:, :, :original_sample_size] 115 | 116 | if not return_dict: 117 | return (audio,) 118 | 119 | return AudioPipelineOutput(audios=audio) 120 | -------------------------------------------------------------------------------- /diffusers/pipelines/ddim/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .pipeline_ddim import DDIMPipeline 3 | -------------------------------------------------------------------------------- /diffusers/pipelines/ddim/pipeline_ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Tuple, Union 16 | 17 | import torch 18 | 19 | from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput 20 | from ...utils import deprecate 21 | 22 | 23 | class DDIMPipeline(DiffusionPipeline): 24 | r""" 25 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 26 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 27 | 28 | Parameters: 29 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. 30 | scheduler ([`SchedulerMixin`]): 31 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 32 | [`DDPMScheduler`], or [`DDIMScheduler`]. 33 | """ 34 | 35 | def __init__(self, unet, scheduler): 36 | super().__init__() 37 | self.register_modules(unet=unet, scheduler=scheduler) 38 | 39 | @torch.no_grad() 40 | def __call__( 41 | self, 42 | batch_size: int = 1, 43 | generator: Optional[torch.Generator] = None, 44 | eta: float = 0.0, 45 | num_inference_steps: int = 50, 46 | use_clipped_model_output: Optional[bool] = None, 47 | output_type: Optional[str] = "pil", 48 | return_dict: bool = True, 49 | **kwargs, 50 | ) -> Union[ImagePipelineOutput, Tuple]: 51 | r""" 52 | Args: 53 | batch_size (`int`, *optional*, defaults to 1): 54 | The number of images to generate. 55 | generator (`torch.Generator`, *optional*): 56 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 57 | deterministic. 58 | eta (`float`, *optional*, defaults to 0.0): 59 | The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). 60 | num_inference_steps (`int`, *optional*, defaults to 50): 61 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 62 | expense of slower inference. 63 | use_clipped_model_output (`bool`, *optional*, defaults to `None`): 64 | if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed 65 | downstream to the scheduler. So use `None` for schedulers which don't support this argument. 66 | output_type (`str`, *optional*, defaults to `"pil"`): 67 | The output format of the generate image. Choose between 68 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 69 | return_dict (`bool`, *optional*, defaults to `True`): 70 | Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. 71 | 72 | Returns: 73 | [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if 74 | `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the 75 | generated images. 76 | """ 77 | 78 | if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": 79 | message = ( 80 | f"The `generator` device is `{generator.device}` and does not match the pipeline " 81 | f"device `{self.device}`, so the `generator` will be ignored. " 82 | f'Please use `generator=torch.Generator(device="{self.device}")` instead.' 83 | ) 84 | deprecate( 85 | "generator.device == 'cpu'", 86 | "0.11.0", 87 | message, 88 | ) 89 | generator = None 90 | 91 | # Sample gaussian noise to begin loop 92 | image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) 93 | if self.device.type == "mps": 94 | # randn does not work reproducibly on mps 95 | image = torch.randn(image_shape, generator=generator) 96 | image = image.to(self.device) 97 | else: 98 | image = torch.randn(image_shape, generator=generator, device=self.device) 99 | 100 | # set step values 101 | self.scheduler.set_timesteps(num_inference_steps) 102 | 103 | for t in self.progress_bar(self.scheduler.timesteps): 104 | # 1. predict noise model_output 105 | model_output = self.unet(image, t).sample 106 | 107 | # 2. predict previous mean of image x_t-1 and add variance depending on eta 108 | # eta corresponds to η in paper and should be between [0, 1] 109 | # do x_t -> x_t-1 110 | image = self.scheduler.step( 111 | model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator 112 | ).prev_sample 113 | 114 | image = (image / 2 + 0.5).clamp(0, 1) 115 | image = image.cpu().permute(0, 2, 3, 1).numpy() 116 | if output_type == "pil": 117 | image = self.numpy_to_pil(image) 118 | 119 | if not return_dict: 120 | return (image,) 121 | 122 | return ImagePipelineOutput(images=image) 123 | -------------------------------------------------------------------------------- /diffusers/pipelines/ddpm/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .pipeline_ddpm import DDPMPipeline 3 | -------------------------------------------------------------------------------- /diffusers/pipelines/ddpm/pipeline_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from ...configuration_utils import FrozenDict 21 | from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput 22 | from ...utils import deprecate 23 | 24 | 25 | class DDPMPipeline(DiffusionPipeline): 26 | r""" 27 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 28 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 29 | 30 | Parameters: 31 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. 32 | scheduler ([`SchedulerMixin`]): 33 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 34 | [`DDPMScheduler`], or [`DDIMScheduler`]. 35 | """ 36 | 37 | def __init__(self, unet, scheduler): 38 | super().__init__() 39 | self.register_modules(unet=unet, scheduler=scheduler) 40 | 41 | @torch.no_grad() 42 | def __call__( 43 | self, 44 | batch_size: int = 1, 45 | generator: Optional[torch.Generator] = None, 46 | num_inference_steps: int = 1000, 47 | output_type: Optional[str] = "pil", 48 | return_dict: bool = True, 49 | **kwargs, 50 | ) -> Union[ImagePipelineOutput, Tuple]: 51 | r""" 52 | Args: 53 | batch_size (`int`, *optional*, defaults to 1): 54 | The number of images to generate. 55 | generator (`torch.Generator`, *optional*): 56 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 57 | deterministic. 58 | num_inference_steps (`int`, *optional*, defaults to 1000): 59 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 60 | expense of slower inference. 61 | output_type (`str`, *optional*, defaults to `"pil"`): 62 | The output format of the generate image. Choose between 63 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 64 | return_dict (`bool`, *optional*, defaults to `True`): 65 | Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. 66 | 67 | Returns: 68 | [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if 69 | `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the 70 | generated images. 71 | """ 72 | message = ( 73 | "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" 74 | " DDPMScheduler.from_config(, predict_epsilon=True)`." 75 | ) 76 | predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) 77 | 78 | if predict_epsilon is not None: 79 | new_config = dict(self.scheduler.config) 80 | new_config["predict_epsilon"] = predict_epsilon 81 | self.scheduler._internal_dict = FrozenDict(new_config) 82 | 83 | if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": 84 | message = ( 85 | f"The `generator` device is `{generator.device}` and does not match the pipeline " 86 | f"device `{self.device}`, so the `generator` will be ignored. " 87 | f'Please use `torch.Generator(device="{self.device}")` instead.' 88 | ) 89 | deprecate( 90 | "generator.device == 'cpu'", 91 | "0.11.0", 92 | message, 93 | ) 94 | generator = None 95 | 96 | # Sample gaussian noise to begin loop 97 | image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) 98 | if self.device.type == "mps": 99 | # randn does not work reproducibly on mps 100 | image = torch.randn(image_shape, generator=generator) 101 | image = image.to(self.device) 102 | else: 103 | image = torch.randn(image_shape, generator=generator, device=self.device) 104 | 105 | # set step values 106 | self.scheduler.set_timesteps(num_inference_steps) 107 | 108 | for t in self.progress_bar(self.scheduler.timesteps): 109 | # 1. predict noise model_output 110 | model_output = self.unet(image, t).sample 111 | 112 | # 2. compute previous image: x_t -> x_t-1 113 | image = self.scheduler.step( 114 | model_output, t, image, generator=generator, predict_epsilon=predict_epsilon 115 | ).prev_sample 116 | 117 | image = (image / 2 + 0.5).clamp(0, 1) 118 | image = image.cpu().permute(0, 2, 3, 1).numpy() 119 | if output_type == "pil": 120 | image = self.numpy_to_pil(image) 121 | 122 | if not return_dict: 123 | return (image,) 124 | 125 | return ImagePipelineOutput(images=image) 126 | -------------------------------------------------------------------------------- /diffusers/pipelines/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from ...utils import is_transformers_available 3 | from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline 4 | 5 | 6 | if is_transformers_available(): 7 | from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline 8 | -------------------------------------------------------------------------------- /diffusers/pipelines/latent_diffusion_uncond/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .pipeline_latent_diffusion_uncond import LDMPipeline 3 | -------------------------------------------------------------------------------- /diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from ...models import UNet2DModel, VQModel 21 | from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput 22 | from ...schedulers import DDIMScheduler 23 | 24 | 25 | class LDMPipeline(DiffusionPipeline): 26 | r""" 27 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 28 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 29 | 30 | Parameters: 31 | vqvae ([`VQModel`]): 32 | Vector-quantized (VQ) Model to encode and decode images to and from latent representations. 33 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. 34 | scheduler ([`SchedulerMixin`]): 35 | [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents. 36 | """ 37 | 38 | def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): 39 | super().__init__() 40 | self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) 41 | 42 | @torch.no_grad() 43 | def __call__( 44 | self, 45 | batch_size: int = 1, 46 | generator: Optional[torch.Generator] = None, 47 | eta: float = 0.0, 48 | num_inference_steps: int = 50, 49 | output_type: Optional[str] = "pil", 50 | return_dict: bool = True, 51 | **kwargs, 52 | ) -> Union[Tuple, ImagePipelineOutput]: 53 | r""" 54 | Args: 55 | batch_size (`int`, *optional*, defaults to 1): 56 | Number of images to generate. 57 | generator (`torch.Generator`, *optional*): 58 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 59 | deterministic. 60 | num_inference_steps (`int`, *optional*, defaults to 50): 61 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 62 | expense of slower inference. 63 | output_type (`str`, *optional*, defaults to `"pil"`): 64 | The output format of the generate image. Choose between 65 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 66 | return_dict (`bool`, *optional*, defaults to `True`): 67 | Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. 68 | 69 | Returns: 70 | [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if 71 | `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the 72 | generated images. 73 | """ 74 | 75 | latents = torch.randn( 76 | (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), 77 | generator=generator, 78 | ) 79 | latents = latents.to(self.device) 80 | 81 | # scale the initial noise by the standard deviation required by the scheduler 82 | latents = latents * self.scheduler.init_noise_sigma 83 | 84 | self.scheduler.set_timesteps(num_inference_steps) 85 | 86 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 87 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 88 | 89 | extra_kwargs = {} 90 | if accepts_eta: 91 | extra_kwargs["eta"] = eta 92 | 93 | for t in self.progress_bar(self.scheduler.timesteps): 94 | latent_model_input = self.scheduler.scale_model_input(latents, t) 95 | # predict the noise residual 96 | noise_prediction = self.unet(latent_model_input, t).sample 97 | # compute the previous noisy sample x_t -> x_t-1 98 | latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample 99 | 100 | # decode the image latents with the VAE 101 | image = self.vqvae.decode(latents).sample 102 | 103 | image = (image / 2 + 0.5).clamp(0, 1) 104 | image = image.cpu().permute(0, 2, 3, 1).numpy() 105 | if output_type == "pil": 106 | image = self.numpy_to_pil(image) 107 | 108 | if not return_dict: 109 | return (image,) 110 | 111 | return ImagePipelineOutput(images=image) 112 | -------------------------------------------------------------------------------- /diffusers/pipelines/pndm/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .pipeline_pndm import PNDMPipeline 3 | -------------------------------------------------------------------------------- /diffusers/pipelines/pndm/pipeline_pndm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple, Union 17 | 18 | import torch 19 | 20 | from ...models import UNet2DModel 21 | from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput 22 | from ...schedulers import PNDMScheduler 23 | 24 | 25 | class PNDMPipeline(DiffusionPipeline): 26 | r""" 27 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 28 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 29 | 30 | Parameters: 31 | unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. 32 | scheduler ([`SchedulerMixin`]): 33 | The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. 34 | """ 35 | 36 | unet: UNet2DModel 37 | scheduler: PNDMScheduler 38 | 39 | def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): 40 | super().__init__() 41 | self.register_modules(unet=unet, scheduler=scheduler) 42 | 43 | @torch.no_grad() 44 | def __call__( 45 | self, 46 | batch_size: int = 1, 47 | num_inference_steps: int = 50, 48 | generator: Optional[torch.Generator] = None, 49 | output_type: Optional[str] = "pil", 50 | return_dict: bool = True, 51 | **kwargs, 52 | ) -> Union[ImagePipelineOutput, Tuple]: 53 | r""" 54 | Args: 55 | batch_size (`int`, `optional`, defaults to 1): The number of images to generate. 56 | num_inference_steps (`int`, `optional`, defaults to 50): 57 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 58 | expense of slower inference. 59 | generator (`torch.Generator`, `optional`): A [torch 60 | generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 61 | deterministic. 62 | output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose 63 | between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 64 | return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a 65 | [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. 66 | 67 | Returns: 68 | [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if 69 | `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the 70 | generated images. 71 | """ 72 | # For more information on the sampling method you can take a look at Algorithm 2 of 73 | # the official paper: https://arxiv.org/pdf/2202.09778.pdf 74 | 75 | # Sample gaussian noise to begin loop 76 | image = torch.randn( 77 | (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), 78 | generator=generator, 79 | ) 80 | image = image.to(self.device) 81 | 82 | self.scheduler.set_timesteps(num_inference_steps) 83 | for t in self.progress_bar(self.scheduler.timesteps): 84 | model_output = self.unet(image, t).sample 85 | 86 | image = self.scheduler.step(model_output, t, image).prev_sample 87 | 88 | image = (image / 2 + 0.5).clamp(0, 1) 89 | image = image.cpu().permute(0, 2, 3, 1).numpy() 90 | if output_type == "pil": 91 | image = self.numpy_to_pil(image) 92 | 93 | if not return_dict: 94 | return (image,) 95 | 96 | return ImagePipelineOutput(images=image) 97 | -------------------------------------------------------------------------------- /diffusers/pipelines/repaint/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_repaint import RePaintPipeline 2 | -------------------------------------------------------------------------------- /diffusers/pipelines/score_sde_ve/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .pipeline_score_sde_ve import ScoreSdeVePipeline 3 | -------------------------------------------------------------------------------- /diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Tuple, Union 16 | 17 | import torch 18 | 19 | from ...models import UNet2DModel 20 | from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput 21 | from ...schedulers import ScoreSdeVeScheduler 22 | 23 | 24 | class ScoreSdeVePipeline(DiffusionPipeline): 25 | r""" 26 | Parameters: 27 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 28 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 29 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): 30 | The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image. 31 | """ 32 | unet: UNet2DModel 33 | scheduler: ScoreSdeVeScheduler 34 | 35 | def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline): 36 | super().__init__() 37 | self.register_modules(unet=unet, scheduler=scheduler) 38 | 39 | @torch.no_grad() 40 | def __call__( 41 | self, 42 | batch_size: int = 1, 43 | num_inference_steps: int = 2000, 44 | generator: Optional[torch.Generator] = None, 45 | output_type: Optional[str] = "pil", 46 | return_dict: bool = True, 47 | **kwargs, 48 | ) -> Union[ImagePipelineOutput, Tuple]: 49 | r""" 50 | Args: 51 | batch_size (`int`, *optional*, defaults to 1): 52 | The number of images to generate. 53 | generator (`torch.Generator`, *optional*): 54 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 55 | deterministic. 56 | output_type (`str`, *optional*, defaults to `"pil"`): 57 | The output format of the generate image. Choose between 58 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 59 | return_dict (`bool`, *optional*, defaults to `True`): 60 | Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. 61 | 62 | Returns: 63 | [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if 64 | `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the 65 | generated images. 66 | """ 67 | 68 | img_size = self.unet.config.sample_size 69 | shape = (batch_size, 3, img_size, img_size) 70 | 71 | model = self.unet 72 | 73 | sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma 74 | sample = sample.to(self.device) 75 | 76 | self.scheduler.set_timesteps(num_inference_steps) 77 | self.scheduler.set_sigmas(num_inference_steps) 78 | 79 | for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): 80 | sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) 81 | 82 | # correction step 83 | for _ in range(self.scheduler.config.correct_steps): 84 | model_output = self.unet(sample, sigma_t).sample 85 | sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample 86 | 87 | # prediction step 88 | model_output = model(sample, sigma_t).sample 89 | output = self.scheduler.step_pred(model_output, t, sample, generator=generator) 90 | 91 | sample, sample_mean = output.prev_sample, output.prev_sample_mean 92 | 93 | sample = sample_mean.clamp(0, 1) 94 | sample = sample.cpu().permute(0, 2, 3, 1).numpy() 95 | if output_type == "pil": 96 | sample = self.numpy_to_pil(sample) 97 | 98 | if not return_dict: 99 | return (sample,) 100 | 101 | return ImagePipelineOutput(images=sample) 102 | -------------------------------------------------------------------------------- /diffusers/pipelines/stable_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | 6 | import PIL 7 | from PIL import Image 8 | 9 | from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available 10 | 11 | 12 | @dataclass 13 | class StableDiffusionPipelineOutput(BaseOutput): 14 | """ 15 | Output class for Stable Diffusion pipelines. 16 | 17 | Args: 18 | images (`List[PIL.Image.Image]` or `np.ndarray`) 19 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 20 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 21 | nsfw_content_detected (`List[bool]`) 22 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" 23 | (nsfw) content, or `None` if safety checking could not be performed. 24 | """ 25 | 26 | images: Union[List[PIL.Image.Image], np.ndarray] 27 | nsfw_content_detected: Optional[List[bool]] 28 | 29 | 30 | if is_transformers_available() and is_torch_available(): 31 | from .pipeline_cycle_diffusion import CycleDiffusionPipeline 32 | from .pipeline_stable_diffusion import StableDiffusionPipeline 33 | from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline 34 | from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline 35 | from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy 36 | from .safety_checker import StableDiffusionSafetyChecker 37 | 38 | if is_transformers_available() and is_onnx_available(): 39 | from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline 40 | from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline 41 | from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline 42 | 43 | if is_transformers_available() and is_flax_available(): 44 | import flax 45 | 46 | @flax.struct.dataclass 47 | class FlaxStableDiffusionPipelineOutput(BaseOutput): 48 | """ 49 | Output class for Stable Diffusion pipelines. 50 | 51 | Args: 52 | images (`List[PIL.Image.Image]` or `np.ndarray`) 53 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 54 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 55 | nsfw_content_detected (`List[bool]`) 56 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" 57 | (nsfw) content. 58 | """ 59 | 60 | images: Union[List[PIL.Image.Image], np.ndarray] 61 | nsfw_content_detected: List[bool] 62 | 63 | from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState 64 | from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline 65 | from .safety_checker_flax import FlaxStableDiffusionSafetyChecker 66 | -------------------------------------------------------------------------------- /diffusers/pipelines/stable_diffusion/safety_checker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | 19 | from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel 20 | 21 | from ...utils import logging 22 | 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | 27 | def cosine_distance(image_embeds, text_embeds): 28 | normalized_image_embeds = nn.functional.normalize(image_embeds) 29 | normalized_text_embeds = nn.functional.normalize(text_embeds) 30 | return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) 31 | 32 | 33 | class StableDiffusionSafetyChecker(PreTrainedModel): 34 | config_class = CLIPConfig 35 | 36 | _no_split_modules = ["CLIPEncoderLayer"] 37 | 38 | def __init__(self, config: CLIPConfig): 39 | super().__init__(config) 40 | 41 | self.vision_model = CLIPVisionModel(config.vision_config) 42 | self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) 43 | 44 | self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) 45 | self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) 46 | 47 | self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) 48 | self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) 49 | 50 | @torch.no_grad() 51 | def forward(self, clip_input, images): 52 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 53 | image_embeds = self.visual_projection(pooled_output) 54 | 55 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 56 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() 57 | cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() 58 | 59 | result = [] 60 | batch_size = image_embeds.shape[0] 61 | for i in range(batch_size): 62 | result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} 63 | 64 | # increase this value to create a stronger `nfsw` filter 65 | # at the cost of increasing the possibility of filtering benign images 66 | adjustment = 0.0 67 | 68 | for concept_idx in range(len(special_cos_dist[0])): 69 | concept_cos = special_cos_dist[i][concept_idx] 70 | concept_threshold = self.special_care_embeds_weights[concept_idx].item() 71 | result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) 72 | if result_img["special_scores"][concept_idx] > 0: 73 | result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) 74 | adjustment = 0.01 75 | 76 | for concept_idx in range(len(cos_dist[0])): 77 | concept_cos = cos_dist[i][concept_idx] 78 | concept_threshold = self.concept_embeds_weights[concept_idx].item() 79 | result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) 80 | if result_img["concept_scores"][concept_idx] > 0: 81 | result_img["bad_concepts"].append(concept_idx) 82 | 83 | result.append(result_img) 84 | 85 | has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] 86 | 87 | for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): 88 | if has_nsfw_concept: 89 | images[idx] = np.zeros(images[idx].shape) # black image 90 | 91 | if any(has_nsfw_concepts): 92 | logger.warning( 93 | "Potential NSFW content was detected in one or more images. A black image will be returned instead." 94 | " Try again with a different prompt and/or seed." 95 | ) 96 | 97 | return images, has_nsfw_concepts 98 | 99 | @torch.no_grad() 100 | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): 101 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 102 | image_embeds = self.visual_projection(pooled_output) 103 | 104 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) 105 | cos_dist = cosine_distance(image_embeds, self.concept_embeds) 106 | 107 | # increase this value to create a stronger `nsfw` filter 108 | # at the cost of increasing the possibility of filtering benign images 109 | adjustment = 0.0 110 | 111 | special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment 112 | # special_scores = special_scores.round(decimals=3) 113 | special_care = torch.any(special_scores > 0, dim=1) 114 | special_adjustment = special_care * 0.01 115 | special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) 116 | 117 | concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment 118 | # concept_scores = concept_scores.round(decimals=3) 119 | has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) 120 | 121 | images[has_nsfw_concepts] = 0.0 # black image 122 | 123 | return images, has_nsfw_concepts 124 | -------------------------------------------------------------------------------- /diffusers/pipelines/stable_diffusion/safety_checker_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Tuple 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | from flax import linen as nn 20 | from flax.core.frozen_dict import FrozenDict 21 | from transformers import CLIPConfig, FlaxPreTrainedModel 22 | from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule 23 | 24 | 25 | def jax_cosine_distance(emb_1, emb_2, eps=1e-12): 26 | norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T 27 | norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T 28 | return jnp.matmul(norm_emb_1, norm_emb_2.T) 29 | 30 | 31 | class FlaxStableDiffusionSafetyCheckerModule(nn.Module): 32 | config: CLIPConfig 33 | dtype: jnp.dtype = jnp.float32 34 | 35 | def setup(self): 36 | self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) 37 | self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) 38 | 39 | self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) 40 | self.special_care_embeds = self.param( 41 | "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) 42 | ) 43 | 44 | self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) 45 | self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) 46 | 47 | def __call__(self, clip_input): 48 | pooled_output = self.vision_model(clip_input)[1] 49 | image_embeds = self.visual_projection(pooled_output) 50 | 51 | special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) 52 | cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) 53 | 54 | # increase this value to create a stronger `nfsw` filter 55 | # at the cost of increasing the possibility of filtering benign image inputs 56 | adjustment = 0.0 57 | 58 | special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment 59 | special_scores = jnp.round(special_scores, 3) 60 | is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) 61 | # Use a lower threshold if an image has any special care concept 62 | special_adjustment = is_special_care * 0.01 63 | 64 | concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment 65 | concept_scores = jnp.round(concept_scores, 3) 66 | has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) 67 | 68 | return has_nsfw_concepts 69 | 70 | 71 | class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): 72 | config_class = CLIPConfig 73 | main_input_name = "clip_input" 74 | module_class = FlaxStableDiffusionSafetyCheckerModule 75 | 76 | def __init__( 77 | self, 78 | config: CLIPConfig, 79 | input_shape: Optional[Tuple] = None, 80 | seed: int = 0, 81 | dtype: jnp.dtype = jnp.float32, 82 | _do_init: bool = True, 83 | **kwargs, 84 | ): 85 | if input_shape is None: 86 | input_shape = (1, 224, 224, 3) 87 | module = self.module_class(config=config, dtype=dtype, **kwargs) 88 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 89 | 90 | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: 91 | # init input tensor 92 | clip_input = jax.random.normal(rng, input_shape) 93 | 94 | params_rng, dropout_rng = jax.random.split(rng) 95 | rngs = {"params": params_rng, "dropout": dropout_rng} 96 | 97 | random_params = self.module.init(rngs, clip_input)["params"] 98 | 99 | return random_params 100 | 101 | def __call__( 102 | self, 103 | clip_input, 104 | params: dict = None, 105 | ): 106 | clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) 107 | 108 | return self.module.apply( 109 | {"params": params or self.params}, 110 | jnp.array(clip_input, dtype=jnp.float32), 111 | rngs={}, 112 | ) 113 | -------------------------------------------------------------------------------- /diffusers/pipelines/stochastic_karras_ve/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .pipeline_stochastic_karras_ve import KarrasVePipeline 3 | -------------------------------------------------------------------------------- /diffusers/pipelines/vq_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_vq_diffusion import VQDiffusionPipeline 2 | -------------------------------------------------------------------------------- /diffusers/schedulers/README.md: -------------------------------------------------------------------------------- 1 | # Schedulers 2 | 3 | For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers). -------------------------------------------------------------------------------- /diffusers/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from ..utils import is_flax_available, is_scipy_available, is_torch_available 17 | 18 | 19 | if is_torch_available(): 20 | from .scheduling_ddim import DDIMScheduler 21 | from .scheduling_ddpm import DDPMScheduler 22 | from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler 23 | from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler 24 | from .scheduling_euler_discrete import EulerDiscreteScheduler 25 | from .scheduling_ipndm import IPNDMScheduler 26 | from .scheduling_karras_ve import KarrasVeScheduler 27 | from .scheduling_pndm import PNDMScheduler 28 | from .scheduling_repaint import RePaintScheduler 29 | from .scheduling_sde_ve import ScoreSdeVeScheduler 30 | from .scheduling_sde_vp import ScoreSdeVpScheduler 31 | from .scheduling_utils import SchedulerMixin 32 | from .scheduling_vq_diffusion import VQDiffusionScheduler 33 | else: 34 | from ..utils.dummy_pt_objects import * # noqa F403 35 | 36 | if is_flax_available(): 37 | from .scheduling_ddim_flax import FlaxDDIMScheduler 38 | from .scheduling_ddpm_flax import FlaxDDPMScheduler 39 | from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler 40 | from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler 41 | from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler 42 | from .scheduling_pndm_flax import FlaxPNDMScheduler 43 | from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler 44 | from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left 45 | else: 46 | from ..utils.dummy_flax_objects import * # noqa F403 47 | 48 | 49 | if is_scipy_available() and is_torch_available(): 50 | from .scheduling_lms_discrete import LMSDiscreteScheduler 51 | else: 52 | from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 53 | -------------------------------------------------------------------------------- /diffusers/schedulers/scheduling_sde_vp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch 16 | 17 | import math 18 | from typing import Union 19 | 20 | import torch 21 | 22 | from ..configuration_utils import ConfigMixin, register_to_config 23 | from .scheduling_utils import SchedulerMixin 24 | 25 | 26 | class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): 27 | """ 28 | The variance preserving stochastic differential equation (SDE) scheduler. 29 | 30 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 31 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 32 | [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and 33 | [`~ConfigMixin.from_config`] functions. 34 | 35 | For more information, see the original paper: https://arxiv.org/abs/2011.13456 36 | 37 | UNDER CONSTRUCTION 38 | 39 | """ 40 | 41 | @register_to_config 42 | def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): 43 | self.sigmas = None 44 | self.discrete_sigmas = None 45 | self.timesteps = None 46 | 47 | def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): 48 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) 49 | 50 | def step_pred(self, score, x, t, generator=None): 51 | if self.timesteps is None: 52 | raise ValueError( 53 | "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" 54 | ) 55 | 56 | # TODO(Patrick) better comments + non-PyTorch 57 | # postprocess model score 58 | log_mean_coeff = ( 59 | -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min 60 | ) 61 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) 62 | std = std.flatten() 63 | while len(std.shape) < len(score.shape): 64 | std = std.unsqueeze(-1) 65 | score = -score / std 66 | 67 | # compute 68 | dt = -1.0 / len(self.timesteps) 69 | 70 | beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) 71 | beta_t = beta_t.flatten() 72 | while len(beta_t.shape) < len(x.shape): 73 | beta_t = beta_t.unsqueeze(-1) 74 | drift = -0.5 * beta_t * x 75 | 76 | diffusion = torch.sqrt(beta_t) 77 | drift = drift - diffusion**2 * score 78 | x_mean = x + drift * dt 79 | 80 | # add noise 81 | noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device) 82 | x = x_mean + diffusion * math.sqrt(-dt) * noise 83 | 84 | return x, x_mean 85 | 86 | def __len__(self): 87 | return self.config.num_train_timesteps 88 | -------------------------------------------------------------------------------- /diffusers/schedulers/scheduling_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | 16 | import torch 17 | 18 | from ..utils import BaseOutput 19 | 20 | 21 | SCHEDULER_CONFIG_NAME = "scheduler_config.json" 22 | 23 | 24 | @dataclass 25 | class SchedulerOutput(BaseOutput): 26 | """ 27 | Base class for the scheduler's step function output. 28 | 29 | Args: 30 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 31 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 32 | denoising loop. 33 | """ 34 | 35 | prev_sample: torch.FloatTensor 36 | 37 | 38 | class SchedulerMixin: 39 | """ 40 | Mixin containing common functions for the schedulers. 41 | """ 42 | 43 | config_name = SCHEDULER_CONFIG_NAME 44 | -------------------------------------------------------------------------------- /diffusers/schedulers/scheduling_utils_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Tuple 16 | 17 | import jax.numpy as jnp 18 | 19 | from ..utils import BaseOutput 20 | 21 | 22 | SCHEDULER_CONFIG_NAME = "scheduler_config.json" 23 | 24 | 25 | @dataclass 26 | class FlaxSchedulerOutput(BaseOutput): 27 | """ 28 | Base class for the scheduler's step function output. 29 | 30 | Args: 31 | prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): 32 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 33 | denoising loop. 34 | """ 35 | 36 | prev_sample: jnp.ndarray 37 | 38 | 39 | class FlaxSchedulerMixin: 40 | """ 41 | Mixin containing common functions for the schedulers. 42 | """ 43 | 44 | config_name = SCHEDULER_CONFIG_NAME 45 | 46 | 47 | def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: 48 | assert len(shape) >= x.ndim 49 | return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) 50 | -------------------------------------------------------------------------------- /diffusers/training_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def enable_full_determinism(seed: int): 10 | """ 11 | Helper function for reproducible behavior during distributed training. See 12 | - https://pytorch.org/docs/stable/notes/randomness.html for pytorch 13 | """ 14 | # set seed first 15 | set_seed(seed) 16 | 17 | # Enable PyTorch deterministic mode. This potentially requires either the environment 18 | # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, 19 | # depending on the CUDA version, so we set them both here 20 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 21 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" 22 | torch.use_deterministic_algorithms(True) 23 | 24 | # Enable CUDNN deterministic mode 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | 29 | def set_seed(seed: int): 30 | """ 31 | Args: 32 | Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. 33 | seed (`int`): The seed to set. 34 | """ 35 | random.seed(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | # ^^ safe to call this function even if cuda is not available 40 | 41 | 42 | class EMAModel: 43 | """ 44 | Exponential Moving Average of models weights 45 | """ 46 | 47 | def __init__( 48 | self, 49 | model, 50 | update_after_step=0, 51 | inv_gamma=1.0, 52 | power=2 / 3, 53 | min_value=0.0, 54 | max_value=0.9999, 55 | device=None, 56 | ): 57 | """ 58 | @crowsonkb's notes on EMA Warmup: 59 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 60 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 61 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 62 | at 215.4k steps). 63 | Args: 64 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 65 | power (float): Exponential factor of EMA warmup. Default: 2/3. 66 | min_value (float): The minimum EMA decay rate. Default: 0. 67 | """ 68 | 69 | self.averaged_model = copy.deepcopy(model).eval() 70 | self.averaged_model.requires_grad_(False) 71 | 72 | self.update_after_step = update_after_step 73 | self.inv_gamma = inv_gamma 74 | self.power = power 75 | self.min_value = min_value 76 | self.max_value = max_value 77 | 78 | if device is not None: 79 | self.averaged_model = self.averaged_model.to(device=device) 80 | 81 | self.decay = 0.0 82 | self.optimization_step = 0 83 | 84 | def get_decay(self, optimization_step): 85 | """ 86 | Compute the decay factor for the exponential moving average. 87 | """ 88 | step = max(0, optimization_step - self.update_after_step - 1) 89 | value = 1 - (1 + step / self.inv_gamma) ** -self.power 90 | 91 | if step <= 0: 92 | return 0.0 93 | 94 | return max(self.min_value, min(value, self.max_value)) 95 | 96 | @torch.no_grad() 97 | def step(self, new_model): 98 | ema_state_dict = {} 99 | ema_params = self.averaged_model.state_dict() 100 | 101 | self.decay = self.get_decay(self.optimization_step) 102 | 103 | for key, param in new_model.named_parameters(): 104 | if isinstance(param, dict): 105 | continue 106 | try: 107 | ema_param = ema_params[key] 108 | except KeyError: 109 | ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) 110 | ema_params[key] = ema_param 111 | 112 | if not param.requires_grad: 113 | ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) 114 | ema_param = ema_params[key] 115 | else: 116 | ema_param.mul_(self.decay) 117 | ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) 118 | 119 | ema_state_dict[key] = ema_param 120 | 121 | for key, param in new_model.named_buffers(): 122 | ema_state_dict[key] = param 123 | 124 | self.averaged_model.load_state_dict(ema_state_dict, strict=False) 125 | self.optimization_step += 1 126 | -------------------------------------------------------------------------------- /diffusers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | from .deprecation_utils import deprecate 19 | from .import_utils import ( 20 | ENV_VARS_TRUE_AND_AUTO_VALUES, 21 | ENV_VARS_TRUE_VALUES, 22 | USE_JAX, 23 | USE_TF, 24 | USE_TORCH, 25 | DummyObject, 26 | is_accelerate_available, 27 | is_flax_available, 28 | is_inflect_available, 29 | is_modelcards_available, 30 | is_onnx_available, 31 | is_scipy_available, 32 | is_tf_available, 33 | is_torch_available, 34 | is_torch_version, 35 | is_transformers_available, 36 | is_unidecode_available, 37 | requires_backends, 38 | ) 39 | from .logging import get_logger 40 | from .outputs import BaseOutput 41 | 42 | 43 | if is_torch_available(): 44 | from .testing_utils import ( 45 | floats_tensor, 46 | load_hf_numpy, 47 | load_image, 48 | load_numpy, 49 | parse_flag_from_env, 50 | require_torch_gpu, 51 | slow, 52 | torch_all_close, 53 | torch_device, 54 | ) 55 | 56 | 57 | logger = get_logger(__name__) 58 | 59 | 60 | hf_cache_home = os.path.expanduser( 61 | os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) 62 | ) 63 | default_cache_path = os.path.join(hf_cache_home, "diffusers") 64 | 65 | 66 | CONFIG_NAME = "config.json" 67 | WEIGHTS_NAME = "diffusion_pytorch_model.bin" 68 | FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" 69 | ONNX_WEIGHTS_NAME = "model.onnx" 70 | ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" 71 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" 72 | DIFFUSERS_CACHE = default_cache_path 73 | DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" 74 | HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) 75 | -------------------------------------------------------------------------------- /diffusers/utils/deprecation_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from typing import Any, Dict, Optional, Union 4 | 5 | from packaging import version 6 | 7 | 8 | def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): 9 | from .. import __version__ 10 | 11 | deprecated_kwargs = take_from 12 | values = () 13 | if not isinstance(args[0], tuple): 14 | args = (args,) 15 | 16 | for attribute, version_name, message in args: 17 | if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): 18 | raise ValueError( 19 | f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" 20 | f" version {__version__} is >= {version_name}" 21 | ) 22 | 23 | warning = None 24 | if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: 25 | values += (deprecated_kwargs.pop(attribute),) 26 | warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." 27 | elif hasattr(deprecated_kwargs, attribute): 28 | values += (getattr(deprecated_kwargs, attribute),) 29 | warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." 30 | elif deprecated_kwargs is None: 31 | warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." 32 | 33 | if warning is not None: 34 | warning = warning + " " if standard_warn else "" 35 | warnings.warn(warning + message, DeprecationWarning) 36 | 37 | if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: 38 | call_frame = inspect.getouterframes(inspect.currentframe())[1] 39 | filename = call_frame.filename 40 | line_number = call_frame.lineno 41 | function = call_frame.function 42 | key, value = next(iter(deprecated_kwargs.items())) 43 | raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") 44 | 45 | if len(values) == 0: 46 | return 47 | elif len(values) == 1: 48 | return values[0] 49 | return values 50 | -------------------------------------------------------------------------------- /diffusers/utils/dummy_flax_and_transformers_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | # flake8: noqa 3 | 4 | from ..utils import DummyObject, requires_backends 5 | 6 | 7 | class FlaxStableDiffusionPipeline(metaclass=DummyObject): 8 | _backends = ["flax", "transformers"] 9 | 10 | def __init__(self, *args, **kwargs): 11 | requires_backends(self, ["flax", "transformers"]) 12 | 13 | @classmethod 14 | def from_config(cls, *args, **kwargs): 15 | requires_backends(cls, ["flax", "transformers"]) 16 | 17 | @classmethod 18 | def from_pretrained(cls, *args, **kwargs): 19 | requires_backends(cls, ["flax", "transformers"]) 20 | -------------------------------------------------------------------------------- /diffusers/utils/dummy_flax_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | # flake8: noqa 3 | 4 | from ..utils import DummyObject, requires_backends 5 | 6 | 7 | class FlaxModelMixin(metaclass=DummyObject): 8 | _backends = ["flax"] 9 | 10 | def __init__(self, *args, **kwargs): 11 | requires_backends(self, ["flax"]) 12 | 13 | @classmethod 14 | def from_config(cls, *args, **kwargs): 15 | requires_backends(cls, ["flax"]) 16 | 17 | @classmethod 18 | def from_pretrained(cls, *args, **kwargs): 19 | requires_backends(cls, ["flax"]) 20 | 21 | 22 | class FlaxUNet2DConditionModel(metaclass=DummyObject): 23 | _backends = ["flax"] 24 | 25 | def __init__(self, *args, **kwargs): 26 | requires_backends(self, ["flax"]) 27 | 28 | @classmethod 29 | def from_config(cls, *args, **kwargs): 30 | requires_backends(cls, ["flax"]) 31 | 32 | @classmethod 33 | def from_pretrained(cls, *args, **kwargs): 34 | requires_backends(cls, ["flax"]) 35 | 36 | 37 | class FlaxAutoencoderKL(metaclass=DummyObject): 38 | _backends = ["flax"] 39 | 40 | def __init__(self, *args, **kwargs): 41 | requires_backends(self, ["flax"]) 42 | 43 | @classmethod 44 | def from_config(cls, *args, **kwargs): 45 | requires_backends(cls, ["flax"]) 46 | 47 | @classmethod 48 | def from_pretrained(cls, *args, **kwargs): 49 | requires_backends(cls, ["flax"]) 50 | 51 | 52 | class FlaxDiffusionPipeline(metaclass=DummyObject): 53 | _backends = ["flax"] 54 | 55 | def __init__(self, *args, **kwargs): 56 | requires_backends(self, ["flax"]) 57 | 58 | @classmethod 59 | def from_config(cls, *args, **kwargs): 60 | requires_backends(cls, ["flax"]) 61 | 62 | @classmethod 63 | def from_pretrained(cls, *args, **kwargs): 64 | requires_backends(cls, ["flax"]) 65 | 66 | 67 | class FlaxDDIMScheduler(metaclass=DummyObject): 68 | _backends = ["flax"] 69 | 70 | def __init__(self, *args, **kwargs): 71 | requires_backends(self, ["flax"]) 72 | 73 | @classmethod 74 | def from_config(cls, *args, **kwargs): 75 | requires_backends(cls, ["flax"]) 76 | 77 | @classmethod 78 | def from_pretrained(cls, *args, **kwargs): 79 | requires_backends(cls, ["flax"]) 80 | 81 | 82 | class FlaxDDPMScheduler(metaclass=DummyObject): 83 | _backends = ["flax"] 84 | 85 | def __init__(self, *args, **kwargs): 86 | requires_backends(self, ["flax"]) 87 | 88 | @classmethod 89 | def from_config(cls, *args, **kwargs): 90 | requires_backends(cls, ["flax"]) 91 | 92 | @classmethod 93 | def from_pretrained(cls, *args, **kwargs): 94 | requires_backends(cls, ["flax"]) 95 | 96 | 97 | class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): 98 | _backends = ["flax"] 99 | 100 | def __init__(self, *args, **kwargs): 101 | requires_backends(self, ["flax"]) 102 | 103 | @classmethod 104 | def from_config(cls, *args, **kwargs): 105 | requires_backends(cls, ["flax"]) 106 | 107 | @classmethod 108 | def from_pretrained(cls, *args, **kwargs): 109 | requires_backends(cls, ["flax"]) 110 | 111 | 112 | class FlaxKarrasVeScheduler(metaclass=DummyObject): 113 | _backends = ["flax"] 114 | 115 | def __init__(self, *args, **kwargs): 116 | requires_backends(self, ["flax"]) 117 | 118 | @classmethod 119 | def from_config(cls, *args, **kwargs): 120 | requires_backends(cls, ["flax"]) 121 | 122 | @classmethod 123 | def from_pretrained(cls, *args, **kwargs): 124 | requires_backends(cls, ["flax"]) 125 | 126 | 127 | class FlaxLMSDiscreteScheduler(metaclass=DummyObject): 128 | _backends = ["flax"] 129 | 130 | def __init__(self, *args, **kwargs): 131 | requires_backends(self, ["flax"]) 132 | 133 | @classmethod 134 | def from_config(cls, *args, **kwargs): 135 | requires_backends(cls, ["flax"]) 136 | 137 | @classmethod 138 | def from_pretrained(cls, *args, **kwargs): 139 | requires_backends(cls, ["flax"]) 140 | 141 | 142 | class FlaxPNDMScheduler(metaclass=DummyObject): 143 | _backends = ["flax"] 144 | 145 | def __init__(self, *args, **kwargs): 146 | requires_backends(self, ["flax"]) 147 | 148 | @classmethod 149 | def from_config(cls, *args, **kwargs): 150 | requires_backends(cls, ["flax"]) 151 | 152 | @classmethod 153 | def from_pretrained(cls, *args, **kwargs): 154 | requires_backends(cls, ["flax"]) 155 | 156 | 157 | class FlaxSchedulerMixin(metaclass=DummyObject): 158 | _backends = ["flax"] 159 | 160 | def __init__(self, *args, **kwargs): 161 | requires_backends(self, ["flax"]) 162 | 163 | @classmethod 164 | def from_config(cls, *args, **kwargs): 165 | requires_backends(cls, ["flax"]) 166 | 167 | @classmethod 168 | def from_pretrained(cls, *args, **kwargs): 169 | requires_backends(cls, ["flax"]) 170 | 171 | 172 | class FlaxScoreSdeVeScheduler(metaclass=DummyObject): 173 | _backends = ["flax"] 174 | 175 | def __init__(self, *args, **kwargs): 176 | requires_backends(self, ["flax"]) 177 | 178 | @classmethod 179 | def from_config(cls, *args, **kwargs): 180 | requires_backends(cls, ["flax"]) 181 | 182 | @classmethod 183 | def from_pretrained(cls, *args, **kwargs): 184 | requires_backends(cls, ["flax"]) 185 | -------------------------------------------------------------------------------- /diffusers/utils/dummy_torch_and_scipy_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | # flake8: noqa 3 | 4 | from ..utils import DummyObject, requires_backends 5 | 6 | 7 | class LMSDiscreteScheduler(metaclass=DummyObject): 8 | _backends = ["torch", "scipy"] 9 | 10 | def __init__(self, *args, **kwargs): 11 | requires_backends(self, ["torch", "scipy"]) 12 | 13 | @classmethod 14 | def from_config(cls, *args, **kwargs): 15 | requires_backends(cls, ["torch", "scipy"]) 16 | 17 | @classmethod 18 | def from_pretrained(cls, *args, **kwargs): 19 | requires_backends(cls, ["torch", "scipy"]) 20 | -------------------------------------------------------------------------------- /diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | # flake8: noqa 3 | 4 | from ..utils import DummyObject, requires_backends 5 | 6 | 7 | class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): 8 | _backends = ["torch", "transformers", "onnx"] 9 | 10 | def __init__(self, *args, **kwargs): 11 | requires_backends(self, ["torch", "transformers", "onnx"]) 12 | 13 | @classmethod 14 | def from_config(cls, *args, **kwargs): 15 | requires_backends(cls, ["torch", "transformers", "onnx"]) 16 | 17 | @classmethod 18 | def from_pretrained(cls, *args, **kwargs): 19 | requires_backends(cls, ["torch", "transformers", "onnx"]) 20 | 21 | 22 | class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): 23 | _backends = ["torch", "transformers", "onnx"] 24 | 25 | def __init__(self, *args, **kwargs): 26 | requires_backends(self, ["torch", "transformers", "onnx"]) 27 | 28 | @classmethod 29 | def from_config(cls, *args, **kwargs): 30 | requires_backends(cls, ["torch", "transformers", "onnx"]) 31 | 32 | @classmethod 33 | def from_pretrained(cls, *args, **kwargs): 34 | requires_backends(cls, ["torch", "transformers", "onnx"]) 35 | 36 | 37 | class OnnxStableDiffusionPipeline(metaclass=DummyObject): 38 | _backends = ["torch", "transformers", "onnx"] 39 | 40 | def __init__(self, *args, **kwargs): 41 | requires_backends(self, ["torch", "transformers", "onnx"]) 42 | 43 | @classmethod 44 | def from_config(cls, *args, **kwargs): 45 | requires_backends(cls, ["torch", "transformers", "onnx"]) 46 | 47 | @classmethod 48 | def from_pretrained(cls, *args, **kwargs): 49 | requires_backends(cls, ["torch", "transformers", "onnx"]) 50 | 51 | 52 | class StableDiffusionOnnxPipeline(metaclass=DummyObject): 53 | _backends = ["torch", "transformers", "onnx"] 54 | 55 | def __init__(self, *args, **kwargs): 56 | requires_backends(self, ["torch", "transformers", "onnx"]) 57 | 58 | @classmethod 59 | def from_config(cls, *args, **kwargs): 60 | requires_backends(cls, ["torch", "transformers", "onnx"]) 61 | 62 | @classmethod 63 | def from_pretrained(cls, *args, **kwargs): 64 | requires_backends(cls, ["torch", "transformers", "onnx"]) 65 | -------------------------------------------------------------------------------- /diffusers/utils/dummy_torch_and_transformers_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | # flake8: noqa 3 | 4 | from ..utils import DummyObject, requires_backends 5 | 6 | 7 | class CycleDiffusionPipeline(metaclass=DummyObject): 8 | _backends = ["torch", "transformers"] 9 | 10 | def __init__(self, *args, **kwargs): 11 | requires_backends(self, ["torch", "transformers"]) 12 | 13 | @classmethod 14 | def from_config(cls, *args, **kwargs): 15 | requires_backends(cls, ["torch", "transformers"]) 16 | 17 | @classmethod 18 | def from_pretrained(cls, *args, **kwargs): 19 | requires_backends(cls, ["torch", "transformers"]) 20 | 21 | 22 | class LDMTextToImagePipeline(metaclass=DummyObject): 23 | _backends = ["torch", "transformers"] 24 | 25 | def __init__(self, *args, **kwargs): 26 | requires_backends(self, ["torch", "transformers"]) 27 | 28 | @classmethod 29 | def from_config(cls, *args, **kwargs): 30 | requires_backends(cls, ["torch", "transformers"]) 31 | 32 | @classmethod 33 | def from_pretrained(cls, *args, **kwargs): 34 | requires_backends(cls, ["torch", "transformers"]) 35 | 36 | 37 | class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): 38 | _backends = ["torch", "transformers"] 39 | 40 | def __init__(self, *args, **kwargs): 41 | requires_backends(self, ["torch", "transformers"]) 42 | 43 | @classmethod 44 | def from_config(cls, *args, **kwargs): 45 | requires_backends(cls, ["torch", "transformers"]) 46 | 47 | @classmethod 48 | def from_pretrained(cls, *args, **kwargs): 49 | requires_backends(cls, ["torch", "transformers"]) 50 | 51 | 52 | class StableDiffusionInpaintPipeline(metaclass=DummyObject): 53 | _backends = ["torch", "transformers"] 54 | 55 | def __init__(self, *args, **kwargs): 56 | requires_backends(self, ["torch", "transformers"]) 57 | 58 | @classmethod 59 | def from_config(cls, *args, **kwargs): 60 | requires_backends(cls, ["torch", "transformers"]) 61 | 62 | @classmethod 63 | def from_pretrained(cls, *args, **kwargs): 64 | requires_backends(cls, ["torch", "transformers"]) 65 | 66 | 67 | class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): 68 | _backends = ["torch", "transformers"] 69 | 70 | def __init__(self, *args, **kwargs): 71 | requires_backends(self, ["torch", "transformers"]) 72 | 73 | @classmethod 74 | def from_config(cls, *args, **kwargs): 75 | requires_backends(cls, ["torch", "transformers"]) 76 | 77 | @classmethod 78 | def from_pretrained(cls, *args, **kwargs): 79 | requires_backends(cls, ["torch", "transformers"]) 80 | 81 | 82 | class StableDiffusionPipeline(metaclass=DummyObject): 83 | _backends = ["torch", "transformers"] 84 | 85 | def __init__(self, *args, **kwargs): 86 | requires_backends(self, ["torch", "transformers"]) 87 | 88 | @classmethod 89 | def from_config(cls, *args, **kwargs): 90 | requires_backends(cls, ["torch", "transformers"]) 91 | 92 | @classmethod 93 | def from_pretrained(cls, *args, **kwargs): 94 | requires_backends(cls, ["torch", "transformers"]) 95 | 96 | 97 | class VQDiffusionPipeline(metaclass=DummyObject): 98 | _backends = ["torch", "transformers"] 99 | 100 | def __init__(self, *args, **kwargs): 101 | requires_backends(self, ["torch", "transformers"]) 102 | 103 | @classmethod 104 | def from_config(cls, *args, **kwargs): 105 | requires_backends(cls, ["torch", "transformers"]) 106 | 107 | @classmethod 108 | def from_pretrained(cls, *args, **kwargs): 109 | requires_backends(cls, ["torch", "transformers"]) 110 | -------------------------------------------------------------------------------- /diffusers/utils/model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | 7 | 8 | # {{ model_name | default("Diffusion Model") }} 9 | 10 | ## Model description 11 | 12 | This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library 13 | on the `{{ dataset_name }}` dataset. 14 | 15 | ## Intended uses & limitations 16 | 17 | #### How to use 18 | 19 | ```python 20 | # TODO: add an example code snippet for running this diffusion pipeline 21 | ``` 22 | 23 | #### Limitations and bias 24 | 25 | [TODO: provide examples of latent issues and potential remediations] 26 | 27 | ## Training data 28 | 29 | [TODO: describe the data used to train the model] 30 | 31 | ### Training hyperparameters 32 | 33 | The following hyperparameters were used during training: 34 | - learning_rate: {{ learning_rate }} 35 | - train_batch_size: {{ train_batch_size }} 36 | - eval_batch_size: {{ eval_batch_size }} 37 | - gradient_accumulation_steps: {{ gradient_accumulation_steps }} 38 | - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} 39 | - lr_scheduler: {{ lr_scheduler }} 40 | - lr_warmup_steps: {{ lr_warmup_steps }} 41 | - ema_inv_gamma: {{ ema_inv_gamma }} 42 | - ema_inv_gamma: {{ ema_power }} 43 | - ema_inv_gamma: {{ ema_max_decay }} 44 | - mixed_precision: {{ mixed_precision }} 45 | 46 | ### Training results 47 | 48 | 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) 49 | 50 | 51 | -------------------------------------------------------------------------------- /diffusers/utils/outputs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Generic utilities 16 | """ 17 | 18 | from collections import OrderedDict 19 | from dataclasses import fields 20 | from typing import Any, Tuple 21 | 22 | import numpy as np 23 | 24 | from .import_utils import is_torch_available 25 | 26 | 27 | def is_tensor(x): 28 | """ 29 | Tests if `x` is a `torch.Tensor` or `np.ndarray`. 30 | """ 31 | if is_torch_available(): 32 | import torch 33 | 34 | if isinstance(x, torch.Tensor): 35 | return True 36 | 37 | return isinstance(x, np.ndarray) 38 | 39 | 40 | class BaseOutput(OrderedDict): 41 | """ 42 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a 43 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular 44 | python dictionary. 45 | 46 | 47 | 48 | You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple 49 | before. 50 | 51 | 52 | """ 53 | 54 | def __post_init__(self): 55 | class_fields = fields(self) 56 | 57 | # Safety and consistency checks 58 | if not len(class_fields): 59 | raise ValueError(f"{self.__class__.__name__} has no fields.") 60 | 61 | first_field = getattr(self, class_fields[0].name) 62 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) 63 | 64 | if other_fields_are_none and isinstance(first_field, dict): 65 | for key, value in first_field.items(): 66 | self[key] = value 67 | else: 68 | for field in class_fields: 69 | v = getattr(self, field.name) 70 | if v is not None: 71 | self[field.name] = v 72 | 73 | def __delitem__(self, *args, **kwargs): 74 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") 75 | 76 | def setdefault(self, *args, **kwargs): 77 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") 78 | 79 | def pop(self, *args, **kwargs): 80 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") 81 | 82 | def update(self, *args, **kwargs): 83 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") 84 | 85 | def __getitem__(self, k): 86 | if isinstance(k, str): 87 | inner_dict = {k: v for (k, v) in self.items()} 88 | return inner_dict[k] 89 | else: 90 | return self.to_tuple()[k] 91 | 92 | def __setattr__(self, name, value): 93 | if name in self.keys() and value is not None: 94 | # Don't call self.__setitem__ to avoid recursion errors 95 | super().__setitem__(name, value) 96 | super().__setattr__(name, value) 97 | 98 | def __setitem__(self, key, value): 99 | # Will raise a KeyException if needed 100 | super().__setitem__(key, value) 101 | # Don't call self.__setattr__ to avoid recursion errors 102 | super().__setattr__(key, value) 103 | 104 | def to_tuple(self) -> Tuple[Any]: 105 | """ 106 | Convert self to a tuple containing all the attributes/keys that are not `None`. 107 | """ 108 | return tuple(self[k] for k in self.keys()) 109 | -------------------------------------------------------------------------------- /examples/mask0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/mask0.png -------------------------------------------------------------------------------- /examples/mask1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/mask1.png -------------------------------------------------------------------------------- /examples/mask2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/mask2.png -------------------------------------------------------------------------------- /examples/mask3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/mask3.png -------------------------------------------------------------------------------- /examples/mask4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/mask4.png -------------------------------------------------------------------------------- /examples/sample0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/sample0.png -------------------------------------------------------------------------------- /examples/sample1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/sample1.png -------------------------------------------------------------------------------- /examples/sample2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/sample2.png -------------------------------------------------------------------------------- /examples/sample3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/sample3.png -------------------------------------------------------------------------------- /examples/sample4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/examples/sample4.png -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from omegaconf import OmegaConf 5 | from argparse import ArgumentParser 6 | from torchvision.utils import make_grid 7 | from pytorch_lightning import seed_everything 8 | from torchvision.transforms import ToPILImage, ToTensor 9 | from src.trainers import CharInpaintTrainer 10 | from src.dataset import prepare_style_chars 11 | from src.dataset.utils import prepare_npy_image_mask, normalize_image 12 | 13 | 14 | def create_parser(): 15 | parser = ArgumentParser() 16 | parser.add_argument("--seed", type=int, default=13) 17 | parser.add_argument("--ckpt_path", type=str, required=True) 18 | parser.add_argument("--in_image", type=str, required=True) 19 | parser.add_argument("--in_mask", type=str, required=True) 20 | parser.add_argument("--out_dir", default="output") 21 | parser.add_argument("--text", type=str) 22 | parser.add_argument("--font", type=str, default="") 23 | parser.add_argument("--color", type=str, default="") 24 | parser.add_argument("--instruction", type=str) 25 | parser.add_argument("--num_inference_steps", default=30) 26 | parser.add_argument("--num_sample_per_image", default=3, type=int) 27 | parser.add_argument("--guidance_scale", default=7.5, type=float) 28 | parser.add_argument("--no_cuda", action="store_true") 29 | return parser 30 | 31 | 32 | def main(opt): 33 | model = CharInpaintTrainer.load_from_checkpoint(opt.ckpt_path) 34 | device = "cpu" if opt.no_cuda else "cuda" 35 | model = model.to(device) 36 | 37 | image = Image.open(opt.in_image) 38 | mask = Image.open(opt.in_mask).convert("1") 39 | raw_image, mask, masked_image, mask_coordinate = prepare_npy_image_mask( 40 | image, mask 41 | ) 42 | 43 | if opt.instruction is not None: 44 | style = opt.instruction 45 | char = opt.text 46 | else: 47 | char = opt.text 48 | color = opt.color 49 | font = opt.font 50 | style = prepare_style_chars(char, [font, color]) 51 | 52 | torch.manual_seed(opt.seed) 53 | batch = { 54 | "image": torch.from_numpy(raw_image).unsqueeze(0).to(device), 55 | "mask": torch.from_numpy(mask).unsqueeze(0).to(device), 56 | "masked_image": torch.from_numpy(masked_image).unsqueeze(0).to(device), 57 | "coordinate": [mask_coordinate], 58 | "chars": [char], 59 | "style": [style], 60 | } 61 | 62 | generation_kwargs = { 63 | "num_inference_steps": opt.num_inference_steps, 64 | "num_sample_per_image": opt.num_sample_per_image, 65 | "guidance_scale": opt.guidance_scale, 66 | "generator": torch.Generator(model.device).manual_seed(opt.seed) 67 | } 68 | 69 | with torch.no_grad(): 70 | results = model.log_images(batch, generation_kwargs) 71 | os.makedirs(opt.out_dir, exist_ok=True) 72 | keys = results.keys() 73 | for i, k in enumerate(keys): 74 | img = torch.cat([ 75 | ((batch["image"][i:i+1].cpu()) / 2. + 0.5).clamp(0., 1.), 76 | ((batch["masked_image"][i:i+1].cpu()) / 2. + 0.5).clamp(0., 1.), 77 | results[k] 78 | ]) 79 | grid = make_grid(img, nrow=5, padding=1) 80 | ToPILImage()(grid).save( 81 | f"{opt.out_dir}/{k}-grid.png" 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = create_parser() 87 | opt = parser.parse_args() 88 | main(opt) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.16.0 2 | datasets==2.9.0 3 | editdistance==0.6.2 4 | einops==0.6.0 5 | flax 6 | huggingface_hub==0.12.0 7 | importlib_metadata==6.0.0 8 | ipdb==0.13.11 9 | jax 10 | matplotlib==3.6.3 11 | modelcards==0.1.6 12 | msgpack_python==0.5.6 13 | numpy==1.24.1 14 | omegaconf==2.3.0 15 | onnxruntime==1.14.1 16 | packaging==23.0 17 | pandas==1.5.3 18 | Pillow==9.4.0 19 | pytorch_lightning==1.9.0 20 | PyYAML==6.0 21 | requests==2.28.2 22 | scipy==1.10.1 23 | torch==1.13.1 24 | torchvision==0.14.1 25 | tqdm==4.64.1 26 | transformers==4.26.0 27 | xformers==0.0.16 28 | tokenizers==0.13.2 29 | torch-tb-profiler==0.4.1 30 | tensorboard==2.12.0 31 | tensorboard-data-server==0.7.0 32 | tensorboard-plugin-wit==1.8.1 33 | tensorboardX==2.6 34 | pyarrow==11.0.0 35 | gdown 36 | -------------------------------------------------------------------------------- /scripts/down_data.sh: -------------------------------------------------------------------------------- 1 | # COCOText 2 | mkdir -p ocr-dataset && cd ocr-dataset 3 | 4 | ( 5 | mkdir -p COCO && cd COCO 6 | wget https://github.com/bgshih/cocotext/releases/download/dl/cocotext.v2.zip 7 | wget http://images.cocodataset.org/zips/train2014.zip 8 | ) 9 | 10 | # ArT 11 | ( 12 | mkdir -p ArT && cd ArT 13 | #* manualy download from https://rrc.cvc.uab.es/?ch=14&com=downloads 14 | ) 15 | 16 | # TextOCR 17 | ( 18 | mkdir -p TextOCR && cd TextOCR 19 | wegt https://dl.fbaipublicfiles.com/textvqa/data/textocr/TextOCR_0.1_train.json 20 | wget https://dl.fbaipublicfiles.com/textvqa/data/textocr/TextOCR_0.1_val.json 21 | wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip 22 | ) 23 | 24 | # ICDAR13: 25 | ( # Please check https://mmocr.readthedocs.io/en/latest/datasets/det.html?highlight=icdar#icdar-2013-focused-scene-text for details 26 | mkdir icdar2013 && cd icdar2013 27 | mkdir imgs && mkdir annotations 28 | 29 | # Download ICDAR 2013 30 | wget https://rrc.cvc.uab.es/downloads/Challenge2_Training_Task12_Images.zip --no-check-certificate 31 | wget https://rrc.cvc.uab.es/downloads/Challenge2_Test_Task12_Images.zip --no-check-certificate 32 | wget https://rrc.cvc.uab.es/downloads/Challenge2_Training_Task1_GT.zip --no-check-certificate 33 | wget https://rrc.cvc.uab.es/downloads/Challenge2_Test_Task1_GT.zip --no-check-certificate 34 | 35 | # For images 36 | unzip -q Challenge2_Training_Task12_Images.zip -d imgs/training 37 | unzip -q Challenge2_Test_Task12_Images.zip -d imgs/test 38 | # For annotations 39 | unzip -q Challenge2_Training_Task1_GT.zip -d annotations/training 40 | unzip -q Challenge2_Test_Task1_GT.zip -d annotations/test 41 | 42 | rm Challenge2_Training_Task12_Images.zip && rm Challenge2_Test_Task12_Images.zip && rm Challenge2_Training_Task1_GT.zip && rm Challenge2_Test_Task1_GT.zip 43 | ) -------------------------------------------------------------------------------- /scripts/gen_synth.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | BASE_OUT=../ocr-dataset/synth/ 4 | PROCESS_NUM=4 5 | BG_NUM=600000 6 | 7 | cd synthgenerator 8 | echo $CWD 9 | python generate_synth.py -v -o $BASE_OUT -w $PROCESS_NUM multistyle_template SynthForCharDiffusion synthgen_config.yaml --count $BG_NUM 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/src/__init__.py -------------------------------------------------------------------------------- /src/abinet/__init__.py: -------------------------------------------------------------------------------- 1 | from .abinet_base import get_model, preprocess, postprocess, load, create_ocr_model 2 | from .utils import Config, CharsetMapper, prepare_label 3 | from .modules.model_abinet_iter import ABINetIterModelWrapper 4 | from .modules.losses import MultiLosses 5 | -------------------------------------------------------------------------------- /src/abinet/abinet_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import glob 4 | import torch 5 | import PIL 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | from omegaconf import OmegaConf 9 | from .utils import Config, CharsetMapper 10 | 11 | 12 | BASE_DIR = "src/abinet/" 13 | DEFAULT_OCR_CONFIG = { 14 | "conf": os.path.join(BASE_DIR, "configs/train_abinet.yaml"), 15 | "default_conf": os.path.join(BASE_DIR, "configs/template.yaml"), 16 | "ckpt": os.path.join(BASE_DIR, "checkpoints/abinet/train-abinet/best-train-abinet.pth"), 17 | } 18 | 19 | 20 | def create_ocr_model(device=None): 21 | print("Loading OCR model...") 22 | if device is None: 23 | device = torch.cuda.current_device() 24 | default_conf = OmegaConf.load(DEFAULT_OCR_CONFIG["default_conf"]) 25 | conf = OmegaConf.load(DEFAULT_OCR_CONFIG["conf"]) 26 | config = OmegaConf.merge(default_conf, conf) 27 | OmegaConf.resolve(config) 28 | charset = CharsetMapper( 29 | filename=config.dataset.charset_path, max_length=config.dataset.max_length + 1 30 | ) 31 | config.model_eval = "alignment" 32 | ocr_model = get_model(config.model) 33 | model = load( 34 | ocr_model, 35 | DEFAULT_OCR_CONFIG["ckpt"], 36 | device=None, 37 | strict="Contrast" not in config.model.name, 38 | ) # always load to cpu first 39 | model = model.to(device) 40 | print("OCR Model loaded") 41 | return charset, ocr_model 42 | 43 | 44 | def get_model(config, device="cpu", reload=False): 45 | import importlib 46 | 47 | module, cls = config.name.rsplit(".", 1) 48 | if reload: 49 | module_imp = importlib.import_module(module) 50 | importlib.reload(module_imp) 51 | cls = getattr(importlib.import_module(module, package=None), cls) 52 | 53 | model = cls(config) 54 | logging.info(model) 55 | model = model.eval() 56 | return model 57 | 58 | 59 | def preprocess(img, width=128, height=32): 60 | img = img.resize((width, height), PIL.Image.Resampling.BILINEAR) 61 | img = transforms.ToTensor()(img).unsqueeze(0) 62 | mean = torch.tensor([0.485, 0.456, 0.406]) 63 | std = torch.tensor([0.229, 0.224, 0.225]) 64 | return (img - mean[..., None, None]) / std[..., None, None] 65 | 66 | def postprocess(output, charset, model_eval): 67 | def _get_output(last_output, model_eval): 68 | if isinstance(last_output, (tuple, list)): 69 | for res in last_output: 70 | if res["name"] == model_eval: 71 | output = res 72 | else: 73 | output = last_output 74 | return output 75 | 76 | def _decode(logit): 77 | """Greed decode""" 78 | out = F.softmax(logit, dim=2) 79 | pt_text, pt_scores, pt_lengths = [], [], [] 80 | for o in out: 81 | text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) 82 | text = text.split(charset.null_char)[0] # end at end-token 83 | pt_text.append(text) 84 | pt_scores.append(o.max(dim=1)[0]) 85 | pt_lengths.append( 86 | min(len(text) + 1, charset.max_length) 87 | ) # one for end-token 88 | return pt_text, pt_scores, pt_lengths 89 | 90 | output = _get_output(output, model_eval) 91 | logits, pt_lengths = output["logits"], output["pt_lengths"] 92 | pt_text, pt_scores, pt_lengths_ = _decode(logits) 93 | return pt_text, pt_scores, pt_lengths_ 94 | 95 | 96 | def load(model, file, device=None, strict=True): 97 | if device is None: 98 | device = "cpu" 99 | elif isinstance(device, int): 100 | device = torch.device("cuda", device) 101 | 102 | assert os.path.isfile(file) 103 | state = torch.load(file, map_location=device) 104 | if set(state.keys()) == {"model", "opt"}: 105 | state = state["model"] 106 | model.load_state_dict(state, strict=strict) 107 | return model 108 | -------------------------------------------------------------------------------- /src/abinet/configs/pretrain_language_model.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: pretrain-language-model 3 | phase: train 4 | stage: pretrain-language 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: ['data/WikiText-103.csv'], 11 | batch_size: 4096 12 | } 13 | test: { 14 | roots: ['data/WikiText-103_eval_d1.csv'], 15 | batch_size: 4096 16 | } 17 | 18 | training: 19 | epochs: 80 20 | show_iters: 50 21 | eval_iters: 6000 22 | save_iters: 3000 23 | 24 | optimizer: 25 | type: Adam 26 | true_wd: False 27 | wd: 0.0 28 | bn_wd: False 29 | clip_grad: 20 30 | lr: 0.0001 31 | args: { 32 | betas: !!python/tuple [0.9, 0.999], # for default Adam 33 | } 34 | scheduler: { 35 | periods: [70, 10], 36 | gamma: 0.1, 37 | } 38 | 39 | model: 40 | name: 'abinet.modules.model_language.BCNLanguage' 41 | language: { 42 | num_layers: 4, 43 | loss_weight: 1., 44 | use_self_attn: False 45 | } 46 | -------------------------------------------------------------------------------- /src/abinet/configs/pretrain_vision_model.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: pretrain-vision-model 3 | phase: train 4 | stage: pretrain-vision 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: ['data/training/MJ/MJ_train/', 11 | 'data/training/MJ/MJ_test/', 12 | 'data/training/MJ/MJ_valid/', 13 | 'data/training/ST'], 14 | batch_size: 384 15 | } 16 | test: { 17 | roots: ['data/evaluation/IIIT5k_3000', 18 | 'data/evaluation/SVT', 19 | 'data/evaluation/SVTP', 20 | 'data/evaluation/IC13_857', 21 | 'data/evaluation/IC15_1811', 22 | 'data/evaluation/CUTE80'], 23 | batch_size: 384 24 | } 25 | data_aug: True 26 | multiscales: False 27 | num_workers: 14 28 | 29 | training: 30 | epochs: 8 31 | show_iters: 50 32 | eval_iters: 3000 33 | save_iters: 3000 34 | 35 | optimizer: 36 | type: Adam 37 | true_wd: False 38 | wd: 0.0 39 | bn_wd: False 40 | clip_grad: 20 41 | lr: 0.0001 42 | args: { 43 | betas: !!python/tuple [0.9, 0.999], # for default Adam 44 | } 45 | scheduler: { 46 | periods: [6, 2], 47 | gamma: 0.1, 48 | } 49 | 50 | model: 51 | name: 'abinet.modules.model_vision.BaseVision' 52 | checkpoint: ~ 53 | vision: { 54 | loss_weight: 1., 55 | attention: 'position', 56 | backbone: 'transformer', 57 | backbone_ln: 3, 58 | } 59 | -------------------------------------------------------------------------------- /src/abinet/configs/pretrain_vision_model_sv.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: pretrain-vision-model-sv 3 | phase: train 4 | stage: pretrain-vision 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: ['data/training/MJ/MJ_train/', 11 | 'data/training/MJ/MJ_test/', 12 | 'data/training/MJ/MJ_valid/', 13 | 'data/training/ST'], 14 | batch_size: 384 15 | } 16 | test: { 17 | roots: ['data/evaluation/IIIT5k_3000', 18 | 'data/evaluation/SVT', 19 | 'data/evaluation/SVTP', 20 | 'data/evaluation/IC13_857', 21 | 'data/evaluation/IC15_1811', 22 | 'data/evaluation/CUTE80'], 23 | batch_size: 384 24 | } 25 | data_aug: True 26 | multiscales: False 27 | num_workers: 14 28 | 29 | training: 30 | epochs: 8 31 | show_iters: 50 32 | eval_iters: 3000 33 | save_iters: 3000 34 | 35 | optimizer: 36 | type: Adam 37 | true_wd: False 38 | wd: 0.0 39 | bn_wd: False 40 | clip_grad: 20 41 | lr: 0.0001 42 | args: { 43 | betas: !!python/tuple [0.9, 0.999], # for default Adam 44 | } 45 | scheduler: { 46 | periods: [6, 2], 47 | gamma: 0.1, 48 | } 49 | 50 | model: 51 | name: 'abinet.modules.model_vision.BaseVision' 52 | checkpoint: ~ 53 | vision: { 54 | loss_weight: 1., 55 | attention: 'attention', 56 | backbone: 'transformer', 57 | backbone_ln: 2, 58 | } 59 | -------------------------------------------------------------------------------- /src/abinet/configs/template.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: exp 3 | phase: train 4 | stage: pretrain-vision 5 | workdir: /tmp/workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: ['data/training/MJ/MJ_train/', 11 | 'data/training/MJ/MJ_test/', 12 | 'data/training/MJ/MJ_valid/', 13 | 'data/training/ST'], 14 | batch_size: 128 15 | } 16 | test: { 17 | roots: ['data/evaluation/IIIT5k_3000', 18 | 'data/evaluation/SVT', 19 | 'data/evaluation/SVTP', 20 | 'data/evaluation/IC13_857', 21 | 'data/evaluation/IC15_1811', 22 | 'data/evaluation/CUTE80'], 23 | batch_size: 128 24 | } 25 | charset_path: abinet/data/charset_36.txt 26 | num_workers: 4 27 | max_length: 25 # 30 28 | image_height: 32 29 | image_width: 128 30 | case_sensitive: False 31 | eval_case_sensitive: False 32 | data_aug: True 33 | multiscales: False 34 | pin_memory: True 35 | smooth_label: False 36 | smooth_factor: 0.1 37 | one_hot_y: True 38 | use_sm: False 39 | 40 | training: 41 | epochs: 6 42 | show_iters: 50 43 | eval_iters: 3000 44 | save_iters: 20000 45 | start_iters: 0 46 | stats_iters: 100000 47 | 48 | optimizer: 49 | type: Adadelta # Adadelta, Adam 50 | true_wd: False 51 | wd: 0. # 0.001 52 | bn_wd: False 53 | args: { 54 | # betas: !!python/tuple [0.9, 0.99], # betas=(0.9,0.99) for AdamW 55 | # betas: !!python/tuple [0.9, 0.999], # for default Adam 56 | } 57 | clip_grad: 20 58 | lr: [1.0, 1.0, 1.0] # lr: [0.005, 0.005, 0.005] 59 | scheduler: { 60 | periods: [3, 2, 1], 61 | gamma: 0.1, 62 | } 63 | 64 | model: 65 | name: 'abinet.modules.model_abinet.ABINetModel' 66 | checkpoint: ~ 67 | strict: True 68 | -------------------------------------------------------------------------------- /src/abinet/configs/train_abinet.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: train-abinet 3 | phase: train 4 | stage: train-super 5 | workdir: workdir 6 | seed: ~ 7 | 8 | model: 9 | name: 'abinet.modules.model_abinet_iter.ABINetIterModel' 10 | max_length: ${dataset.max_length} 11 | charset_path: ${dataset.charset_path} 12 | iter_size: 3 13 | ensemble: '' 14 | use_vision: False 15 | vision: 16 | checkpoint: checkpoints/abinet/pretrain-vision-model/best-pretrain-vision-model.pth 17 | loss_weight: 1. 18 | attention: 'position' 19 | backbone: 'transformer' 20 | backbone_ln: 3 21 | max_length: ${model.max_length} 22 | charset_path: ${model.charset_path} 23 | language: 24 | checkpoint: checkpoints/abinet/pretrain-language-model/pretrain-language-model.pth 25 | num_layers: 4 26 | loss_weight: 1. 27 | detach: True 28 | use_self_attn: False 29 | max_length: ${model.max_length} 30 | charset_path: ${model.charset_path} 31 | alignment: 32 | loss_weight: 1. 33 | max_length: ${model.max_length} 34 | charset_path: ${model.charset_path} 35 | 36 | -------------------------------------------------------------------------------- /src/abinet/configs/train_abinet_sv.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: train-abinet-sv 3 | phase: train 4 | stage: train-super 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: ['data/training/MJ/MJ_train/', 11 | 'data/training/MJ/MJ_test/', 12 | 'data/training/MJ/MJ_valid/', 13 | 'data/training/ST'], 14 | batch_size: 384 15 | } 16 | test: { 17 | roots: ['data/evaluation/IIIT5k_3000', 18 | 'data/evaluation/SVT', 19 | 'data/evaluation/SVTP', 20 | 'data/evaluation/IC13_857', 21 | 'data/evaluation/IC15_1811', 22 | 'data/evaluation/CUTE80'], 23 | batch_size: 384 24 | } 25 | data_aug: True 26 | multiscales: False 27 | num_workers: 14 28 | 29 | training: 30 | epochs: 10 31 | show_iters: 50 32 | eval_iters: 3000 33 | save_iters: 3000 34 | 35 | optimizer: 36 | type: Adam 37 | true_wd: False 38 | wd: 0.0 39 | bn_wd: False 40 | clip_grad: 20 41 | lr: 0.0001 42 | args: { 43 | betas: !!python/tuple [0.9, 0.999], # for default Adam 44 | } 45 | scheduler: { 46 | periods: [6, 4], 47 | gamma: 0.1, 48 | } 49 | 50 | model: 51 | name: 'abinet.modules.model_abinet_iter.ABINetIterModel' 52 | iter_size: 3 53 | ensemble: '' 54 | use_vision: False 55 | vision: { 56 | checkpoint: workdir/pretrain-vision-model-sv/best-pretrain-vision-model-sv.pth, 57 | loss_weight: 1., 58 | attention: 'attention', 59 | backbone: 'transformer', 60 | backbone_ln: 2, 61 | } 62 | language: { 63 | checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth, 64 | num_layers: 4, 65 | loss_weight: 1., 66 | detach: True, 67 | use_self_attn: False 68 | } 69 | alignment: { 70 | loss_weight: 1., 71 | } 72 | -------------------------------------------------------------------------------- /src/abinet/configs/train_abinet_wo_iter.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | name: train-abinet-wo-iter 3 | phase: train 4 | stage: train-super 5 | workdir: workdir 6 | seed: ~ 7 | 8 | dataset: 9 | train: { 10 | roots: ['data/training/MJ/MJ_train/', 11 | 'data/training/MJ/MJ_test/', 12 | 'data/training/MJ/MJ_valid/', 13 | 'data/training/ST'], 14 | batch_size: 384 15 | } 16 | test: { 17 | roots: ['data/evaluation/IIIT5k_3000', 18 | 'data/evaluation/SVT', 19 | 'data/evaluation/SVTP', 20 | 'data/evaluation/IC13_857', 21 | 'data/evaluation/IC15_1811', 22 | 'data/evaluation/CUTE80'], 23 | batch_size: 384 24 | } 25 | data_aug: True 26 | multiscales: False 27 | num_workers: 14 28 | 29 | training: 30 | epochs: 10 31 | show_iters: 50 32 | eval_iters: 3000 33 | save_iters: 3000 34 | 35 | optimizer: 36 | type: Adam 37 | true_wd: False 38 | wd: 0.0 39 | bn_wd: False 40 | clip_grad: 20 41 | lr: 0.0001 42 | args: { 43 | betas: !!python/tuple [0.9, 0.999], # for default Adam 44 | } 45 | scheduler: { 46 | periods: [6, 4], 47 | gamma: 0.1, 48 | } 49 | 50 | model: 51 | name: 'abinet.modules.model_abinet.ABINetModel' 52 | vision: { 53 | checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth, 54 | loss_weight: 1., 55 | attention: 'position', 56 | backbone: 'transformer', 57 | backbone_ln: 3, 58 | } 59 | language: { 60 | checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth, 61 | num_layers: 4, 62 | loss_weight: 1., 63 | detach: True, 64 | use_self_attn: False 65 | } 66 | alignment: { 67 | loss_weight: 1., 68 | } 69 | -------------------------------------------------------------------------------- /src/abinet/configs/train_contrast_abinet.yaml: -------------------------------------------------------------------------------- 1 | lightning: 2 | logger: 3 | callbacks: {} 4 | modelcheckpoint: 5 | monitor: "val/loss" 6 | trainer: 7 | benchmark: true 8 | 9 | trainer: 10 | accelerator: gpu 11 | devices: [6, ] 12 | strategy: ddp 13 | amp_backend: native 14 | log_every_n_steps: 50 # this is global step 15 | precision: 16 16 | max_epochs: 10 17 | check_val_every_n_epoch: 1 18 | accumulate_grad_batches: 1 19 | 20 | 21 | model_eval: alignment 22 | model: 23 | name: 'abinet.modules.model_abinet_iter.ContrastABINetIterModel' 24 | iter_size: 3 25 | ensemble: '' 26 | use_vision: False 27 | max_length: ${dataset.max_length} 28 | charset_path: ${dataset.charset_path} 29 | 30 | source: "raw" 31 | base_learning_rate: 1e-3 32 | precision: ${trainer.precision} 33 | weight_decay: 0.0 34 | adam_epsilon: 1.0e-8 35 | 36 | vision: 37 | checkpoint: checkpoints/abinet/pretrain-vision-model/best-pretrain-vision-model.pth 38 | loss_weight: 1. 39 | attention: 'position' 40 | backbone: 'transformer' 41 | backbone_ln: 3 42 | d_model: 512 43 | charset_path: ${model.charset_path} 44 | 45 | class_num_heads: 8 46 | contrast_hidden_dim: 512 47 | max_length: ${model.max_length} 48 | 49 | language: 50 | checkpoint: checkpoints/abinet/pretrain-language-model/pretrain-language-model.pth 51 | num_layers: 4 52 | loss_weight: 1. 53 | detach: True 54 | use_self_attn: False 55 | max_length: ${model.max_length} 56 | charset_path: ${model.charset_path} 57 | 58 | alignment: 59 | loss_weight: 1. 60 | max_length: ${model.max_length} 61 | charset_path: ${model.charset_path} 62 | 63 | char_tokenizer: 64 | pretrained_path: "checkpoint/abinet/chartokenizer" 65 | pad_token: "\u2591" 66 | unk_token: "\u2591" 67 | cls_token: "[bos]" 68 | 69 | char_embedder: 70 | vocab_size: 95 # by default 71 | embedding_dim: 32 72 | padding_idx: 0 73 | attention_head_dim: 2 74 | encoder: 75 | contrast_hidden_dim: 512 76 | num_heads: 8 77 | num_encoder_layers: 2 78 | 79 | data: 80 | batch_size: 128 81 | base_dir: dataset/ocr-dataset/SynthText/data_dir/ 82 | train: 83 | target: "train_abinet.ContrastOCRData" 84 | params: 85 | data_csv: ${data.base_dir}/expand_train.csv 86 | img_dir: ${data.base_dir} 87 | is_raw_synth: True 88 | width: 128 # same as ABINet 89 | height: 32 90 | multiscale: True 91 | training: False 92 | 93 | validation: 94 | target: "train_abinet.ContrastOCRData" 95 | params: 96 | data_csv: ${data.base_dir}/expand_val.csv 97 | img_dir: ${data.base_dir} 98 | is_raw_synth: True 99 | width: 128 100 | height: 32 101 | multiscale: False 102 | training: False 103 | -------------------------------------------------------------------------------- /src/abinet/data/charset_36.txt: -------------------------------------------------------------------------------- 1 | 0 a 2 | 1 b 3 | 2 c 4 | 3 d 5 | 4 e 6 | 5 f 7 | 6 g 8 | 7 h 9 | 8 i 10 | 9 j 11 | 10 k 12 | 11 l 13 | 12 m 14 | 13 n 15 | 14 o 16 | 15 p 17 | 16 q 18 | 17 r 19 | 18 s 20 | 19 t 21 | 20 u 22 | 21 v 23 | 22 w 24 | 23 x 25 | 24 y 26 | 25 z 27 | 26 1 28 | 27 2 29 | 28 3 30 | 29 4 31 | 30 5 32 | 31 6 33 | 32 7 34 | 33 8 35 | 34 9 36 | 35 0 -------------------------------------------------------------------------------- /src/abinet/data/charset_62.txt: -------------------------------------------------------------------------------- 1 | 0 0 2 | 1 1 3 | 2 2 4 | 3 3 5 | 4 4 6 | 5 5 7 | 6 6 8 | 7 7 9 | 8 8 10 | 9 9 11 | 10 A 12 | 11 B 13 | 12 C 14 | 13 D 15 | 14 E 16 | 15 F 17 | 16 G 18 | 17 H 19 | 18 I 20 | 19 J 21 | 20 K 22 | 21 L 23 | 22 M 24 | 23 N 25 | 24 O 26 | 25 P 27 | 26 Q 28 | 27 R 29 | 28 S 30 | 29 T 31 | 30 U 32 | 31 V 33 | 32 W 34 | 33 X 35 | 34 Y 36 | 35 Z 37 | 36 a 38 | 37 b 39 | 38 c 40 | 39 d 41 | 40 e 42 | 41 f 43 | 42 g 44 | 43 h 45 | 44 i 46 | 45 j 47 | 46 k 48 | 47 l 49 | 48 m 50 | 49 n 51 | 50 o 52 | 51 p 53 | 52 q 54 | 53 r 55 | 54 s 56 | 55 t 57 | 56 u 58 | 57 v 59 | 58 w 60 | 59 x 61 | 60 y 62 | 61 z -------------------------------------------------------------------------------- /src/abinet/data/charset_vn.txt: -------------------------------------------------------------------------------- 1 | 0 ' 2 | 1 - 3 | 2 ! 4 | 3 " 5 | 4 # 6 | 5 $ 7 | 6 % 8 | 7 & 9 | 8 ( 10 | 9 ) 11 | 10 * 12 | 11 , 13 | 12 . 14 | 13 / 15 | 14 : 16 | 15 ; 17 | 16 ? 18 | 17 @ 19 | 18 [ 20 | 19 \ 21 | 20 ] 22 | 21 ^ 23 | 22 _ 24 | 23 | 25 | 24 ~ 26 | 25 + 27 | 26 < 28 | 27 = 29 | 28 > 30 | 29 0 31 | 30 1 32 | 31 2 33 | 32 3 34 | 33 4 35 | 34 5 36 | 35 6 37 | 36 7 38 | 37 8 39 | 38 9 40 | 39 a 41 | 40 A 42 | 41 à 43 | 42 À 44 | 43 ả 45 | 44 Ả 46 | 45 ã 47 | 46 Ã 48 | 47 á 49 | 48 Á 50 | 49 ạ 51 | 50 Ạ 52 | 51 ă 53 | 52 Ă 54 | 53 ằ 55 | 54 Ằ 56 | 55 ẳ 57 | 56 Ẳ 58 | 57 ẵ 59 | 58 Ẵ 60 | 59 ắ 61 | 60 Ắ 62 | 61 ặ 63 | 62 Ặ 64 | 63 â 65 | 64 Â 66 | 65 ầ 67 | 66 Ầ 68 | 67 ẩ 69 | 68 Ẩ 70 | 69 ẫ 71 | 70 Ẫ 72 | 71 ấ 73 | 72 Ấ 74 | 73 ậ 75 | 74 Ậ 76 | 75 b 77 | 76 B 78 | 77 c 79 | 78 C 80 | 79 d 81 | 80 D 82 | 81 đ 83 | 82 Đ 84 | 83 e 85 | 84 E 86 | 85 è 87 | 86 È 88 | 87 ẻ 89 | 88 Ẻ 90 | 89 ẽ 91 | 90 Ẽ 92 | 91 é 93 | 92 É 94 | 93 ẹ 95 | 94 Ẹ 96 | 95 ė 97 | 96 ê 98 | 97 Ê 99 | 98 ề 100 | 99 Ề 101 | 100 ể 102 | 101 Ể 103 | 102 ễ 104 | 103 Ễ 105 | 104 ế 106 | 105 Ế 107 | 106 ệ 108 | 107 Ệ 109 | 108 f 110 | 109 F 111 | 110 g 112 | 111 G 113 | 112 h 114 | 113 H 115 | 114 i 116 | 115 I 117 | 116 ì 118 | 117 Ì 119 | 118 ỉ 120 | 119 Ỉ 121 | 120 ĩ 122 | 121 Ĩ 123 | 122 í 124 | 123 Í 125 | 124 ị 126 | 125 Ị 127 | 126 j 128 | 127 J 129 | 128 k 130 | 129 K 131 | 130 l 132 | 131 L 133 | 132 m 134 | 133 M 135 | 134 n 136 | 135 N 137 | 136 o 138 | 137 O 139 | 138 ò 140 | 139 Ò 141 | 140 ỏ 142 | 141 Ỏ 143 | 142 õ 144 | 143 Õ 145 | 144 ó 146 | 145 Ó 147 | 146 ọ 148 | 147 Ọ 149 | 148 ô 150 | 149 Ô 151 | 150 ồ 152 | 151 Ồ 153 | 152 ổ 154 | 153 Ổ 155 | 154 ỗ 156 | 155 Ỗ 157 | 156 ố 158 | 157 Ố 159 | 158 ộ 160 | 159 Ộ 161 | 160 ơ 162 | 161 Ơ 163 | 162 ờ 164 | 163 Ờ 165 | 164 ở 166 | 165 Ở 167 | 166 ỡ 168 | 167 Ỡ 169 | 168 ớ 170 | 169 Ớ 171 | 170 ợ 172 | 171 Ợ 173 | 172 p 174 | 173 P 175 | 174 q 176 | 175 Q 177 | 176 r 178 | 177 R 179 | 178 s 180 | 179 S 181 | 180 t 182 | 181 T 183 | 182 u 184 | 183 U 185 | 184 ù 186 | 185 Ù 187 | 186 ủ 188 | 187 Ủ 189 | 188 ũ 190 | 189 Ũ 191 | 190 ú 192 | 191 Ú 193 | 192 ụ 194 | 193 Ụ 195 | 194 ư 196 | 195 Ư 197 | 196 ừ 198 | 197 Ừ 199 | 198 ử 200 | 199 Ử 201 | 200 ữ 202 | 201 Ữ 203 | 202 ứ 204 | 203 Ứ 205 | 204 ự 206 | 205 Ự 207 | 206 v 208 | 207 V 209 | 208 w 210 | 209 W 211 | 210 x 212 | 211 X 213 | 212 y 214 | 213 Y 215 | 214 ỳ 216 | 215 Ỳ 217 | 216 ỷ 218 | 217 Ỷ 219 | 218 ỹ 220 | 219 Ỹ 221 | 220 ý 222 | 221 Ý 223 | 222 ỵ 224 | 223 Ỵ 225 | 224 z 226 | 225 Z 227 | -------------------------------------------------------------------------------- /src/abinet/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/src/abinet/modules/__init__.py -------------------------------------------------------------------------------- /src/abinet/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .transformer import PositionalEncoding 4 | 5 | class Attention(nn.Module): 6 | def __init__(self, in_channels=512, max_length=25, n_feature=256): 7 | super().__init__() 8 | self.max_length = max_length 9 | 10 | self.f0_embedding = nn.Embedding(max_length, in_channels) 11 | self.w0 = nn.Linear(max_length, n_feature) 12 | self.wv = nn.Linear(in_channels, in_channels) 13 | self.we = nn.Linear(in_channels, max_length) 14 | 15 | self.active = nn.Tanh() 16 | self.softmax = nn.Softmax(dim=2) 17 | 18 | def forward(self, enc_output): 19 | enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) 20 | reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) 21 | reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) 22 | reading_order_embed = self.f0_embedding(reading_order) # b,25,512 23 | 24 | t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 25 | t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 26 | 27 | attn = self.we(t) # b,256,25 28 | attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 29 | g_output = torch.bmm(attn, enc_output) # b,25,512 30 | return g_output, attn.view(*attn.shape[:2], 8, 32) 31 | 32 | 33 | def encoder_layer(in_c, out_c, k=3, s=2, p=1): 34 | return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), 35 | nn.BatchNorm2d(out_c), 36 | nn.ReLU(True)) 37 | 38 | def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): 39 | align_corners = None if mode=='nearest' else True 40 | return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, 41 | mode=mode, align_corners=align_corners), 42 | nn.Conv2d(in_c, out_c, k, s, p), 43 | nn.BatchNorm2d(out_c), 44 | nn.ReLU(True)) 45 | 46 | 47 | class PositionAttention(nn.Module): 48 | def __init__(self, max_length, in_channels=512, num_channels=64, 49 | h=8, w=32, mode='nearest', **kwargs): 50 | super().__init__() 51 | self.max_length = max_length 52 | self.k_encoder = nn.Sequential( 53 | encoder_layer(in_channels, num_channels, s=(1, 2)), 54 | encoder_layer(num_channels, num_channels, s=(2, 2)), 55 | encoder_layer(num_channels, num_channels, s=(2, 2)), 56 | encoder_layer(num_channels, num_channels, s=(2, 2)) 57 | ) 58 | self.k_decoder = nn.Sequential( 59 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 60 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 61 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 62 | decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) 63 | ) 64 | 65 | self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length) 66 | self.project = nn.Linear(in_channels, in_channels) 67 | 68 | def forward(self, x): 69 | N, E, H, W = x.size() 70 | k, v = x, x # (N, E, H, W) 71 | 72 | # calculate key vector 73 | features = [] 74 | for i in range(0, len(self.k_encoder)): 75 | k = self.k_encoder[i](k) 76 | features.append(k) 77 | for i in range(0, len(self.k_decoder) - 1): 78 | k = self.k_decoder[i](k) 79 | k = k + features[len(self.k_decoder) - 2 - i] 80 | k = self.k_decoder[-1](k) 81 | 82 | # calculate query vector 83 | # TODO q=f(q,k) 84 | zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) 85 | q = self.pos_encoder(zeros) # (T, N, E) 86 | q = q.permute(1, 0, 2) # (N, T, E) 87 | q = self.project(q) # (N, T, E) 88 | 89 | # calculate attention 90 | attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) 91 | attn_scores = attn_scores / (E ** 0.5) 92 | attn_scores = torch.softmax(attn_scores, dim=-1) 93 | 94 | v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) 95 | attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) 96 | 97 | return attn_vecs, attn_scores.view(N, -1, H, W) -------------------------------------------------------------------------------- /src/abinet/modules/backbone.py: -------------------------------------------------------------------------------- 1 | # from fastai.vision import * 2 | 3 | from .model import _default_tfmer_cfg 4 | from .resnet import resnet45 5 | from .transformer import (PositionalEncoding, 6 | TransformerEncoder, 7 | TransformerEncoderLayer) 8 | from .module_util import * 9 | 10 | class ResTranformer(nn.Module): 11 | def __init__(self, config): 12 | super().__init__() 13 | self.resnet = resnet45() 14 | 15 | self.d_model = config.get("d_model", _default_tfmer_cfg["d_model"]) 16 | # self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model']) 17 | nhead = config.get("nhead", _default_tfmer_cfg["nhead"]) 18 | d_inner = config.get("d_inner", _default_tfmer_cfg["d_inner"]) 19 | dropout = config.get("dropout", _default_tfmer_cfg["dropout"]) 20 | activation = config.get("activation", _default_tfmer_cfg["activation"]) 21 | num_layers = config.get("backbone_ln", 2) 22 | 23 | self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32) 24 | encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead, 25 | dim_feedforward=d_inner, dropout=dropout, activation=activation) 26 | self.transformer = TransformerEncoder(encoder_layer, num_layers) 27 | 28 | def forward(self, images): 29 | feature = self.resnet(images) 30 | n, c, h, w = feature.shape 31 | feature = feature.view(n, c, -1).permute(2, 0, 1) 32 | feature = self.pos_encoder(feature) 33 | feature = self.transformer(feature) 34 | feature = feature.permute(1, 2, 0).view(n, c, h, w) 35 | return feature 36 | -------------------------------------------------------------------------------- /src/abinet/modules/backbone_v2.py: -------------------------------------------------------------------------------- 1 | """ Created by MrBBS """ 2 | # 10/11/2022 3 | # -*-encoding:utf-8-*- 4 | 5 | -------------------------------------------------------------------------------- /src/abinet/modules/losses.py: -------------------------------------------------------------------------------- 1 | # from fastai.vision import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class MultiLosses(nn.Module): 7 | def __init__(self, one_hot=True): 8 | super().__init__() 9 | self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss() 10 | self.bce = torch.nn.BCELoss() 11 | 12 | @property 13 | def last_losses(self): 14 | return self.losses 15 | 16 | def _flatten(self, sources, lengths): 17 | return torch.cat([t[:l] for t, l in zip(sources, lengths)]) 18 | 19 | def _merge_list(self, all_res): 20 | if not isinstance(all_res, (list, tuple)): 21 | return all_res 22 | def merge(items): 23 | if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0) 24 | else: return items[0] 25 | res = dict() 26 | for key in all_res[0].keys(): 27 | items = [r[key] for r in all_res] 28 | res[key] = merge(items) 29 | return res 30 | 31 | def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True): 32 | loss_name = output.get('name') 33 | pt_logits, weight = output['logits'], output['loss_weight'] 34 | 35 | assert pt_logits.shape[0] % gt_labels.shape[0] == 0 36 | iter_size = pt_logits.shape[0] // gt_labels.shape[0] 37 | if iter_size > 1: 38 | gt_labels = gt_labels.repeat(3, 1, 1) 39 | gt_lengths = gt_lengths.repeat(3) 40 | flat_gt_labels = self._flatten(gt_labels, gt_lengths) 41 | flat_pt_logits = self._flatten(pt_logits, gt_lengths) 42 | 43 | nll = output.get('nll') 44 | if nll is not None: 45 | loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight 46 | else: 47 | loss = self.ce(flat_pt_logits, flat_gt_labels) * weight 48 | if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss 49 | 50 | return loss 51 | 52 | def forward(self, outputs, *args): 53 | self.losses = {} 54 | if isinstance(outputs, (tuple, list)): 55 | outputs = [self._merge_list(o) for o in outputs] 56 | return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.]) 57 | # return torch.mean([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.]) 58 | else: 59 | return self._ce_loss(outputs, *args, record=False) 60 | 61 | 62 | class SoftCrossEntropyLoss(nn.Module): 63 | def __init__(self, reduction="mean"): 64 | super().__init__() 65 | self.reduction = reduction 66 | 67 | def forward(self, input, target, softmax=True): 68 | if softmax: log_prob = F.log_softmax(input, dim=-1) 69 | else: log_prob = torch.log(input) 70 | loss = -(target * log_prob).sum(dim=-1) 71 | if self.reduction == "mean": return loss.mean() 72 | elif self.reduction == "sum": return loss.sum() 73 | else: return loss 74 | -------------------------------------------------------------------------------- /src/abinet/modules/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from ..utils import CharsetMapper 6 | 7 | _default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024 8 | dropout=0.1, activation='relu') 9 | 10 | 11 | class Model(nn.Module): 12 | 13 | def __init__(self, config): 14 | super().__init__() 15 | self.max_length = config.max_length + 1 16 | self.charset = CharsetMapper(config.charset_path, 17 | max_length=self.max_length) 18 | 19 | def load(self, source, device=None, strict=True): 20 | state = torch.load(source, map_location=device) 21 | self.load_state_dict(state['model'], strict=strict) 22 | 23 | def _get_length(self, logit): 24 | """ Greed decoder to obtain length from logit""" 25 | out = (logit.argmax(dim=-1) == self.charset.null_label) 26 | out = self.first_nonzero(out.int()) + 1 27 | return out 28 | 29 | @staticmethod 30 | def first_nonzero(x): 31 | non_zero_mask = x != 0 32 | mask_max_values, mask_max_indices = torch.max( 33 | non_zero_mask.int(), dim=-1) 34 | mask_max_indices[mask_max_values == 0] = -1 35 | return mask_max_indices 36 | 37 | @staticmethod 38 | def _get_padding_mask(length, max_length): 39 | length = length.unsqueeze(-1) 40 | grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) 41 | return grid >= length 42 | 43 | @staticmethod 44 | def _get_square_subsequent_mask(sz, device, diagonal=0, fw=True): 45 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 46 | Unmasked positions are filled with float(0.0). 47 | """ 48 | mask = (torch.triu(torch.ones(sz, sz, device=device), diagonal=diagonal) == 1) 49 | if fw: 50 | mask = mask.transpose(0, 1) 51 | mask = mask.float().masked_fill(mask == 0, float( 52 | '-inf')).masked_fill(mask == 1, float(0.0)) 53 | return mask 54 | 55 | @staticmethod 56 | def _get_location_mask(sz, device=None): 57 | mask = torch.eye(sz, device=device) 58 | mask = mask.float().masked_fill(mask == 1, float('-inf')) 59 | return mask 60 | -------------------------------------------------------------------------------- /src/abinet/modules/model_abinet.py: -------------------------------------------------------------------------------- 1 | # from fastai.vision import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .model_alignment import BaseAlignment 7 | from .model_language import BCNLanguage 8 | from .model_vision import BaseVision, ContrastVision 9 | 10 | from .module_util import * 11 | 12 | class ABINetModel(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | self.use_alignment = ifnone(config.model_use_alignment, True) 16 | self.max_length = config.dataset_max_length + 1 # additional stop token 17 | self.vision = BaseVision(config) 18 | self.language = BCNLanguage(config) 19 | if self.use_alignment: self.alignment = BaseAlignment(config) 20 | 21 | def forward(self, images, *args): 22 | v_res = self.vision(images) 23 | v_tokens = torch.softmax(v_res['logits'], dim=-1) 24 | v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model 25 | 26 | l_res = self.language(v_tokens, v_lengths) 27 | if not self.use_alignment: 28 | return l_res, v_res 29 | l_feature, v_feature = l_res['feature'], v_res['feature'] 30 | 31 | a_res = self.alignment(l_feature, v_feature) 32 | return a_res, l_res, v_res 33 | -------------------------------------------------------------------------------- /src/abinet/modules/model_abinet_iter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision.transforms.functional as vision_F 4 | 5 | from .model_vision import BaseVision, ContrastVision 6 | from .model_language import BCNLanguage 7 | from .model_alignment import BaseAlignment 8 | from .module_util import * 9 | from .losses import MultiLosses 10 | from typing import Tuple 11 | 12 | 13 | class ABINetIterModel(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | self.iter_size = config.iter_size 17 | self.max_length = config.max_length + 1 # additional stop token 18 | self.vision = BaseVision(config.vision) 19 | self.language = BCNLanguage(config.language) 20 | self.alignment = BaseAlignment(config.alignment) 21 | self.export = config.get("export", False) 22 | 23 | def forward(self, images, mode="train", *args): 24 | v_res = self.vision(images) 25 | a_res = v_res 26 | all_l_res, all_a_res = [], [] 27 | for _ in range(self.iter_size): 28 | tokens = torch.softmax(a_res["logits"], dim=-1) 29 | lengths = a_res["pt_lengths"] 30 | lengths.clamp_(2, self.max_length) # TODO:move to langauge model 31 | l_res = self.language(tokens, lengths) 32 | all_l_res.append(l_res) 33 | a_res = self.alignment(l_res["feature"], v_res["feature"]) 34 | all_a_res.append(a_res) 35 | if self.export: 36 | return F.softmax(a_res["logits"], dim=2), a_res["pt_lengths"] 37 | if mode == "train": 38 | return all_a_res, all_l_res, v_res 39 | elif mode == "validation": 40 | return all_a_res, all_l_res, v_res, (a_res, all_l_res[-1], v_res) 41 | else: 42 | return a_res, all_l_res[-1], v_res 43 | 44 | 45 | class ABINetIterModelWrapper(nn.Module): 46 | # wrapper for ABINetIterModel to make loss_computation in this 47 | def __init__(self, config, width, height) -> None: 48 | super().__init__() 49 | # TODO: accomodate ContrastABINetIterModel 50 | self.abinet = ABINetIterModel(config) 51 | self.width = width 52 | self.height = height 53 | self.loss_fn = MultiLosses(True) 54 | 55 | def preprocess_char(self, char_tokenizer, labels, device): 56 | # convert label strings to char_input_ids, one_hot_label for loss computation 57 | inputs = char_tokenizer( 58 | labels, 59 | return_tensors="pt", 60 | padding="max_length", 61 | truncation=True, 62 | add_special_tokens=True, 63 | ) 64 | char_input_ids = inputs["input_ids"] 65 | abi_num_classes = len(char_tokenizer) - 1 66 | abi_labels = char_input_ids[:, 1:] 67 | gt_labels = F.one_hot(abi_labels, abi_num_classes) 68 | gt_lengths = torch.sum(inputs.attention_mask, dim=1) 69 | return char_input_ids.to(device), gt_labels.to(device), gt_lengths.to(device) 70 | 71 | def preprocess(self, image): 72 | # images: (C, H, W) 73 | # this method resize images to self.w self.h 74 | return vision_F.resize(image, (self.height, self.width)) 75 | 76 | def forward(self, images, char_inputs: Tuple, mode="train", *args): 77 | char_input_ids, gt_labels, gt_lengths = char_inputs 78 | assert images.device == char_input_ids.device == gt_labels.device 79 | outputs = self.abinet(images, char_input_ids, mode) 80 | celoss_inputs = outputs[:3] 81 | # TODO: add contrast loss later 82 | celoss = self.loss_fn(celoss_inputs, gt_labels, gt_lengths) 83 | if mode == "train": 84 | return celoss 85 | elif mode == "test" or mode == "validation": 86 | #! TODO: not compatible with tokenizer 87 | text_preds = outputs[-1] 88 | pt_text, a, b = postprocess( 89 | text_preds, 90 | ) 91 | return celoss, outputs[-1] 92 | -------------------------------------------------------------------------------- /src/abinet/modules/model_alignment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .model import Model, _default_tfmer_cfg 5 | from .module_util import * 6 | 7 | 8 | class BaseAlignment(Model): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | d_model = config.get("d_model", _default_tfmer_cfg["d_model"]) 12 | 13 | self.loss_weight = config.get("loss_weight", 1.0) 14 | self.w_att = nn.Linear(2 * d_model, d_model) 15 | self.cls = nn.Linear(d_model, self.charset.num_classes) 16 | 17 | def forward(self, l_feature, v_feature): 18 | """ 19 | Args: 20 | l_feature: (N, T, E) where T is length, N is batch size and d is dim of model 21 | v_feature: (N, T, E) shape the same as l_feature 22 | l_lengths: (N,) 23 | v_lengths: (N,) 24 | """ 25 | f = torch.cat((l_feature, v_feature), dim=2) 26 | f_att = torch.sigmoid(self.w_att(f)) 27 | output = f_att * v_feature + (1 - f_att) * l_feature 28 | 29 | logits = self.cls(output) # (N, T, C) 30 | pt_lengths = self._get_length(logits) 31 | 32 | return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight, 33 | 'name': 'alignment'} 34 | -------------------------------------------------------------------------------- /src/abinet/modules/model_language.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from .model import _default_tfmer_cfg 5 | from .model import Model 6 | from .transformer import (PositionalEncoding, 7 | TransformerDecoder, 8 | TransformerDecoderLayer) 9 | 10 | from .module_util import * 11 | 12 | class BCNLanguage(Model): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | d_model = config.get("d_model", _default_tfmer_cfg['d_model']) 16 | nhead = config.get("nhead", _default_tfmer_cfg['nhead']) 17 | d_inner = config.get("d_inner", _default_tfmer_cfg['d_inner']) 18 | dropout = config.get("dropout", _default_tfmer_cfg['dropout']) 19 | activation = config.get("activation", _default_tfmer_cfg['activation']) 20 | num_layers = config.get("num_layers", 4) 21 | 22 | self.d_model = d_model 23 | self.detach = config.get("detach", True) 24 | self.use_self_attn = config.get("use_self_attn", False) 25 | self.loss_weight = config.get("loss_weight", 1.0) 26 | # self.max_length = self.max_length 27 | 28 | self.proj = nn.Linear(self.charset.num_classes, d_model, False) 29 | self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length) 30 | self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length) 31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, 32 | activation, self_attn=self.use_self_attn) 33 | self.model = TransformerDecoder(decoder_layer, num_layers) 34 | 35 | self.cls = nn.Linear(d_model, self.charset.num_classes) 36 | 37 | if config.checkpoint is not None: 38 | logging.info(f'Read language model from {config.checkpoint}.') 39 | self.load(config.checkpoint, device="cpu") 40 | 41 | def forward(self, tokens, lengths): 42 | """ 43 | Args: 44 | tokens: (N, T, C) where T is length, N is batch size and C is classes number 45 | lengths: (N,) 46 | """ 47 | if self.detach: tokens = tokens.detach() 48 | embed = self.proj(tokens) # (N, T, E) 49 | embed = embed.permute(1, 0, 2) # (T, N, E) 50 | embed = self.token_encoder(embed) # (T, N, E) 51 | padding_mask = self._get_padding_mask(lengths, self.max_length) 52 | 53 | zeros = embed.new_zeros(*embed.shape) 54 | qeury = self.pos_encoder(zeros) 55 | location_mask = self._get_location_mask(self.max_length, tokens.device) 56 | output = self.model(qeury, embed, 57 | tgt_key_padding_mask=padding_mask, 58 | memory_mask=location_mask, 59 | memory_key_padding_mask=padding_mask) # (T, N, E) 60 | output = output.permute(1, 0, 2) # (N, T, E) 61 | 62 | logits = self.cls(output) # (N, T, C) 63 | pt_lengths = self._get_length(logits) 64 | 65 | res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths, 66 | 'loss_weight':self.loss_weight, 'name': 'language'} 67 | return res 68 | -------------------------------------------------------------------------------- /src/abinet/modules/model_vision.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from .attention import * 5 | from .backbone import ResTranformer 6 | from .model import Model 7 | from .resnet import resnet45 8 | 9 | from .module_util import * 10 | 11 | class BaseVision(Model): 12 | def __init__(self, config): 13 | super().__init__(config) 14 | self.loss_weight = config.get("loss_weight", 1.0) 15 | self.out_channels = config.get("d_model", 512) 16 | 17 | if config.backbone == 'transformer': 18 | self.backbone = ResTranformer(config) 19 | else: 20 | self.backbone = resnet45() 21 | 22 | if config.attention == 'position': 23 | mode = config.get("attention_mode", 'nearest') 24 | self.attention = PositionAttention( 25 | max_length=config.max_length + 1, # additional stop token 26 | mode=mode, 27 | ) 28 | elif config.attention == 'attention': 29 | self.attention = Attention( 30 | max_length=config.max_length + 1, # additional stop token 31 | n_feature=8*32, 32 | ) 33 | else: 34 | raise Exception(f'{config.attention} is not valid.') 35 | self.cls = nn.Linear(self.out_channels, self.charset.num_classes) 36 | 37 | if config.checkpoint is not None: 38 | logging.info(f'Read vision model from {config.checkpoint}.') 39 | self.load(config.checkpoint, device="cpu") # always cpu first and then convert 40 | 41 | def forward(self, images, *args): 42 | features = self.backbone(images) # (N, E, H, W) 43 | attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) 44 | logits = self.cls(attn_vecs) # (N, T, C) 45 | pt_lengths = self._get_length(logits) 46 | 47 | return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, 48 | 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision'} 49 | 50 | class ContrastVision(BaseVision): 51 | def __init__(self, config): 52 | assert config.attention == 'position', "Contrastive learning only supports position attention in this model." 53 | super().__init__(config) # backbone is not changed 54 | 55 | # gather the information from attn_vecs provided by features 56 | self.class_embedding_q = nn.Parameter(torch.randn(self.out_channels)) 57 | self.class_encoder = nn.MultiheadAttention( 58 | embed_dim=self.out_channels, num_heads=config.class_num_heads, batch_first=True, 59 | ) 60 | 61 | def forward(self, images, *args): 62 | # 1. Extract features 63 | features = self.backbone(images) # (N, E, H, W) 64 | attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W) 65 | 66 | # 2. logits as before 67 | logits = self.cls(attn_vecs) 68 | pt_lengths = self._get_length(logits) 69 | 70 | # 3. Compute the class embedding for contrastive learning 71 | # attn_vecs already has position information(position embedding), 72 | # therefore we use attention mechanism to do weighted-sum on them 73 | 74 | # pt_lengths contain the length of each sequence, therefore we can use it to mask the padding part 75 | mask = torch.arange(attn_vecs.shape[1], device=logits.device)[None, :] < pt_lengths[:, None] 76 | class_embedding_q = self.class_embedding_q.expand(attn_vecs.shape[0], 1, -1) # expand to (N, 1, E) 77 | # attention weighted sum of attn_vecs, (N, 1, E) 78 | class_feature, _ = self.class_encoder( 79 | query=class_embedding_q, 80 | key=attn_vecs, 81 | value=attn_vecs, 82 | key_padding_mask=mask, 83 | ) # we only want the weighted value 84 | class_feature = class_feature[:, 0, :] # (N, E) 85 | 86 | return {'feature': attn_vecs, 'logits': logits, 87 | 'pt_lengths': pt_lengths, 'class_feature' : class_feature, 88 | 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision',} 89 | -------------------------------------------------------------------------------- /src/abinet/modules/module_util.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | from typing import Any 4 | def ifnone(a:Any,b:Any)->Any: 5 | "`a` if `a` is not None, otherwise `b`." 6 | return b if a is None else a 7 | 8 | def postprocess(output, charset, model_eval): 9 | def _get_output(last_output, model_eval): 10 | if isinstance(last_output, (tuple, list)): 11 | for res in last_output: 12 | if res['name'] == model_eval: 13 | output = res 14 | else: 15 | output = last_output 16 | return output 17 | 18 | def _decode(logit): 19 | """ Greed decode """ 20 | out = F.softmax(logit, dim=2) 21 | pt_text, pt_scores, pt_lengths = [], [], [] 22 | for o in out: 23 | text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) 24 | text = text.split(charset.null_char)[0] # end at end-token 25 | pt_text.append(text) 26 | pt_scores.append(o.max(dim=1)[0]) 27 | pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token 28 | return pt_text, pt_scores, pt_lengths 29 | 30 | output = _get_output(output, model_eval) 31 | logits, pt_lengths = output['logits'], output['pt_lengths'] 32 | pt_text, pt_scores, pt_lengths_ = _decode(logits) 33 | return pt_text, pt_scores, pt_lengths_ 34 | -------------------------------------------------------------------------------- /src/abinet/modules/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | def conv1x1(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | "3x3 convolution with padding" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv1x1(inplanes, planes) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes, stride) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class ResNet(nn.Module): 51 | 52 | def __init__(self, block, layers): 53 | self.inplanes = 32 54 | super(ResNet, self).__init__() 55 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, 56 | bias=False) 57 | self.bn1 = nn.BatchNorm2d(32) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.layer1 = self._make_layer(block, 32, layers[0], stride=2) 61 | self.layer2 = self._make_layer(block, 64, layers[1], stride=1) 62 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 63 | self.layer4 = self._make_layer(block, 256, layers[3], stride=1) 64 | self.layer5 = self._make_layer(block, 512, layers[4], stride=1) 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | def _make_layer(self, block, planes, blocks, stride=1): 75 | downsample = None 76 | if stride != 1 or self.inplanes != planes * block.expansion: 77 | downsample = nn.Sequential( 78 | nn.Conv2d(self.inplanes, planes * block.expansion, 79 | kernel_size=1, stride=stride, bias=False), 80 | nn.BatchNorm2d(planes * block.expansion), 81 | ) 82 | 83 | layers = [] 84 | layers.append(block(self.inplanes, planes, stride, downsample)) 85 | self.inplanes = planes * block.expansion 86 | for i in range(1, blocks): 87 | layers.append(block(self.inplanes, planes)) 88 | 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | x = self.conv1(x) 93 | x = self.bn1(x) 94 | x = self.relu(x) 95 | x = self.layer1(x) 96 | x = self.layer2(x) 97 | x = self.layer3(x) 98 | x = self.layer4(x) 99 | x = self.layer5(x) 100 | return x 101 | 102 | 103 | def resnet45(): 104 | return ResNet(BasicBlock, [3, 4, 6, 6, 3]) 105 | -------------------------------------------------------------------------------- /src/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch 3 | import numpy as np 4 | from typing import List 5 | from torch.utils.data import default_collate 6 | 7 | 8 | def normalize_image(image: PIL.Image): 9 | size = image.size[0] 10 | image = np.array(image.convert("RGB"), dtype=np.float32) 11 | image = image.transpose(2, 0, 1) 12 | image = image / 127.5 - 1.0 13 | return image 14 | 15 | 16 | def prepare_npy_image_mask(image: PIL.Image, mask): 17 | size = image.size[0] 18 | image = np.array(image.convert("RGB"), dtype=np.float32) 19 | image = image.transpose(2, 0, 1) 20 | image = image / 127.5 - 1.0 21 | mask = np.array(mask.convert("L")) 22 | mask = mask.astype(np.float32) / 255.0 23 | mask = mask[None] 24 | mask[mask < 0.5] = 0 25 | mask[mask >= 0.5] = 1 26 | masked_image = image * (mask < 0.5) 27 | 28 | nonzeros = mask[0].nonzero() # (2, N) 29 | minx, maxx = min(nonzeros[0], default=0), max(nonzeros[0], default=size) 30 | miny, maxy = min(nonzeros[1], default=0), max(nonzeros[1], default=size) 31 | mask_coordinate = np.array((minx, maxx, miny, maxy), dtype=np.int16) 32 | return image, mask, masked_image, mask_coordinate 33 | 34 | 35 | def char_inpaint_collate_fn(features): 36 | """this collate function concate list/set into a list instead of merging into a tensor""" 37 | feature_keys = features[0].keys() 38 | collated = {k: [] for k in feature_keys} 39 | for feature in features: 40 | for k, v in feature.items(): 41 | collated[k].append(v) 42 | for k, v in collated.items(): 43 | if not isinstance(v[0], list) and not isinstance(v[0], set): 44 | collated[k] = default_collate(v) 45 | return collated 46 | 47 | 48 | class LenCounter: 49 | def __init__(self, min_len=1, max_len=15, eachnum=10, inf=False) -> None: 50 | self.bucket = {k: eachnum for k in range(min_len, max_len + 1)} 51 | self.inf = inf 52 | 53 | def ended(self): 54 | if self.inf: 55 | return False 56 | else: 57 | return sum(list(self.bucket.values())) == 0 58 | 59 | def __call__(self, label_str): 60 | if self.inf: 61 | return True 62 | else: 63 | propose_len = len(label_str) 64 | if propose_len not in self.bucket or self.bucket[propose_len] == 0: 65 | return False # not adding anything 66 | self.bucket[propose_len] -= 1 # adding one to this bucket 67 | return True 68 | 69 | 70 | def sample_random_angle( 71 | cat_prob: List, 72 | angle_list: List, 73 | rotate_range: int, 74 | generator: torch.Generator = None, 75 | ): 76 | """Return a random angle according to the probability distribution of each category. 77 | 78 | Args: 79 | cat_prob (List): 3-element list, the probability of each category for stay/rotate in angle_list/rotate in random angle 80 | angle_list (List): possible angles for category 1 81 | rotate_range (int): maximum possible angle for category 2 82 | generator (torch.Generator, optional): let the function be deterministic. Defaults to None. 83 | """ 84 | assert len(cat_prob) == 3 85 | # sample category 86 | cat_sample = torch.rand(size=(), generator=generator).item() 87 | if cat_sample < cat_prob[0]: 88 | # no rotate 89 | angle = 0 90 | elif cat_sample < (cat_prob[0] + cat_prob[1]): 91 | # rotate in angle_list 92 | angle_list = list(angle_list) 93 | angle = angle_list[ 94 | torch.randint(0, len(angle_list), size=(), generator=generator) 95 | ] 96 | else: 97 | # rotate in random angle 98 | angle = torch.randint( 99 | -rotate_range, rotate_range, size=(), generator=generator 100 | ).item() 101 | return angle 102 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import MaskMSELoss, convert_fourchannel_unet, convert_single_cond_unet 2 | from .charencoder import * 3 | from .unet_2d_multicondition import UNet2DMultiConditionModel 4 | -------------------------------------------------------------------------------- /src/model/utils.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch 3 | from torch import Tensor 4 | import torch.nn as nn 5 | from torch.nn import Module 6 | import torch.nn.functional as F 7 | 8 | 9 | class MaskMSELoss(Module): 10 | def __init__(self, alpha=1, reduction="mean"): 11 | super().__init__() 12 | self.alpha = alpha 13 | self.reduction = reduction 14 | 15 | def forward(self, input: Tensor, target: Tensor, mask: Tensor): 16 | mask_loss = F.mse_loss( 17 | input[mask == 1], target[mask == 1], reduction="sum") 18 | non_mask_loss = F.mse_loss( 19 | input[mask == 0], target[mask == 0], reduction="sum") 20 | return (self.alpha * mask_loss + non_mask_loss) / torch.numel(mask) 21 | 22 | 23 | @torch.no_grad() 24 | def convert_fourchannel_unet(oldunet, newunet): 25 | # 1. replace conv_in weight since they have different in_channels 26 | old_conv_in = oldunet.conv_in 27 | oldunet.conv_in = newunet.conv_in 28 | # 2. put old conv_in weight back 29 | oldunet.conv_in.weight[:, :4, :, :] = old_conv_in.weight 30 | oldunet.conv_in.bias = old_conv_in.bias 31 | # 3. copy other weights to newunet 32 | convert_single_cond_unet(oldunet, newunet) 33 | 34 | 35 | @torch.no_grad() 36 | def convert_single_cond_unet(oldunet, newunet): 37 | if type(newunet.config.cross_attention_dim) == dict: 38 | #! convert old unet to MultiCondition2DUnet 39 | # 1. replace resnets and other modules not cross-attention related 40 | newunet.load_state_dict(oldunet.state_dict(), strict=False) 41 | # 2. replace old text cross-attention modules into new model attentions['text'] module 42 | if "text" in newunet.config.cross_attention_dim: 43 | for i in range(len(oldunet.down_blocks)): 44 | if hasattr(oldunet.down_blocks[i], "attentions"): 45 | newunet.down_blocks[i].attentions["text"].load_state_dict( 46 | oldunet.down_blocks[i].attentions.state_dict() 47 | ) 48 | else: 49 | newunet.down_blocks[i].load_state_dict( 50 | oldunet.down_blocks[i].state_dict() 51 | ) 52 | newunet.mid_block.attentions["text"].load_state_dict( 53 | oldunet.mid_block.attentions.state_dict() 54 | ) 55 | for i in range(len(oldunet.up_blocks)): 56 | if hasattr(oldunet.up_blocks[i], "attentions"): 57 | newunet.up_blocks[i].attentions["text"].load_state_dict( 58 | oldunet.up_blocks[i].attentions.state_dict() 59 | ) 60 | else: 61 | newunet.up_blocks[i].load_state_dict( 62 | oldunet.up_blocks[i].state_dict()) 63 | else: 64 | #! convert old unet to 2DConditionUNet 65 | newunet.load_state_dict(oldunet.state_dict(), strict=True) 66 | # 3. we don't replace attentions['char'] module because it is not present in old model 67 | -------------------------------------------------------------------------------- /src/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .datawrapper import WrappedDataModule 3 | from .vae_trainer import VAETrainer, VAEImageLogger 4 | from .inpaint_trainer import ( 5 | CharInpaintTrainer, 6 | CharInpaintImageLogger, 7 | CharInpaintModelWrapper, 8 | ) 9 | from .callbacks import * 10 | -------------------------------------------------------------------------------- /src/trainers/datawrapper.py: -------------------------------------------------------------------------------- 1 | from .utils import get_obj_from_str 2 | from torch.utils.data import DataLoader 3 | from ..dataset import char_inpaint_collate_fn, CharInpaintDataset 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.distributed as torchdist 7 | 8 | 9 | class WrappedDataModule(pl.LightningDataModule): 10 | def __init__(self, data_config, **kwargs): 11 | super().__init__() 12 | self.save_hyperparameters() 13 | self.config = data_config 14 | self.batch_size = data_config.batch_size 15 | 16 | def setup(self, stage: str): 17 | if stage == "fit": 18 | self.train = CharInpaintDataset(self.config.train) 19 | self.val = CharInpaintDataset(self.config.validation) 20 | if stage == "test" or stage == "predict": 21 | self.val = CharInpaintDataset(self.config.test) 22 | 23 | def train_dataloader(self): 24 | return DataLoader( 25 | self.train, 26 | batch_size=self.batch_size, 27 | shuffle=True, 28 | num_workers=4, 29 | collate_fn=char_inpaint_collate_fn, 30 | ) 31 | 32 | def val_dataloader(self): 33 | return DataLoader( 34 | self.val, 35 | batch_size=self.batch_size, 36 | shuffle=True, 37 | num_workers=4, 38 | collate_fn=char_inpaint_collate_fn, 39 | ) 40 | -------------------------------------------------------------------------------- /src/trainers/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import importlib 5 | import torch.nn as nn 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 8 | from pytorch_lightning import Trainer, Callback 9 | from pytorch_lightning.utilities import rank_zero_only 10 | 11 | 12 | def count_params(model): 13 | total_params = sum(p.numel() for p in model.parameters()) 14 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 15 | 16 | 17 | def get_obj_from_str(string, reload=False): 18 | module, cls = string.rsplit(".", 1) 19 | if reload: 20 | module_imp = importlib.import_module(module) 21 | importlib.reload(module_imp) 22 | return getattr(importlib.import_module(module, package=None), cls) 23 | 24 | 25 | def instantiate_from_config(config): 26 | if not "target" in config: 27 | raise KeyError("Expected key `target` to instantiate.") 28 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 29 | 30 | 31 | def nondefault_trainer_args(opt): 32 | parser = argparse.ArgumentParser() 33 | parser = Trainer.add_argparse_args(parser) 34 | args = parser.parse_args([]) 35 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) 36 | 37 | 38 | def module_requires_grad(module: nn.Module): 39 | for param in module.parameters(): 40 | if not param.requires_grad: 41 | return False 42 | return True 43 | 44 | @rank_zero_only 45 | def pl_on_train_tart(pl_module: pl.LightningModule): 46 | wandb_logger = pl_module.logger.experiment 47 | if isinstance(wandb_logger, WandbLogger): 48 | print("Logging code") 49 | wandb_logger.log_code( 50 | os.getcwd(), 51 | include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb") or path.endswith(".yaml") 52 | ) 53 | elif isinstance(wandb_logger, TensorBoardLogger): 54 | print("Logging git info") 55 | wandb_logger.log_hyperparams({"git_version": os.popen("git log -1").read().split("\n")[0]}) 56 | 57 | print("***** Start training *****") 58 | num_samples = len(pl_module.trainer.train_dataloader.dataset) 59 | max_epoch = pl_module.trainer.max_epochs 60 | total_step = pl_module.trainer.estimated_stepping_batches 61 | total_batch_size = round(num_samples * max_epoch / total_step) 62 | print(f" Num examples = {num_samples}") 63 | print(f" Num Epochs = {max_epoch}") 64 | print(f" Total GPU device number: {pl_module.trainer.num_devices}") 65 | print(f" Gradient Accumulation steps = {pl_module.trainer.accumulate_grad_batches}") 66 | #? this seems to be not right 67 | print(f" Instant batch size: {round(total_batch_size * pl_module.trainer.num_devices)}") 68 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 69 | print(f" Total optimization steps = {total_step}") 70 | -------------------------------------------------------------------------------- /src/trainers/vae_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Any, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim import Optimizer 7 | from torchvision.utils import make_grid 8 | from einops import rearrange 9 | 10 | import pytorch_lightning as pl 11 | from pytorch_lightning import Callback 12 | from pytorch_lightning.utilities import rank_zero_only 13 | from pytorch_lightning.utilities.types import STEP_OUTPUT 14 | 15 | from .utils import count_params, pl_on_train_tart 16 | from diffusers import AutoencoderKL 17 | 18 | 19 | class VAETrainer(pl.LightningModule): 20 | def __init__(self, config): 21 | super().__init__() 22 | self.save_hyperparameters() 23 | self.config = config 24 | pretrained_model_path = config.pretrained_model_path 25 | self.vae = AutoencoderKL.from_pretrained( 26 | pretrained_model_path, subfolder="vae") 27 | count_params(self.vae) 28 | if config.precision == 16: 29 | self.data_dtype = torch.float16 30 | elif config.precision == 32: 31 | self.data_dtype = torch.float32 32 | self.loss_fn = nn.MSELoss() 33 | 34 | @rank_zero_only 35 | def on_train_start(self): 36 | pl_on_train_tart(self) 37 | 38 | def forward(self, batch, mode="train"): 39 | image = batch["image"].to(self.vae.dtype) 40 | latents = self.vae.encode(image).latent_dist.sample() 41 | latents = latents * 0.18215 42 | latents = 1.0 / 0.18215 * latents # introduce analytical error 43 | decoded = self.vae.decode(latents).sample 44 | loss = self.loss_fn(image, decoded) 45 | if mode == "train": 46 | return loss 47 | elif mode == "validation": 48 | return loss, decoded 49 | 50 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 51 | loss, decoded = self(batch, mode="validation") 52 | self.log( 53 | "train/loss", 54 | loss, 55 | batch_size=len(batch["image"]), 56 | prog_bar=True, 57 | sync_dist=True, 58 | ) 59 | return {"loss": loss, "image": batch["image"], "decoded": decoded} 60 | 61 | def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: 62 | loss, decoded = self(batch, mode="validation") 63 | self.log( 64 | "validation/loss", 65 | loss, 66 | batch_size=len(batch["image"]), 67 | prog_bar=True, 68 | sync_dist=True, 69 | ) 70 | return {"loss": loss, "image": batch["image"], "decoded": decoded} 71 | 72 | def configure_optimizers(self) -> Any: 73 | lr = self.learning_rate 74 | params = [{"params": self.vae.parameters()}] 75 | print( 76 | f"Initialize optimizer with: lr: {lr}, weight_decay: {self.config.weight_decay}, eps: {self.config.adam_epsilon}" 77 | ) 78 | opt = torch.optim.AdamW( 79 | params, 80 | lr=lr, 81 | weight_decay=self.config.weight_decay, 82 | eps=self.config.adam_epsilon, 83 | ) 84 | return opt 85 | 86 | def configure_gradient_clipping( 87 | self, 88 | optimizer: Optimizer, 89 | optimizer_idx: int, 90 | gradient_clip_val: Optional[Union[int, float]] = None, 91 | gradient_clip_algorithm: Optional[str] = None, 92 | ) -> None: 93 | self.clip_gradients( 94 | optimizer, 95 | gradient_clip_val=gradient_clip_val, 96 | gradient_clip_algorithm=gradient_clip_algorithm, 97 | ) 98 | 99 | 100 | class VAEImageLogger(Callback): 101 | def __init__(self, train_batch_frequency, val_batch_frequency): 102 | super().__init__() 103 | self.batch_freq = { 104 | "train": train_batch_frequency, 105 | "validation": val_batch_frequency, 106 | } 107 | 108 | @rank_zero_only 109 | def _wandb_image(self, pl_module, results, batch_idx, split): 110 | print(f"Log images to wandb at: {split}/{batch_idx}") 111 | raw_image = results["image"] 112 | recontstruct_image = results["decoded"] 113 | batch_size = raw_image.shape[0] 114 | with torch.no_grad(): 115 | # 2 x B x 3 x H x W 116 | grids = torch.stack([raw_image, recontstruct_image]) 117 | grids = torch.clamp((grids + 1.0) / 2.0, min=0.0, max=1.0) 118 | grids = rearrange(grids, "g b c h w -> c (g h) (b w)", g=2) 119 | # 4 pairs in a group 120 | split = batch_size // 4 121 | groups = torch.tensor_split(grids, split, dim=2) 122 | 123 | def reshape_for_grid(group): return rearrange( 124 | group, "c (g h) (b w) -> (g b) c h w", g=2, b=4 125 | ) 126 | groups = [make_grid(reshape_for_grid(group), nrow=4) 127 | for group in groups] 128 | pl_module.logger.log_image( 129 | key=f"image-{split}/{batch_idx}", 130 | images=groups, 131 | step=batch_idx, 132 | ) 133 | 134 | def check_freq(self, batch_dix, split="train"): 135 | return (batch_dix + 1) % self.batch_freq[split] == 0 136 | 137 | def on_train_batch_end( 138 | self, 139 | trainer: "pl.Trainer", 140 | pl_module: "pl.LightningModule", 141 | outputs: STEP_OUTPUT, 142 | batch: Any, 143 | batch_idx: int, 144 | ) -> None: 145 | split = "train" 146 | if self.check_freq(batch_idx, split=split): 147 | self._wandb_image(pl_module, outputs, batch_idx, split=split) 148 | 149 | def on_validation_batch_end( 150 | self, 151 | trainer: "pl.Trainer", 152 | pl_module: "pl.LightningModule", 153 | outputs: Optional[STEP_OUTPUT], 154 | batch: Any, 155 | batch_idx: int, 156 | dataloader_idx: int, 157 | ) -> None: 158 | split = "validation" 159 | if self.check_freq(batch_idx, split=split): 160 | self._wandb_image(pl_module, outputs, batch_idx, split=split) 161 | -------------------------------------------------------------------------------- /synthgenerator/README.md: -------------------------------------------------------------------------------- 1 | # Synthtic Data Generator 2 | 3 | This is the synthetic data generator based on [SynthTiger](https://github.com/clovaai/synthtiger). 4 | 5 | Install `synthtiger` 6 | ```bash 7 | pip install synthtiger 8 | ``` 9 | 10 | Generate synthtic data by: 11 | ```bash 12 | synthtiger -o $outdir -w 8 synth_template.py SynthForCharDiffusion $config_file --count $max_num_of_samples 13 | ``` 14 | 15 | NOTICE: Please download background images from https://github.com/ankush-me/SynthText and put them in `ocr-dataset/SynthText/bg_data` and fonts data from google fonts and put them in `synthgenerator/resources/100fonts`. -------------------------------------------------------------------------------- /synthgenerator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSB-NLP-Chang/DiffSTE/b6421cd491e8d22c4b35124d91193573f5b117f8/synthgenerator/__init__.py -------------------------------------------------------------------------------- /synthgenerator/generate_synth.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process, Queue 2 | from omegaconf import OmegaConf 3 | from synthtiger.gen import _task_generator, _run, _generate 4 | from synthtiger.main import parse_args 5 | import time 6 | import pprint 7 | import importlib 8 | import itertools 9 | import synthtiger 10 | 11 | 12 | def read_template(path, name, config): 13 | template_cls = get_obj(path + "." + name, reload=True) 14 | template = template_cls(config=config) 15 | return template 16 | 17 | 18 | synthtiger.read_template = read_template 19 | synthtiger.gen.read_template = read_template 20 | 21 | 22 | def get_obj(string, reload=False): 23 | module, cls = string.rsplit(".", 1) 24 | if reload: 25 | module_imp = importlib.import_module(module) 26 | importlib.reload(module_imp) 27 | return getattr(importlib.import_module(module, package=None), cls) 28 | 29 | 30 | def run(args): 31 | if args.config is not None: 32 | config = OmegaConf.load(args.config) 33 | OmegaConf.resolve(config) 34 | 35 | if args.seed is None: 36 | if 'seed' not in config: 37 | args.seed = 42 38 | else: 39 | args.seed = config['seed'] 40 | if 'seed' not in config: 41 | config['seed'] = args.seed 42 | 43 | pprint.pprint(config) 44 | synthtiger.set_global_random_seed(args.seed) 45 | template = read_template(args.script, args.name, config) 46 | generator = synthtiger.generator( 47 | args.script, 48 | args.name, 49 | config=config, 50 | count=args.count, 51 | worker=args.worker, 52 | seed=args.seed, 53 | retry=True, 54 | verbose=args.verbose, 55 | ) 56 | 57 | if args.output is not None: 58 | template.init_save(args.output) 59 | 60 | from tqdm import tqdm 61 | for idx, (task_idx, data) in tqdm(enumerate(generator)): 62 | if args.output is not None: 63 | template.save(args.output, data, task_idx) 64 | # print(f"Generated {idx + 1} data (task {task_idx})") 65 | 66 | if args.output is not None: 67 | template.end_save(args.output) 68 | 69 | 70 | def main(): 71 | start_time = time.time() 72 | args = parse_args() 73 | run(args) 74 | end_time = time.time() 75 | print(f"{end_time - start_time:.2f} seconds elapsed") 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /synthgenerator/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.6.3 2 | numpy==1.24.1 3 | omegaconf==2.3.0 4 | Pillow==9.4.0 5 | synthtiger==1.2.1 6 | tqdm==4.64.1 7 | pandas==1.5.3 -------------------------------------------------------------------------------- /synthgenerator/resources/100fonts/chosen_fonts.json: -------------------------------------------------------------------------------- 1 | ["mina", "imfellenglish", "zillaslabhighlight", "scopeone", "nuosusil", "unna", "sairaextracondensed", "azeretmono", "shadowsintolighttwo", "moul", "imfelldoublepica", "armata", "frankruhllibre", "hennypenny", "belleza", "leaguegothic", "pacifico", "condiment", "ribeyemarrow", "arizonia", "notosanshanifirohingya", "battambang", "inconsolata", "kantumruypro", "breeserif", "sulphurpoint", "notorashihebrew", "macondo", "elmessiri", "zenkurenaido", "novaround", "redhatmono", "markazitext", "synetactile", "windsong", "sahitya", "londrinasketch", "zenmarugothic", "amethysta", "kellyslab", "mrdehaviland", "domine", "kavoon", "sora", "caladea", "palanquin", "overlock", "explora", "meowscript", "suwannaphum", "ramaraja", "sendflowers", "anekgurmukhi", "spectral", "cabin", "palanquindark", "manuale", "petitformalscript", "poppins", "galada", "notosanstagalog", "tiltneon", "lancelot", "notosansmiao", "convergence", "montserratsubrayada", "mplus1code", "daysone", "splinesansmono", "zeyada", "trirong", "pridi", "notoserifnphmong", "paprika", "pavanam", "exo2", "flowblock", "josefinsans", "imfellgreatprimer", "silkscreen", "bigshouldersinlinetext", "solway", "climatecrisis", "atkinsonhyperlegible", "sourcesanspro", "missfajardose", "redhatdisplay", "justmeagaindownhere", "edunswactfoundation", "loversquarrel", "mergeone", "ranchers", "shortstack", "fasterone", "imprima", "thasadith", "bilboswashcaps", "librebaskerville", "epilogue", "gidugu"] -------------------------------------------------------------------------------- /synthgenerator/resources/charset/alphanum.txt: -------------------------------------------------------------------------------- 1 | 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz -------------------------------------------------------------------------------- /synthgenerator/resources/charset/alphanum_lower.txt: -------------------------------------------------------------------------------- 1 | 0123456789abcdefghijklmnopqrstuvwxyz -------------------------------------------------------------------------------- /synthgenerator/resources/charset/alphanum_special.txt: -------------------------------------------------------------------------------- 1 | !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ -------------------------------------------------------------------------------- /synthgenerator/synthgen_config.yaml: -------------------------------------------------------------------------------- 1 | coord_output: true 2 | mask_output: true 3 | glyph_coord_output: true 4 | glyph_mask_output: true 5 | 6 | vertical: false 7 | 8 | BASE_DIR: ./ 9 | RESOURCE_DIR: ${BASE_DIR}/resources 10 | visibility_check: true 11 | 12 | level: 2 13 | 14 | corpus: 15 | paths: ["${RESOURCE_DIR}/corpus/wikicorpus.json"] 16 | weights: [1] 17 | min_length: 1 18 | max_length: 25 19 | textcase: [lower, upper, capitalize] 20 | 21 | word_sampler: 22 | max_num: 5 23 | max_len: 16 24 | corpus_conf: 25 | ${corpus} 26 | 27 | font: 28 | paths: ["${RESOURCE_DIR}/100fonts"] 29 | weights: [1] 30 | size: [30, 80] 31 | bold: 0.0 32 | mode: same 33 | 34 | named_color: 35 | mode: same 36 | pallate_name: XKCD 37 | 38 | color: 39 | rgb: [[0, 256], [0, 256], [0, 256]] 40 | alpha: [1, 1] 41 | grayscale: 0 42 | 43 | colormap2: 44 | paths: ["${RESOURCE_DIR}/colormap/iiit5k_gray.txt"] 45 | weights: [1] 46 | k: 2 47 | alpha: [1, 1] 48 | colorize: 1 49 | 50 | colormap3: 51 | paths: ["${RESOURCE_DIR}/colormap/iiit5k_gray.txt"] 52 | weights: [1] 53 | k: 3 54 | alpha: [1, 1] 55 | colorize: 1 56 | 57 | shape: 58 | prob: 0 59 | args: 60 | weights: [1, 1] 61 | args: 62 | # elastic distortion 63 | - alpha: [15, 30] 64 | sigma: [4, 12] 65 | # elastic distortion 66 | - alpha: [0, 2] 67 | sigma: [0, 0.6] 68 | 69 | style: 70 | prob: 0 71 | args: 72 | weights: [1, 0, 0] 73 | args: 74 | # text border 75 | - size: [1, 5] 76 | alpha: [1, 1] 77 | grayscale: 0 78 | # text shadow 79 | - distance: [2, 4] 80 | angle: [0, 360] 81 | alpha: [0.3, 0.7] 82 | grayscale: 0 83 | # text extrusion 84 | - length: [2, 5] 85 | angle: [0, 360] 86 | rgb: [[0, 0], [0, 0], [0, 0]] 87 | grayscale: 0 88 | 89 | pad: 90 | prob: 1 91 | args: 92 | pxs: [[2, 10], [2, 10], [2, 10], [2, 10]] 93 | 94 | texture: 95 | prob: 1. 96 | args: 97 | paths: ["${BASE_DIR}/../ocr-dataset/SynthText/bg_data/bg_img"] 98 | weights: [1] 99 | alpha: [0.7, 1] 100 | grayscale: 0 101 | crop: 1 102 | 103 | layout: 104 | space: [4, 8] 105 | line_space: [1, 5] 106 | align: [left, "center", "right"] 107 | line_align: [left] 108 | ltr: true 109 | ttb: true 110 | vertical: false 111 | 112 | text_layout: 113 | space: [5, 10] 114 | line_space: [3, 10] 115 | align: [left] 116 | line_align: [top, bottom, middle] 117 | ltr: true 118 | ttb: true 119 | vertical: true 120 | 121 | transform: 122 | prob: 0.8 123 | args: 124 | # weights: [1, 1, 1, 1, 1, 1, 1., 1.] 125 | weights: [0, 0, 0, 0, 0, 0, 1., 1.] 126 | args: 127 | # perspective x 128 | - percents: [[0.5, 1], [1, 1]] 129 | aligns: [[0, 0], [0, 0]] 130 | # perspective y 131 | - percents: [[1, 1], [0.5, 1]] 132 | aligns: [[0, 0], [0, 0]] 133 | # trapezoidate x # don't use 134 | - weights: [1, 0, 1, 0] 135 | percent: [0.8, 1] 136 | align: [-1, 1] 137 | # trapezoidate y # don't use 138 | - weights: [0, 1, 0, 1] 139 | percent: [0.5, 1] 140 | align: [-1, 1] 141 | # skew x 142 | - weights: [1, 0] 143 | angle: [0, 20] 144 | ccw: 0.5 145 | # skew y 146 | - weights: [0, 1] 147 | angle: [0, 20] 148 | ccw: 0.5 149 | # rotate 150 | - angle: [0, 30] 151 | ccw: 0.5 152 | - angle: [-90, -15, 15, 90] 153 | ccw: 0.5 154 | 155 | 156 | postprocess: 157 | args: 158 | # gaussian noise 159 | - prob: 0.5 160 | args: 161 | scale: [0, 0.5] 162 | per_channel: 0 163 | # gaussian blur 164 | - prob: 0.0 165 | args: 166 | sigma: [0, 1] 167 | - prob: 0.0 168 | args: 169 | k: [1, 1] -------------------------------------------------------------------------------- /synthgenerator/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | from synthtiger import utils 5 | 6 | 7 | def expand2square(pil_img, background_color): 8 | width, height = pil_img.size 9 | if width == height: 10 | return pil_img 11 | elif width > height: 12 | result = Image.new(pil_img.mode, (width, width), background_color) 13 | result.paste(pil_img, (0, (width - height) // 2)) 14 | return result 15 | else: 16 | result = Image.new(pil_img.mode, (height, height), background_color) 17 | result.paste(pil_img, ((height - width) // 2, 0)) 18 | return result 19 | 20 | 21 | def _create_poly_mask(image, pad=0): 22 | height, width = image.shape[:2] 23 | alpha = image[..., 3].astype(np.uint8) 24 | mask = np.zeros((height, width), dtype=np.float32) 25 | cts, _ = cv2.findContours( 26 | alpha, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 27 | cts = sorted(cts, key=lambda ct: sum(cv2.boundingRect(ct)[:2])) 28 | if len(cts) == 1: 29 | hull = cv2.convexHull(cts[0]) 30 | cv2.fillConvexPoly(mask, hull, 255) 31 | for idx in range(len(cts) - 1): 32 | pts = np.concatenate((cts[idx], cts[idx + 1]), axis=0) 33 | hull = cv2.convexHull(pts) 34 | cv2.fillConvexPoly(mask, hull, 255) 35 | mask = utils.dilate_image(mask, pad) 36 | out = utils.create_image((width, height)) 37 | out[..., 3] = mask 38 | return out 39 | 40 | 41 | BLEND_MODES = [ 42 | # "normal", 43 | # "hard_light", 44 | # "soft_light" 45 | # "overlay" 46 | # "multiply" 47 | # "normal", 48 | # "overlay", 49 | # "screen", 50 | # "darken_only" 51 | # "lighten_only", 52 | 53 | "normal", 54 | # "overlay", "multiply", "screen", "overlay", "hard_light", "soft_light", 55 | # "dodge", "divide", "addition", "difference", "darken_only", "lighten_only", 56 | ] 57 | 58 | 59 | def _blend_images(src, dst, visibility_check=False): 60 | blend_modes = np.random.permutation(BLEND_MODES) 61 | # print(blend_modes) 62 | for blend_mode in blend_modes: 63 | out = utils.blend_image(src, dst, mode=blend_mode) 64 | if not visibility_check or _check_visibility(out, src[..., 3]): 65 | break 66 | else: 67 | raise RuntimeError("Text is not visible") 68 | return out 69 | 70 | 71 | def _check_visibility(image, mask): 72 | gray = utils.to_gray(image[..., :3]).astype(np.uint8) 73 | mask = mask.astype(np.uint8) 74 | height, width = mask.shape 75 | peak = (mask > 127).astype(np.uint8) 76 | kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) 77 | bound = (mask > 0).astype(np.uint8) 78 | bound = cv2.dilate(bound, kernel, iterations=1) 79 | visit = bound.copy() 80 | visit ^= 1 81 | visit = np.pad(visit, 1, constant_values=1) 82 | border = bound.copy() 83 | border[mask > 0] = 0 84 | flag = 4 | cv2.FLOODFILL_FIXED_RANGE | cv2.FLOODFILL_MASK_ONLY 85 | for y in range(height): 86 | for x in range(width): 87 | if peak[y][x]: 88 | cv2.floodFill(gray, visit, (x, y), 1, 16, 16, flag) 89 | visit = visit[1:-1, 1:-1] 90 | count = np.sum(visit & border) 91 | total = np.sum(border) 92 | return total > 0 and count <= total * 0.1 93 | -------------------------------------------------------------------------------- /tools/create_mask/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Mask Creating Tool 3 | 4 | This tool processes images for creating DiffSTE samples using the Doctr library. It detects text regions in an image, randomly crops these regions, and saves the cropped images along with their masks. 5 | 6 | ## Features 7 | 8 | - Allows configuration of various parameters via command-line arguments. 9 | - Well suited for documents. 10 | 11 | ## Requirements 12 | 13 | - Python (tested on 3.10) 14 | - OpenCV 15 | - NumPy 16 | - Matplotlib 17 | - Doctr 18 | 19 | ## Setup 20 | 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Usage 26 | 27 | Run the script with the required parameters: 28 | 29 | ```bash 30 | python main.py --input_image /path/to/image.jpg --output_dir /path/to/output/dir 31 | ``` 32 | 33 | You can also specify additional parameters as needed: 34 | 35 | ```bash 36 | python main.py --input_image /path/to/image.jpg --output_dir /path/to/output/dir --arch db_resnet50 --batch_size 4 --bin_thresh 0.5 --device cuda 37 | ``` 38 | 39 | ### Command-Line Arguments 40 | 41 | - `--input_image`: Path to the input image (required). 42 | - `--output_dir`: Directory to save the output images (required). 43 | - `--arch`: Model architecture (default: `db_resnet50`). 44 | - `--pretrained`: Use pretrained model (default: `True`). 45 | - `--pretrained_backbone`: Use pretrained backbone (default: `True`). 46 | - `--batch_size`: Batch size for the model (default: `2`). 47 | - `--assume_straight_pages`: Assume straight pages (default: `False`). 48 | - `--preserve_aspect_ratio`: Preserve aspect ratio (default: `False`). 49 | - `--symmetric_pad`: Use symmetric padding (default: `True`). 50 | - `--bin_thresh`: Binarization threshold for postprocessor (default: `0.3`). 51 | - `--device`: Device to use for computation, e.g., "cuda" or "cpu" (default: `cuda`). 52 | 53 | ## Acknowledgements 54 | 55 | This tool uses the [Doctr](https://github.com/mindee/doctr) project for document text detection. 56 | 57 | ### Citation 58 | 59 | ```bibtex 60 | @misc{doctr2021, 61 | title={docTR: Document Text Recognition}, 62 | author={Mindee}, 63 | year={2021}, 64 | publisher = {GitHub}, 65 | howpublished = {\url{https://github.com/mindee/doctr}} 66 | } 67 | ``` 68 | 69 | ## Contributing 70 | 71 | Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes. 72 | -------------------------------------------------------------------------------- /tools/create_mask/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import cv2 5 | import numpy as np 6 | from matplotlib import pyplot as plt 7 | 8 | from doctr.io import DocumentFile # type: ignore 9 | from doctr.models import detection_predictor # type: ignore 10 | 11 | 12 | def plot(image, si=[12, 12]): 13 | fig, ax = plt.subplots(figsize=si) 14 | ax.imshow(image, cmap='gray') 15 | ax.get_xaxis().set_visible(False) 16 | ax.get_yaxis().set_visible(False) 17 | plt.show() 18 | 19 | 20 | def crop(img, cxcy, crop_size=256): 21 | H, W, _ = img.shape 22 | cx, cy = cxcy 23 | half_crop_size = crop_size // 2 24 | 25 | # Calculate the top-left and bottom-right corners of the crop 26 | x1 = max(cx - half_crop_size, 0) 27 | y1 = max(cy - half_crop_size, 0) 28 | x2 = min(cx + half_crop_size, W) 29 | y2 = min(cy + half_crop_size, H) 30 | 31 | # Adjust the corners if the crop size is smaller than 255x255 pixels 32 | if x2 - x1 < crop_size: 33 | if x1 == 0: 34 | x2 = min(crop_size, W) 35 | else: 36 | x1 = max(x2 - crop_size, 0) 37 | 38 | if y2 - y1 < crop_size: 39 | if y1 == 0: 40 | y2 = min(crop_size, H) 41 | else: 42 | y1 = max(y2 - crop_size, 0) 43 | 44 | # Crop the image 45 | cropped_image = img[y1:y2, x1:x2] 46 | 47 | # If the crop is smaller than 255x255, pad it 48 | if cropped_image.shape[0] < crop_size or cropped_image.shape[1] < crop_size: 49 | padded_image = np.zeros((crop_size, crop_size, 3), dtype=np.uint8) 50 | padded_image[:cropped_image.shape[0], :cropped_image.shape[1]] = cropped_image 51 | cropped_image = padded_image 52 | 53 | return cropped_image 54 | 55 | 56 | def main(args): 57 | 58 | det_predictor = detection_predictor( 59 | arch=args.arch, 60 | pretrained=True, 61 | pretrained_backbone=True, 62 | batch_size=1, 63 | assume_straight_pages=False, 64 | preserve_aspect_ratio=False, 65 | symmetric_pad=True 66 | ) 67 | 68 | if args.device == "cuda": 69 | det_predictor = det_predictor.cuda().half() 70 | 71 | det_predictor.model.postprocessor.bin_thresh = args.bin_thresh 72 | 73 | input = DocumentFile.from_images(args.input_image) 74 | img = cv2.imread(args.input_image) 75 | H, W, _ = img.shape 76 | print(f"{H}, {W}") 77 | 78 | bboxes = det_predictor(input, return_maps=False)[0]['words'] 79 | bboxes[:, :, 0] *= W 80 | bboxes[:, :, 1] *= H 81 | bboxes = bboxes.astype(np.int32) 82 | 83 | idx = np.random.randint(0, len(bboxes), size=1).item() 84 | bbox = bboxes[idx] 85 | cxcy = np.mean(bbox, axis=0, dtype=np.int32) 86 | 87 | mask = np.zeros_like(img, dtype=np.uint8) 88 | cv2.fillConvexPoly(mask, bbox, [255, 255, 255]) 89 | 90 | imgc = crop(img, cxcy) 91 | maskc = crop(mask, cxcy) 92 | 93 | plot(mask) 94 | 95 | cv2.imwrite(os.path.join(args.output_dir, "sample.png"), imgc) 96 | cv2.imwrite(os.path.join(args.output_dir, "mask.png"), maskc) 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser(description='Process images for OCR using Doctr.') 101 | parser.add_argument('--input_image', type=str, required=True, help='Path to the input image') 102 | parser.add_argument('--output_dir', type=str, required=True, help='Directory to save the output images') 103 | parser.add_argument('--arch', type=str, default='db_resnet50', help='Model architecture') 104 | parser.add_argument('--batch_size', type=int, default=2, help='Batch size for the model') 105 | parser.add_argument('--bin_thresh', type=float, default=0.3, help='Binarization threshold for postprocessor') 106 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for computation (e.g., "cuda" or "cpu")') 107 | 108 | args = parser.parse_args() 109 | main(args) 110 | -------------------------------------------------------------------------------- /tools/create_mask/requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.9.0.80 2 | numpy==1.26.4 3 | matplotlib==3.8.2 4 | python-doctr[torch] --------------------------------------------------------------------------------