├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── docs └── source │ └── imgs │ └── diffusers_library.jpg ├── examples ├── README.md ├── experimental │ ├── train_glide_text_to_image.py │ └── train_latent_text_to_image.py └── train_unconditional.py ├── pyproject.toml ├── run.py ├── scripts ├── conversion_bddm.py ├── conversion_glide.py └── conversion_ldm_uncond.py ├── setup.cfg ├── setup.py ├── src └── diffusers │ ├── __init__.py │ ├── configuration_utils.py │ ├── dependency_versions_check.py │ ├── dependency_versions_table.py │ ├── dynamic_modules_utils.py │ ├── hub_utils.py │ ├── modeling_utils.py │ ├── models │ ├── README.md │ ├── __init__.py │ ├── attention.py │ ├── embeddings.py │ ├── resnet.py │ ├── unet.py │ ├── unet_glide.py │ ├── unet_grad_tts.py │ ├── unet_ldm.py │ ├── unet_new.py │ ├── unet_rl.py │ ├── unet_sde_score_estimation.py │ ├── unet_unconditional.py │ └── vae.py │ ├── optimization.py │ ├── pipeline_utils.py │ ├── pipelines │ ├── README.md │ ├── __init__.py │ ├── bddm │ │ ├── __init__.py │ │ └── pipeline_bddm.py │ ├── ddim │ │ ├── __init__.py │ │ └── pipeline_ddim.py │ ├── ddpm │ │ ├── __init__.py │ │ └── pipeline_ddpm.py │ ├── glide │ │ ├── __init__.py │ │ └── pipeline_glide.py │ ├── grad_tts │ │ ├── __init__.py │ │ ├── grad_tts_utils.py │ │ └── pipeline_grad_tts.py │ ├── latent_diffusion │ │ ├── __init__.py │ │ └── pipeline_latent_diffusion.py │ ├── latent_diffusion_uncond │ │ ├── __init__.py │ │ └── pipeline_latent_diffusion_uncond.py │ ├── pndm │ │ ├── __init__.py │ │ └── pipeline_pndm.py │ ├── score_sde_ve │ │ ├── __init__.py │ │ └── pipeline_score_sde_ve.py │ └── score_sde_vp │ │ ├── __init__.py │ │ └── pipeline_score_sde_vp.py │ ├── schedulers │ ├── README.md │ ├── __init__.py │ ├── scheduling_ddim.py │ ├── scheduling_ddpm.py │ ├── scheduling_grad_tts.py │ ├── scheduling_pndm.py │ ├── scheduling_sde_ve.py │ ├── scheduling_sde_vp.py │ └── scheduling_utils.py │ ├── testing_utils.py │ ├── training_utils.py │ └── utils │ ├── __init__.py │ ├── dummy_transformers_and_inflect_and_unidecode_objects.py │ ├── dummy_transformers_objects.py │ ├── logging.py │ └── model_card_template.md ├── tests ├── __init__.py ├── test_layers_utils.py ├── test_modeling_utils.py └── test_scheduler.py └── utils ├── check_config_docstrings.py ├── check_copies.py ├── check_dummies.py ├── check_inits.py ├── check_repo.py ├── check_table.py ├── check_tf_ops.py └── custom_init_isort.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # tests and logs 12 | tests/fixtures/cached_*_text.txt 13 | logs/ 14 | lightning_logs/ 15 | lang_code_data/ 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | # vscode 125 | .vs 126 | .vscode 127 | 128 | # Pycharm 129 | .idea 130 | 131 | # TF code 132 | tensorflow_code 133 | 134 | # Models 135 | proc_data 136 | 137 | # examples 138 | runs 139 | /runs_old 140 | /wandb 141 | /examples/runs 142 | /examples/**/*.args 143 | /examples/rag/sweep 144 | 145 | # data 146 | /data 147 | serialization_dir 148 | 149 | # emacs 150 | *.*~ 151 | debug.env 152 | 153 | # vim 154 | .*.swp 155 | 156 | #ctags 157 | tags 158 | 159 | # pre-commit 160 | .pre-commit* 161 | 162 | # .lock 163 | *.lock 164 | 165 | # DS_Store (MacOS) 166 | .DS_Store -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := examples tests src utils 7 | 8 | modified_only_fixup: 9 | $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) 10 | @if test -n "$(modified_py_files)"; then \ 11 | echo "Checking/fixing $(modified_py_files)"; \ 12 | black --preview $(modified_py_files); \ 13 | isort $(modified_py_files); \ 14 | flake8 $(modified_py_files); \ 15 | else \ 16 | echo "No library .py files were modified"; \ 17 | fi 18 | 19 | # Update src/diffusers/dependency_versions_table.py 20 | 21 | deps_table_update: 22 | @python setup.py deps_table_update 23 | 24 | deps_table_check_updated: 25 | @md5sum src/diffusers/dependency_versions_table.py > md5sum.saved 26 | @python setup.py deps_table_update 27 | @md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1) 28 | @rm md5sum.saved 29 | 30 | # autogenerating code 31 | 32 | autogenerate_code: deps_table_update 33 | 34 | # Check that the repo is in a good state 35 | 36 | repo-consistency: 37 | python utils/check_dummies.py 38 | python utils/check_repo.py 39 | python utils/check_inits.py 40 | 41 | # this target runs checks on all files 42 | 43 | quality: 44 | black --check --preview $(check_dirs) 45 | isort --check-only $(check_dirs) 46 | flake8 $(check_dirs) 47 | doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source 48 | 49 | # Format source code automatically and check is there are any problems left that need manual fixing 50 | 51 | extra_style_checks: 52 | python utils/custom_init_isort.py 53 | doc-builder style src/diffusers docs/source --max_len 119 --path_to_docs docs/source 54 | 55 | # this target runs checks on all files and potentially modifies some of them 56 | 57 | style: 58 | black --preview $(check_dirs) 59 | isort $(check_dirs) 60 | ${MAKE} autogenerate_code 61 | ${MAKE} extra_style_checks 62 | 63 | # Super fast fix and check target that only works on relevant modified files since the branch was made 64 | 65 | fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency 66 | 67 | # Make marked copies of snippets of codes conform to the original 68 | 69 | fix-copies: 70 | python utils/check_dummies.py --fix_and_overwrite 71 | 72 | # Run tests for the library 73 | 74 | test: 75 | python -m pytest -n auto --dist=loadfile -s -v ./tests/ 76 | 77 | # Run tests for examples 78 | 79 | test-examples: 80 | python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/ 81 | 82 | # Run tests for SageMaker DLC release 83 | 84 | test-sagemaker: # install sagemaker dependencies in advance with pip install .[sagemaker] 85 | TEST_SAGEMAKER=True python -m pytest -n auto -s -v ./tests/sagemaker 86 | 87 | 88 | # Release stuff 89 | 90 | pre-release: 91 | python utils/release.py 92 | 93 | pre-patch: 94 | python utils/release.py --patch 95 | 96 | post-release: 97 | python utils/release.py --post_release 98 | 99 | post-patch: 100 | python utils/release.py --post_release --patch 101 | -------------------------------------------------------------------------------- /docs/source/imgs/diffusers_library.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/diffusers_all/c8c0c0e846c8afc07602c44180278a2f7f15331d/docs/source/imgs/diffusers_library.jpg -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## Training examples 2 | 3 | ### Unconditional Flowers 4 | 5 | The command to train a DDPM UNet model on the Oxford Flowers dataset: 6 | 7 | ```bash 8 | python -m torch.distributed.launch \ 9 | --nproc_per_node 4 \ 10 | train_unconditional.py \ 11 | --dataset="huggan/flowers-102-categories" \ 12 | --resolution=64 \ 13 | --output_dir="flowers-ddpm" \ 14 | --batch_size=16 \ 15 | --num_epochs=100 \ 16 | --gradient_accumulation_steps=1 \ 17 | --lr=1e-4 \ 18 | --warmup_steps=500 \ 19 | --mixed_precision=no 20 | ``` 21 | 22 | A full training run takes 2 hours on 4xV100 GPUs. 23 | 24 | 25 | 26 | 27 | ### Unconditional Pokemon 28 | 29 | The command to train a DDPM UNet model on the Pokemon dataset: 30 | 31 | ```bash 32 | python -m torch.distributed.launch \ 33 | --nproc_per_node 4 \ 34 | train_unconditional.py \ 35 | --dataset="huggan/pokemon" \ 36 | --resolution=64 \ 37 | --output_dir="pokemon-ddpm" \ 38 | --batch_size=16 \ 39 | --num_epochs=100 \ 40 | --gradient_accumulation_steps=1 \ 41 | --lr=1e-4 \ 42 | --warmup_steps=500 \ 43 | --mixed_precision=no 44 | ``` 45 | 46 | A full training run takes 2 hours on 4xV100 GPUs. 47 | 48 | 49 | -------------------------------------------------------------------------------- /examples/experimental/train_glide_text_to_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import bitsandbytes as bnb 8 | import PIL.Image 9 | from accelerate import Accelerator 10 | from datasets import load_dataset 11 | from diffusers import DDPMScheduler, Glide, GlideUNetModel 12 | from diffusers.hub_utils import init_git_repo, push_to_hub 13 | from diffusers.optimization import get_scheduler 14 | from diffusers.utils import logging 15 | from torchvision.transforms import ( 16 | CenterCrop, 17 | Compose, 18 | InterpolationMode, 19 | Normalize, 20 | RandomHorizontalFlip, 21 | Resize, 22 | ToTensor, 23 | ) 24 | from tqdm.auto import tqdm 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | def main(args): 31 | accelerator = Accelerator(mixed_precision=args.mixed_precision) 32 | 33 | pipeline = Glide.from_pretrained("fusing/glide-base") 34 | model = pipeline.text_unet 35 | noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") 36 | optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr) 37 | 38 | augmentations = Compose( 39 | [ 40 | Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), 41 | CenterCrop(args.resolution), 42 | RandomHorizontalFlip(), 43 | ToTensor(), 44 | Normalize([0.5], [0.5]), 45 | ] 46 | ) 47 | dataset = load_dataset(args.dataset, split="train") 48 | 49 | text_encoder = pipeline.text_encoder.eval() 50 | 51 | def transforms(examples): 52 | images = [augmentations(image.convert("RGB")) for image in examples["image"]] 53 | text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt") 54 | text_inputs = text_inputs.input_ids.to(accelerator.device) 55 | with torch.no_grad(): 56 | text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state 57 | return {"images": images, "text_embeddings": text_embeddings} 58 | 59 | dataset.set_transform(transforms) 60 | train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 61 | 62 | lr_scheduler = get_scheduler( 63 | "linear", 64 | optimizer=optimizer, 65 | num_warmup_steps=args.warmup_steps, 66 | num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, 67 | ) 68 | 69 | model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 70 | model, text_encoder, optimizer, train_dataloader, lr_scheduler 71 | ) 72 | 73 | if args.push_to_hub: 74 | repo = init_git_repo(args, at_init=True) 75 | 76 | # Train! 77 | is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() 78 | world_size = torch.distributed.get_world_size() if is_distributed else 1 79 | total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size 80 | max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs 81 | logger.info("***** Running training *****") 82 | logger.info(f" Num examples = {len(train_dataloader.dataset)}") 83 | logger.info(f" Num Epochs = {args.num_epochs}") 84 | logger.info(f" Instantaneous batch size per device = {args.batch_size}") 85 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 86 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 87 | logger.info(f" Total optimization steps = {max_steps}") 88 | 89 | for epoch in range(args.num_epochs): 90 | model.train() 91 | with tqdm(total=len(train_dataloader), unit="ba") as pbar: 92 | pbar.set_description(f"Epoch {epoch}") 93 | for step, batch in enumerate(train_dataloader): 94 | clean_images = batch["images"] 95 | batch_size, n_channels, height, width = clean_images.shape 96 | noise_samples = torch.randn(clean_images.shape).to(clean_images.device) 97 | timesteps = torch.randint( 98 | 0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device 99 | ).long() 100 | 101 | # add noise onto the clean images according to the noise magnitude at each timestep 102 | # (this is the forward diffusion process) 103 | noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps) 104 | 105 | if step % args.gradient_accumulation_steps != 0: 106 | with accelerator.no_sync(model): 107 | model_output = model(noisy_images, timesteps, batch["text_embeddings"]) 108 | model_output, model_var_values = torch.split(model_output, n_channels, dim=1) 109 | # Learn the variance using the variational bound, but don't let 110 | # it affect our mean prediction. 111 | frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) 112 | 113 | # predict the noise residual 114 | loss = F.mse_loss(model_output, noise_samples) 115 | 116 | loss = loss / args.gradient_accumulation_steps 117 | 118 | accelerator.backward(loss) 119 | optimizer.step() 120 | else: 121 | model_output = model(noisy_images, timesteps, batch["text_embeddings"]) 122 | model_output, model_var_values = torch.split(model_output, n_channels, dim=1) 123 | # Learn the variance using the variational bound, but don't let 124 | # it affect our mean prediction. 125 | frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) 126 | 127 | # predict the noise residual 128 | loss = F.mse_loss(model_output, noise_samples) 129 | loss = loss / args.gradient_accumulation_steps 130 | accelerator.backward(loss) 131 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 132 | optimizer.step() 133 | lr_scheduler.step() 134 | optimizer.zero_grad() 135 | pbar.update(1) 136 | pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) 137 | 138 | accelerator.wait_for_everyone() 139 | 140 | # Generate a sample image for visual inspection 141 | if accelerator.is_main_process: 142 | model.eval() 143 | with torch.no_grad(): 144 | pipeline.unet = accelerator.unwrap_model(model) 145 | 146 | generator = torch.manual_seed(0) 147 | # run pipeline in inference (sample random noise and denoise) 148 | image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50) 149 | 150 | # process image to PIL 151 | image_processed = image.squeeze(0) 152 | image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 153 | image_pil = PIL.Image.fromarray(image_processed) 154 | 155 | # save image 156 | test_dir = os.path.join(args.output_dir, "test_samples") 157 | os.makedirs(test_dir, exist_ok=True) 158 | image_pil.save(f"{test_dir}/{epoch:04d}.png") 159 | 160 | # save the model 161 | if args.push_to_hub: 162 | push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) 163 | else: 164 | pipeline.save_pretrained(args.output_dir) 165 | accelerator.wait_for_everyone() 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 170 | parser.add_argument("--local_rank", type=int, default=-1) 171 | parser.add_argument("--dataset", type=str, default="fusing/dog_captions") 172 | parser.add_argument("--output_dir", type=str, default="glide-text2image") 173 | parser.add_argument("--overwrite_output_dir", action="store_true") 174 | parser.add_argument("--resolution", type=int, default=64) 175 | parser.add_argument("--batch_size", type=int, default=4) 176 | parser.add_argument("--num_epochs", type=int, default=100) 177 | parser.add_argument("--gradient_accumulation_steps", type=int, default=4) 178 | parser.add_argument("--lr", type=float, default=1e-4) 179 | parser.add_argument("--warmup_steps", type=int, default=500) 180 | parser.add_argument("--push_to_hub", action="store_true") 181 | parser.add_argument("--hub_token", type=str, default=None) 182 | parser.add_argument("--hub_model_id", type=str, default=None) 183 | parser.add_argument("--hub_private_repo", action="store_true") 184 | parser.add_argument( 185 | "--mixed_precision", 186 | type=str, 187 | default="no", 188 | choices=["no", "fp16", "bf16"], 189 | help=( 190 | "Whether to use mixed precision. Choose" 191 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 192 | "and an Nvidia Ampere GPU." 193 | ), 194 | ) 195 | 196 | args = parser.parse_args() 197 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 198 | if env_local_rank != -1 and env_local_rank != args.local_rank: 199 | args.local_rank = env_local_rank 200 | 201 | main(args) 202 | -------------------------------------------------------------------------------- /examples/experimental/train_latent_text_to_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import bitsandbytes as bnb 8 | import PIL.Image 9 | from accelerate import Accelerator 10 | from datasets import load_dataset 11 | from diffusers import DDPMScheduler, LatentDiffusion, UNetLDMModel 12 | from diffusers.hub_utils import init_git_repo, push_to_hub 13 | from diffusers.optimization import get_scheduler 14 | from diffusers.utils import logging 15 | from torchvision.transforms import ( 16 | CenterCrop, 17 | Compose, 18 | InterpolationMode, 19 | Normalize, 20 | RandomHorizontalFlip, 21 | Resize, 22 | ToTensor, 23 | ) 24 | from tqdm.auto import tqdm 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | def main(args): 31 | accelerator = Accelerator(mixed_precision=args.mixed_precision) 32 | 33 | pipeline = LatentDiffusion.from_pretrained("fusing/latent-diffusion-text2im-large") 34 | pipeline.unet = None # this model will be trained from scratch now 35 | model = UNetLDMModel( 36 | attention_resolutions=[4, 2, 1], 37 | channel_mult=[1, 2, 4, 4], 38 | context_dim=1280, 39 | conv_resample=True, 40 | dims=2, 41 | dropout=0, 42 | image_size=8, 43 | in_channels=4, 44 | model_channels=320, 45 | num_heads=8, 46 | num_res_blocks=2, 47 | out_channels=4, 48 | resblock_updown=False, 49 | transformer_depth=1, 50 | use_new_attention_order=False, 51 | use_scale_shift_norm=False, 52 | use_spatial_transformer=True, 53 | legacy=False, 54 | ) 55 | noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") 56 | optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr) 57 | 58 | augmentations = Compose( 59 | [ 60 | Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), 61 | CenterCrop(args.resolution), 62 | RandomHorizontalFlip(), 63 | ToTensor(), 64 | Normalize([0.5], [0.5]), 65 | ] 66 | ) 67 | dataset = load_dataset(args.dataset, split="train") 68 | 69 | text_encoder = pipeline.bert.eval() 70 | vqvae = pipeline.vqvae.eval() 71 | 72 | def transforms(examples): 73 | images = [augmentations(image.convert("RGB")) for image in examples["image"]] 74 | text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt") 75 | with torch.no_grad(): 76 | text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs.input_ids.cpu()).last_hidden_state 77 | images = 1 / 0.18215 * torch.stack(images, dim=0) 78 | latents = accelerator.unwrap_model(vqvae).encode(images.cpu()).mode() 79 | return {"images": images, "text_embeddings": text_embeddings, "latents": latents} 80 | 81 | dataset.set_transform(transforms) 82 | train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 83 | 84 | lr_scheduler = get_scheduler( 85 | "linear", 86 | optimizer=optimizer, 87 | num_warmup_steps=args.warmup_steps, 88 | num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, 89 | ) 90 | 91 | model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 92 | model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler 93 | ) 94 | text_encoder = text_encoder.cpu() 95 | vqvae = vqvae.cpu() 96 | 97 | if args.push_to_hub: 98 | repo = init_git_repo(args, at_init=True) 99 | 100 | # Train! 101 | is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() 102 | world_size = torch.distributed.get_world_size() if is_distributed else 1 103 | total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size 104 | max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs 105 | logger.info("***** Running training *****") 106 | logger.info(f" Num examples = {len(train_dataloader.dataset)}") 107 | logger.info(f" Num Epochs = {args.num_epochs}") 108 | logger.info(f" Instantaneous batch size per device = {args.batch_size}") 109 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 110 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 111 | logger.info(f" Total optimization steps = {max_steps}") 112 | 113 | global_step = 0 114 | for epoch in range(args.num_epochs): 115 | model.train() 116 | with tqdm(total=len(train_dataloader), unit="ba") as pbar: 117 | pbar.set_description(f"Epoch {epoch}") 118 | for step, batch in enumerate(train_dataloader): 119 | clean_latents = batch["latents"] 120 | noise_samples = torch.randn(clean_latents.shape).to(clean_latents.device) 121 | bsz = clean_latents.shape[0] 122 | timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_latents.device).long() 123 | 124 | # add noise onto the clean latents according to the noise magnitude at each timestep 125 | # (this is the forward diffusion process) 126 | noisy_latents = noise_scheduler.training_step(clean_latents, noise_samples, timesteps) 127 | 128 | if step % args.gradient_accumulation_steps != 0: 129 | with accelerator.no_sync(model): 130 | output = model(noisy_latents, timesteps, context=batch["text_embeddings"]) 131 | # predict the noise residual 132 | loss = F.mse_loss(output, noise_samples) 133 | loss = loss / args.gradient_accumulation_steps 134 | accelerator.backward(loss) 135 | optimizer.step() 136 | else: 137 | output = model(noisy_latents, timesteps, context=batch["text_embeddings"]) 138 | # predict the noise residual 139 | loss = F.mse_loss(output, noise_samples) 140 | loss = loss / args.gradient_accumulation_steps 141 | accelerator.backward(loss) 142 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 143 | optimizer.step() 144 | lr_scheduler.step() 145 | optimizer.zero_grad() 146 | pbar.update(1) 147 | pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) 148 | global_step += 1 149 | 150 | accelerator.wait_for_everyone() 151 | 152 | # Generate a sample image for visual inspection 153 | if accelerator.is_main_process: 154 | model.eval() 155 | with torch.no_grad(): 156 | pipeline.unet = accelerator.unwrap_model(model) 157 | 158 | generator = torch.manual_seed(0) 159 | # run pipeline in inference (sample random noise and denoise) 160 | image = pipeline( 161 | ["a clip art of a corgi"], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50 162 | ) 163 | 164 | # process image to PIL 165 | image_processed = image.cpu().permute(0, 2, 3, 1) 166 | image_processed = image_processed * 255.0 167 | image_processed = image_processed.type(torch.uint8).numpy() 168 | image_pil = PIL.Image.fromarray(image_processed[0]) 169 | 170 | # save image 171 | test_dir = os.path.join(args.output_dir, "test_samples") 172 | os.makedirs(test_dir, exist_ok=True) 173 | image_pil.save(f"{test_dir}/{epoch:04d}.png") 174 | 175 | # save the model 176 | if args.push_to_hub: 177 | push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) 178 | else: 179 | pipeline.save_pretrained(args.output_dir) 180 | accelerator.wait_for_everyone() 181 | 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 185 | parser.add_argument("--local_rank", type=int, default=-1) 186 | parser.add_argument("--dataset", type=str, default="fusing/dog_captions") 187 | parser.add_argument("--output_dir", type=str, default="ldm-text2image") 188 | parser.add_argument("--overwrite_output_dir", action="store_true") 189 | parser.add_argument("--resolution", type=int, default=128) 190 | parser.add_argument("--batch_size", type=int, default=1) 191 | parser.add_argument("--num_epochs", type=int, default=100) 192 | parser.add_argument("--gradient_accumulation_steps", type=int, default=16) 193 | parser.add_argument("--lr", type=float, default=1e-4) 194 | parser.add_argument("--warmup_steps", type=int, default=500) 195 | parser.add_argument("--push_to_hub", action="store_true") 196 | parser.add_argument("--hub_token", type=str, default=None) 197 | parser.add_argument("--hub_model_id", type=str, default=None) 198 | parser.add_argument("--hub_private_repo", action="store_true") 199 | parser.add_argument( 200 | "--mixed_precision", 201 | type=str, 202 | default="no", 203 | choices=["no", "fp16", "bf16"], 204 | help=( 205 | "Whether to use mixed precision. Choose" 206 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 207 | "and an Nvidia Ampere GPU." 208 | ), 209 | ) 210 | 211 | args = parser.parse_args() 212 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 213 | if env_local_rank != -1 and env_local_rank != args.local_rank: 214 | args.local_rank = env_local_rank 215 | 216 | main(args) 217 | -------------------------------------------------------------------------------- /examples/train_unconditional.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from accelerate import Accelerator, DistributedDataParallelKwargs 8 | from accelerate.logging import get_logger 9 | from datasets import load_dataset 10 | from diffusers import DDIMPipeline, DDIMScheduler, UNetModel 11 | from diffusers.hub_utils import init_git_repo, push_to_hub 12 | from diffusers.optimization import get_scheduler 13 | from diffusers.training_utils import EMAModel 14 | from torchvision.transforms import ( 15 | CenterCrop, 16 | Compose, 17 | InterpolationMode, 18 | Normalize, 19 | RandomHorizontalFlip, 20 | Resize, 21 | ToTensor, 22 | ) 23 | from tqdm.auto import tqdm 24 | 25 | 26 | logger = get_logger(__name__) 27 | 28 | 29 | def main(args): 30 | ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True) 31 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 32 | accelerator = Accelerator( 33 | mixed_precision=args.mixed_precision, 34 | log_with="tensorboard", 35 | logging_dir=logging_dir, 36 | kwargs_handlers=[ddp_unused_params], 37 | ) 38 | 39 | model = UNetModel( 40 | attn_resolutions=(16,), 41 | ch=128, 42 | ch_mult=(1, 2, 4, 8), 43 | dropout=0.0, 44 | num_res_blocks=2, 45 | resamp_with_conv=True, 46 | resolution=args.resolution, 47 | ) 48 | noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt") 49 | optimizer = torch.optim.AdamW( 50 | model.parameters(), 51 | lr=args.learning_rate, 52 | betas=(args.adam_beta1, args.adam_beta2), 53 | weight_decay=args.adam_weight_decay, 54 | eps=args.adam_epsilon, 55 | ) 56 | 57 | augmentations = Compose( 58 | [ 59 | Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), 60 | CenterCrop(args.resolution), 61 | RandomHorizontalFlip(), 62 | ToTensor(), 63 | Normalize([0.5], [0.5]), 64 | ] 65 | ) 66 | dataset = load_dataset(args.dataset, split="train") 67 | 68 | def transforms(examples): 69 | images = [augmentations(image.convert("RGB")) for image in examples["image"]] 70 | return {"input": images} 71 | 72 | dataset.set_transform(transforms) 73 | train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) 74 | 75 | lr_scheduler = get_scheduler( 76 | args.lr_scheduler, 77 | optimizer=optimizer, 78 | num_warmup_steps=args.lr_warmup_steps, 79 | num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, 80 | ) 81 | 82 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 83 | model, optimizer, train_dataloader, lr_scheduler 84 | ) 85 | 86 | ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) 87 | 88 | if args.push_to_hub: 89 | repo = init_git_repo(args, at_init=True) 90 | 91 | if accelerator.is_main_process: 92 | run = os.path.split(__file__)[-1].split(".")[0] 93 | accelerator.init_trackers(run) 94 | 95 | # Train! 96 | is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() 97 | world_size = torch.distributed.get_world_size() if is_distributed else 1 98 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size 99 | max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs 100 | logger.info("***** Running training *****") 101 | logger.info(f" Num examples = {len(train_dataloader.dataset)}") 102 | logger.info(f" Num Epochs = {args.num_epochs}") 103 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 104 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 105 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 106 | logger.info(f" Total optimization steps = {max_steps}") 107 | 108 | global_step = 0 109 | for epoch in range(args.num_epochs): 110 | model.train() 111 | progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) 112 | progress_bar.set_description(f"Epoch {epoch}") 113 | for step, batch in enumerate(train_dataloader): 114 | clean_images = batch["input"] 115 | noise_samples = torch.randn(clean_images.shape).to(clean_images.device) 116 | bsz = clean_images.shape[0] 117 | timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() 118 | 119 | # add noise onto the clean images according to the noise magnitude at each timestep 120 | # (this is the forward diffusion process) 121 | noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps) 122 | 123 | if step % args.gradient_accumulation_steps != 0: 124 | with accelerator.no_sync(model): 125 | output = model(noisy_images, timesteps) 126 | # predict the noise residual 127 | loss = F.mse_loss(output, noise_samples) 128 | loss = loss / args.gradient_accumulation_steps 129 | accelerator.backward(loss) 130 | else: 131 | output = model(noisy_images, timesteps) 132 | # predict the noise residual 133 | loss = F.mse_loss(output, noise_samples) 134 | loss = loss / args.gradient_accumulation_steps 135 | accelerator.backward(loss) 136 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 137 | optimizer.step() 138 | lr_scheduler.step() 139 | ema_model.step(model) 140 | optimizer.zero_grad() 141 | progress_bar.update(1) 142 | progress_bar.set_postfix( 143 | loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay 144 | ) 145 | accelerator.log( 146 | { 147 | "train_loss": loss.detach().item(), 148 | "epoch": epoch, 149 | "ema_decay": ema_model.decay, 150 | "step": global_step, 151 | }, 152 | step=global_step, 153 | ) 154 | global_step += 1 155 | progress_bar.close() 156 | 157 | accelerator.wait_for_everyone() 158 | 159 | # Generate a sample image for visual inspection 160 | if accelerator.is_main_process: 161 | with torch.no_grad(): 162 | pipeline = DDIMPipeline( 163 | unet=accelerator.unwrap_model(ema_model.averaged_model), 164 | noise_scheduler=noise_scheduler, 165 | ) 166 | 167 | generator = torch.manual_seed(0) 168 | # run pipeline in inference (sample random noise and denoise) 169 | images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50) 170 | 171 | # denormalize the images and save to tensorboard 172 | images_processed = (images.cpu() + 1.0) * 127.5 173 | images_processed = images_processed.clamp(0, 255).type(torch.uint8).numpy() 174 | 175 | accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch) 176 | 177 | # save the model 178 | if args.push_to_hub: 179 | push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) 180 | else: 181 | pipeline.save_pretrained(args.output_dir) 182 | accelerator.wait_for_everyone() 183 | 184 | accelerator.end_training() 185 | 186 | 187 | if __name__ == "__main__": 188 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 189 | parser.add_argument("--local_rank", type=int, default=-1) 190 | parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") 191 | parser.add_argument("--output_dir", type=str, default="ddpm-model") 192 | parser.add_argument("--overwrite_output_dir", action="store_true") 193 | parser.add_argument("--resolution", type=int, default=64) 194 | parser.add_argument("--train_batch_size", type=int, default=16) 195 | parser.add_argument("--eval_batch_size", type=int, default=16) 196 | parser.add_argument("--num_epochs", type=int, default=100) 197 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 198 | parser.add_argument("--learning_rate", type=float, default=1e-4) 199 | parser.add_argument("--lr_scheduler", type=str, default="cosine") 200 | parser.add_argument("--lr_warmup_steps", type=int, default=500) 201 | parser.add_argument("--adam_beta1", type=float, default=0.95) 202 | parser.add_argument("--adam_beta2", type=float, default=0.999) 203 | parser.add_argument("--adam_weight_decay", type=float, default=1e-6) 204 | parser.add_argument("--adam_epsilon", type=float, default=1e-3) 205 | parser.add_argument("--ema_inv_gamma", type=float, default=1.0) 206 | parser.add_argument("--ema_power", type=float, default=3 / 4) 207 | parser.add_argument("--ema_max_decay", type=float, default=0.9999) 208 | parser.add_argument("--push_to_hub", action="store_true") 209 | parser.add_argument("--hub_token", type=str, default=None) 210 | parser.add_argument("--hub_model_id", type=str, default=None) 211 | parser.add_argument("--hub_private_repo", action="store_true") 212 | parser.add_argument("--logging_dir", type=str, default="logs") 213 | parser.add_argument( 214 | "--mixed_precision", 215 | type=str, 216 | default="no", 217 | choices=["no", "fp16", "bf16"], 218 | help=( 219 | "Whether to use mixed precision. Choose" 220 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 221 | "and an Nvidia Ampere GPU." 222 | ), 223 | ) 224 | 225 | args = parser.parse_args() 226 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 227 | if env_local_rank != -1 and env_local_rank != args.local_rank: 228 | args.local_rank = env_local_rank 229 | 230 | main(args) 231 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py36'] 4 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | import PIL 4 | import torch 5 | #from configs.ve import ffhq_ncsnpp_continuous as configs 6 | # from configs.ve import cifar10_ncsnpp_continuous as configs 7 | 8 | 9 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 10 | 11 | torch.backends.cuda.matmul.allow_tf32 = False 12 | torch.manual_seed(0) 13 | 14 | 15 | class NewReverseDiffusionPredictor: 16 | def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0): 17 | super().__init__() 18 | self.sigma_min = sigma_min 19 | self.sigma_max = sigma_max 20 | self.N = N 21 | self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) 22 | 23 | self.probability_flow = probability_flow 24 | self.score_fn = score_fn 25 | 26 | def discretize(self, x, t): 27 | timestep = (t * (self.N - 1)).long() 28 | sigma = self.discrete_sigmas.to(t.device)[timestep] 29 | adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), 30 | self.discrete_sigmas[timestep - 1].to(t.device)) 31 | f = torch.zeros_like(x) 32 | G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) 33 | 34 | labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 35 | result = self.score_fn(x, labels) 36 | 37 | rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.) 38 | rev_G = torch.zeros_like(G) if self.probability_flow else G 39 | return rev_f, rev_G 40 | 41 | def update_fn(self, x, t): 42 | f, G = self.discretize(x, t) 43 | z = torch.randn_like(x) 44 | x_mean = x - f 45 | x = x_mean + G[:, None, None, None] * z 46 | return x, x_mean 47 | 48 | 49 | class NewLangevinCorrector: 50 | def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0): 51 | super().__init__() 52 | self.score_fn = score_fn 53 | self.snr = snr 54 | self.n_steps = n_steps 55 | 56 | self.sigma_min = sigma_min 57 | self.sigma_max = sigma_max 58 | 59 | def update_fn(self, x, t): 60 | score_fn = self.score_fn 61 | n_steps = self.n_steps 62 | target_snr = self.snr 63 | # if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): 64 | # timestep = (t * (sde.N - 1) / sde.T).long() 65 | # alpha = sde.alphas.to(t.device)[timestep] 66 | # else: 67 | alpha = torch.ones_like(t) 68 | 69 | for i in range(n_steps): 70 | labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 71 | grad = score_fn(x, labels) 72 | noise = torch.randn_like(x) 73 | grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() 74 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() 75 | step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha 76 | x_mean = x + step_size[:, None, None, None] * grad 77 | x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise 78 | 79 | return x, x_mean 80 | 81 | 82 | 83 | def save_image(x): 84 | image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) 85 | image_pil = PIL.Image.fromarray(image_processed[0]) 86 | image_pil.save("../images/hey.png") 87 | 88 | 89 | # ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" 90 | #ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" 91 | # Note usually we need to restore ema etc... 92 | # ema restored checkpoint used from below 93 | 94 | N = 2 95 | sigma_min = 0.01 96 | sigma_max = 1348 97 | sampling_eps = 1e-5 98 | batch_size = 1 99 | centered = False 100 | 101 | from diffusers import NCSNpp 102 | 103 | model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device) 104 | model = torch.nn.DataParallel(model) 105 | 106 | img_size = model.module.config.image_size 107 | channels = model.module.config.num_channels 108 | shape = (batch_size, channels, img_size, img_size) 109 | probability_flow = False 110 | snr = 0.15 111 | n_steps = 1 112 | 113 | 114 | new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max) 115 | new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N) 116 | 117 | with torch.no_grad(): 118 | # Initial sample 119 | x = torch.randn(*shape) * sigma_max 120 | x = x.to(device) 121 | timesteps = torch.linspace(1, sampling_eps, N, device=device) 122 | 123 | for i in range(N): 124 | t = timesteps[i] 125 | vec_t = torch.ones(shape[0], device=t.device) * t 126 | x, x_mean = new_corrector.update_fn(x, vec_t) 127 | x, x_mean = new_predictor.update_fn(x, vec_t) 128 | 129 | x = x_mean 130 | if centered: 131 | x = (x + 1.) / 2. 132 | 133 | 134 | # save_image(x) 135 | 136 | # for 5 cifar10 137 | x_sum = 106071.9922 138 | x_mean = 34.52864456176758 139 | 140 | # for 1000 cifar10 141 | x_sum = 461.9700 142 | x_mean = 0.1504 143 | 144 | # for 2 for 1024 145 | x_sum = 3382810112.0 146 | x_mean = 1075.366455078125 147 | 148 | def check_x_sum_x_mean(x, x_sum, x_mean): 149 | assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" 150 | assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" 151 | 152 | 153 | check_x_sum_x_mean(x, x_sum, x_mean) 154 | -------------------------------------------------------------------------------- /scripts/conversion_bddm.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import torch 4 | 5 | from diffusers.pipelines.bddm import DiffWave, BDDMPipeline 6 | from diffusers import DDPMScheduler 7 | 8 | 9 | def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path): 10 | sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] 11 | noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu") 12 | 13 | model = DiffWave() 14 | model.load_state_dict(sd, strict=False) 15 | 16 | ts, _, betas, _ = noise_scheduler_sd 17 | ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist()) 18 | 19 | noise_scheduler = DDPMScheduler( 20 | timesteps=12, 21 | trained_betas=betas, 22 | timestep_values=ts, 23 | clip_sample=False, 24 | tensor_format="np", 25 | ) 26 | 27 | pipeline = BDDMPipeline(model, noise_scheduler) 28 | pipeline.save_pretrained(output_path) 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--checkpoint_path", type=str, required=True) 34 | parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True) 35 | parser.add_argument("--output_path", type=str, required=True) 36 | args = parser.parse_args() 37 | 38 | convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path) 39 | 40 | 41 | -------------------------------------------------------------------------------- /scripts/conversion_glide.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel 5 | from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel 6 | from transformers import CLIPTextConfig, GPT2Tokenizer 7 | 8 | 9 | # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt 10 | state_dict = torch.load("base.pt", map_location="cpu") 11 | state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()} 12 | 13 | ### Convert the text encoder 14 | 15 | config = CLIPTextConfig( 16 | vocab_size=50257, 17 | max_position_embeddings=128, 18 | hidden_size=512, 19 | intermediate_size=2048, 20 | num_hidden_layers=16, 21 | num_attention_heads=8, 22 | use_padding_embeddings=True, 23 | ) 24 | model = CLIPTextModel(config).eval() 25 | tokenizer = GPT2Tokenizer( 26 | "./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>" 27 | ) 28 | 29 | hf_encoder = model.text_model 30 | 31 | hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"] 32 | hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"] 33 | hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"] 34 | 35 | hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"] 36 | hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"] 37 | 38 | for layer_idx in range(config.num_hidden_layers): 39 | hf_layer = hf_encoder.encoder.layers[layer_idx] 40 | hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"] 41 | hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"] 42 | 43 | hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"] 44 | hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"] 45 | 46 | hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"] 47 | hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"] 48 | hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"] 49 | hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"] 50 | 51 | hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"] 52 | hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"] 53 | hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] 54 | hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] 55 | 56 | ### Convert the Text-to-Image UNet 57 | 58 | text2im_model = GlideTextToImageUNetModel( 59 | in_channels=3, 60 | model_channels=192, 61 | out_channels=6, 62 | num_res_blocks=3, 63 | attention_resolutions=(2, 4, 8), 64 | dropout=0.1, 65 | channel_mult=(1, 2, 3, 4), 66 | num_heads=1, 67 | num_head_channels=64, 68 | num_heads_upsample=1, 69 | use_scale_shift_norm=True, 70 | resblock_updown=True, 71 | transformer_dim=512, 72 | ) 73 | 74 | text2im_model.load_state_dict(state_dict, strict=False) 75 | 76 | text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2") 77 | 78 | ### Convert the Super-Resolution UNet 79 | 80 | # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt 81 | ups_state_dict = torch.load("upsample.pt", map_location="cpu") 82 | 83 | superres_model = GlideSuperResUNetModel( 84 | in_channels=6, 85 | model_channels=192, 86 | out_channels=6, 87 | num_res_blocks=2, 88 | attention_resolutions=(8, 16, 32), 89 | dropout=0.1, 90 | channel_mult=(1, 1, 2, 2, 4, 4), 91 | num_heads=1, 92 | num_head_channels=64, 93 | num_heads_upsample=1, 94 | use_scale_shift_norm=True, 95 | resblock_updown=True, 96 | ) 97 | 98 | superres_model.load_state_dict(ups_state_dict, strict=False) 99 | 100 | upscale_scheduler = DDIMScheduler( 101 | timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt" 102 | ) 103 | 104 | glide = Glide( 105 | text_unet=text2im_model, 106 | text_noise_scheduler=text_scheduler, 107 | text_encoder=model, 108 | tokenizer=tokenizer, 109 | upscale_unet=superres_model, 110 | upscale_noise_scheduler=upscale_scheduler, 111 | ) 112 | 113 | glide.save_pretrained("./glide-base") 114 | -------------------------------------------------------------------------------- /scripts/conversion_ldm_uncond.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import OmegaConf 4 | import torch 5 | 6 | from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler 7 | 8 | def convert_ldm_original(checkpoint_path, config_path, output_path): 9 | config = OmegaConf.load(config_path) 10 | state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] 11 | keys = list(state_dict.keys()) 12 | 13 | # extract state_dict for VQVAE 14 | first_stage_dict = {} 15 | first_stage_key = "first_stage_model." 16 | for key in keys: 17 | if key.startswith(first_stage_key): 18 | first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key] 19 | 20 | # extract state_dict for UNetLDM 21 | unet_state_dict = {} 22 | unet_key = "model.diffusion_model." 23 | for key in keys: 24 | if key.startswith(unet_key): 25 | unet_state_dict[key.replace(unet_key, "")] = state_dict[key] 26 | 27 | vqvae_init_args = config.model.params.first_stage_config.params 28 | unet_init_args = config.model.params.unet_config.params 29 | 30 | vqvae = VQModel(**vqvae_init_args).eval() 31 | vqvae.load_state_dict(first_stage_dict) 32 | 33 | unet = UNetLDMModel(**unet_init_args).eval() 34 | unet.load_state_dict(unet_state_dict) 35 | 36 | noise_scheduler = DDIMScheduler( 37 | timesteps=config.model.params.timesteps, 38 | beta_schedule="scaled_linear", 39 | beta_start=config.model.params.linear_start, 40 | beta_end=config.model.params.linear_end, 41 | clip_sample=False, 42 | ) 43 | 44 | pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler) 45 | pipeline.save_pretrained(output_path) 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--checkpoint_path", type=str, required=True) 51 | parser.add_argument("--config_path", type=str, required=True) 52 | parser.add_argument("--output_path", type=str, required=True) 53 | args = parser.parse_args() 54 | 55 | convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path) 56 | 57 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = accelerate 7 | known_third_party = 8 | numpy 9 | torch 10 | torch_xla 11 | 12 | line_length = 119 13 | lines_after_imports = 2 14 | multi_line_output = 3 15 | use_parentheses = True 16 | 17 | [flake8] 18 | ignore = E203, E722, E501, E741, W503, W605 19 | max-line-length = 119 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py 17 | 18 | To create the package for pypi. 19 | 20 | 1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the 21 | documentation. 22 | 23 | If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make 24 | for the post-release and run `make fix-copies` on the main branch as well. 25 | 26 | 2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid. 27 | 28 | 3. Unpin specific versions from setup.py that use a git install. 29 | 30 | 4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the 31 | message: "Release: " and push. 32 | 33 | 5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs) 34 | 35 | 6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' " 36 | Push the tag to git: git push --tags origin v-release 37 | 38 | 7. Build both the sources and the wheel. Do not change anything in setup.py between 39 | creating the wheel and the source distribution (obviously). 40 | 41 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory. 42 | (this will build a wheel for the python version you use to build it). 43 | 44 | For the sources, run: "python setup.py sdist" 45 | You should now have a /dist directory with both .whl and .tar.gz source versions. 46 | 47 | 8. Check that everything looks correct by uploading the package to the pypi test server: 48 | 49 | twine upload dist/* -r pypitest 50 | (pypi suggest using twine as other methods upload files via plaintext.) 51 | You may have to specify the repository url, use the following command then: 52 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 53 | 54 | Check that you can install it in a virtualenv by running: 55 | pip install -i https://testpypi.python.org/pypi diffusers 56 | 57 | Check you can run the following commands: 58 | python -c "from diffusers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))" 59 | python -c "from diffusers import *" 60 | 61 | 9. Upload the final version to actual pypi: 62 | twine upload dist/* -r pypi 63 | 64 | 10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 65 | 66 | 11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release, 67 | you need to go back to main before executing this. 68 | """ 69 | 70 | import re 71 | from distutils.core import Command 72 | 73 | from setuptools import find_packages, setup 74 | 75 | # IMPORTANT: 76 | # 1. all dependencies should be listed here with their version requirements if any 77 | # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py 78 | _deps = [ 79 | "Pillow", 80 | "black~=22.0,>=22.3", 81 | "filelock", 82 | "flake8>=3.8.3", 83 | "huggingface-hub", 84 | "isort>=5.5.4", 85 | "numpy", 86 | "pytest", 87 | "regex!=2019.12.17", 88 | "requests", 89 | "torch>=1.4", 90 | "tensorboard", 91 | "modelcards==0.1.4" 92 | ] 93 | 94 | # this is a lookup table with items like: 95 | # 96 | # tokenizers: "huggingface-hub==0.8.0" 97 | # packaging: "packaging" 98 | # 99 | # some of the values are versioned whereas others aren't. 100 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)} 101 | 102 | # since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from 103 | # anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with: 104 | # 105 | # python -c 'import sys; from diffusers.dependency_versions_table import deps; \ 106 | # print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets 107 | # 108 | # Just pass the desired package names to that script as it's shown with 2 packages above. 109 | # 110 | # If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above 111 | # 112 | # You can then feed this for example to `pip`: 113 | # 114 | # pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \ 115 | # print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets) 116 | # 117 | 118 | 119 | def deps_list(*pkgs): 120 | return [deps[pkg] for pkg in pkgs] 121 | 122 | 123 | class DepsTableUpdateCommand(Command): 124 | """ 125 | A custom distutils command that updates the dependency table. 126 | usage: python setup.py deps_table_update 127 | """ 128 | 129 | description = "build runtime dependency table" 130 | user_options = [ 131 | # format: (long option, short option, description). 132 | ("dep-table-update", None, "updates src/diffusers/dependency_versions_table.py"), 133 | ] 134 | 135 | def initialize_options(self): 136 | pass 137 | 138 | def finalize_options(self): 139 | pass 140 | 141 | def run(self): 142 | entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()]) 143 | content = [ 144 | "# THIS FILE HAS BEEN AUTOGENERATED. To update:", 145 | "# 1. modify the `_deps` dict in setup.py", 146 | "# 2. run `make deps_table_update``", 147 | "deps = {", 148 | entries, 149 | "}", 150 | "", 151 | ] 152 | target = "src/diffusers/dependency_versions_table.py" 153 | print(f"updating {target}") 154 | with open(target, "w", encoding="utf-8", newline="\n") as f: 155 | f.write("\n".join(content)) 156 | 157 | 158 | extras = {} 159 | 160 | 161 | extras = {} 162 | extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"] 163 | extras["docs"] = [] 164 | extras["test"] = [ 165 | "pytest", 166 | ] 167 | extras["dev"] = extras["quality"] + extras["test"] 168 | 169 | install_requires = [ 170 | deps["filelock"], 171 | deps["huggingface-hub"], 172 | deps["numpy"], 173 | deps["regex"], 174 | deps["requests"], 175 | deps["torch"], 176 | deps["Pillow"], 177 | deps["tensorboard"], 178 | deps["modelcards"], 179 | ] 180 | 181 | setup( 182 | name="diffusers", 183 | version="0.0.4", 184 | description="Diffusers", 185 | long_description=open("README.md", "r", encoding="utf-8").read(), 186 | long_description_content_type="text/markdown", 187 | keywords="deep learning", 188 | license="Apache", 189 | author="The HuggingFace team", 190 | author_email="patrick@huggingface.co", 191 | url="https://github.com/huggingface/diffusers", 192 | package_dir={"": "src"}, 193 | packages=find_packages("src"), 194 | python_requires=">=3.6.0", 195 | install_requires=install_requires, 196 | extras_require=extras, 197 | classifiers=[ 198 | "Development Status :: 5 - Production/Stable", 199 | "Intended Audience :: Developers", 200 | "Intended Audience :: Education", 201 | "Intended Audience :: Science/Research", 202 | "License :: OSI Approved :: Apache Software License", 203 | "Operating System :: OS Independent", 204 | "Programming Language :: Python :: 3", 205 | "Programming Language :: Python :: 3.6", 206 | "Programming Language :: Python :: 3.7", 207 | "Programming Language :: Python :: 3.8", 208 | "Programming Language :: Python :: 3.9", 209 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 210 | ], 211 | cmdclass={"deps_table_update": DepsTableUpdateCommand}, 212 | ) 213 | 214 | # Release checklist 215 | # 1. Change the version in __init__.py and setup.py. 216 | # 2. Commit these changes with the message: "Release: Release" 217 | # 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for pypi' " 218 | # Push the tag to git: git push --tags origin main 219 | # 4. Run the following commands in the top-level directory: 220 | # python setup.py bdist_wheel 221 | # python setup.py sdist 222 | # 5. Upload the package to the pypi test server first: 223 | # twine upload dist/* -r pypitest 224 | # twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 225 | # 6. Check that you can install it in a virtualenv by running: 226 | # pip install -i https://testpypi.python.org/pypi diffusers 227 | # diffusers env 228 | # diffusers test 229 | # 7. Upload the final version to actual pypi: 230 | # twine upload dist/* -r pypi 231 | # 8. Add release notes to the tag in github once everything is looking hunky-dory. 232 | # 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master 233 | -------------------------------------------------------------------------------- /src/diffusers/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | from .utils import is_inflect_available, is_transformers_available, is_unidecode_available 5 | 6 | 7 | __version__ = "0.0.4" 8 | 9 | from .modeling_utils import ModelMixin 10 | from .models import AutoencoderKL, NCSNpp, TemporalUNet, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel 11 | from .pipeline_utils import DiffusionPipeline 12 | from .pipelines import ( 13 | BDDMPipeline, 14 | DDIMPipeline, 15 | DDPMPipeline, 16 | LatentDiffusionUncondPipeline, 17 | PNDMPipeline, 18 | ScoreSdeVePipeline, 19 | ScoreSdeVpPipeline, 20 | ) 21 | from .schedulers import ( 22 | DDIMScheduler, 23 | DDPMScheduler, 24 | GradTTSScheduler, 25 | PNDMScheduler, 26 | SchedulerMixin, 27 | ScoreSdeVeScheduler, 28 | ScoreSdeVpScheduler, 29 | ) 30 | 31 | 32 | if is_transformers_available(): 33 | from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel 34 | from .models.unet_grad_tts import UNetGradTTSModel 35 | from .pipelines import GlidePipeline, LatentDiffusionPipeline 36 | else: 37 | from .utils.dummy_transformers_objects import * 38 | 39 | 40 | if is_transformers_available() and is_inflect_available() and is_unidecode_available(): 41 | from .pipelines import GradTTSPipeline 42 | else: 43 | from .utils.dummy_transformers_and_inflect_and_unidecode_objects import * 44 | -------------------------------------------------------------------------------- /src/diffusers/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | from .dependency_versions_table import deps 17 | from .utils.versions import require_version, require_version_core 18 | 19 | 20 | # define which module versions we always want to check at run time 21 | # (usually the ones defined in `install_requires` in setup.py) 22 | # 23 | # order specific notes: 24 | # - tqdm must be checked before tokenizers 25 | 26 | pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() 27 | if sys.version_info < (3, 7): 28 | pkgs_to_check_at_runtime.append("dataclasses") 29 | if sys.version_info < (3, 8): 30 | pkgs_to_check_at_runtime.append("importlib_metadata") 31 | 32 | for pkg in pkgs_to_check_at_runtime: 33 | if pkg in deps: 34 | if pkg == "tokenizers": 35 | # must be loaded here, or else tqdm check may fail 36 | from .utils import is_tokenizers_available 37 | 38 | if not is_tokenizers_available(): 39 | continue # not required, check version only if installed 40 | 41 | require_version_core(deps[pkg]) 42 | else: 43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 44 | 45 | 46 | def dep_version_check(pkg, hint=None): 47 | require_version(deps[pkg], hint) 48 | -------------------------------------------------------------------------------- /src/diffusers/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "Pillow": "Pillow", 6 | "black": "black~=22.0,>=22.3", 7 | "filelock": "filelock", 8 | "flake8": "flake8>=3.8.3", 9 | "huggingface-hub": "huggingface-hub", 10 | "isort": "isort>=5.5.4", 11 | "numpy": "numpy", 12 | "pytest": "pytest", 13 | "regex": "regex!=2019.12.17", 14 | "requests": "requests", 15 | "torch": "torch>=1.4", 16 | "tensorboard": "tensorboard", 17 | "modelcards": "modelcards==0.1.4", 18 | } 19 | -------------------------------------------------------------------------------- /src/diffusers/hub_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import shutil 19 | from pathlib import Path 20 | from typing import Optional 21 | 22 | from diffusers import DiffusionPipeline 23 | from huggingface_hub import HfFolder, Repository, whoami 24 | from modelcards import CardData, ModelCard 25 | 26 | from .utils import logging 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" 33 | 34 | 35 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 36 | if token is None: 37 | token = HfFolder.get_token() 38 | if organization is None: 39 | username = whoami(token)["name"] 40 | return f"{username}/{model_id}" 41 | else: 42 | return f"{organization}/{model_id}" 43 | 44 | 45 | def init_git_repo(args, at_init: bool = False): 46 | """ 47 | Args: 48 | Initializes a git repo in `args.hub_model_id`. 49 | at_init (`bool`, *optional*, defaults to `False`): 50 | Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` 51 | and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. 52 | """ 53 | if args.local_rank not in [-1, 0]: 54 | return 55 | use_auth_token = True if args.hub_token is None else args.hub_token 56 | if args.hub_model_id is None: 57 | repo_name = Path(args.output_dir).absolute().name 58 | else: 59 | repo_name = args.hub_model_id 60 | if "/" not in repo_name: 61 | repo_name = get_full_repo_name(repo_name, token=args.hub_token) 62 | 63 | try: 64 | repo = Repository( 65 | args.output_dir, 66 | clone_from=repo_name, 67 | use_auth_token=use_auth_token, 68 | private=args.hub_private_repo, 69 | ) 70 | except EnvironmentError: 71 | if args.overwrite_output_dir and at_init: 72 | # Try again after wiping output_dir 73 | shutil.rmtree(args.output_dir) 74 | repo = Repository( 75 | args.output_dir, 76 | clone_from=repo_name, 77 | use_auth_token=use_auth_token, 78 | ) 79 | else: 80 | raise 81 | 82 | repo.git_pull() 83 | 84 | # By default, ignore the checkpoint folders 85 | if not os.path.exists(os.path.join(args.output_dir, ".gitignore")): 86 | with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: 87 | writer.writelines(["checkpoint-*/"]) 88 | 89 | return repo 90 | 91 | 92 | def push_to_hub( 93 | args, 94 | pipeline: DiffusionPipeline, 95 | repo: Repository, 96 | commit_message: Optional[str] = "End of training", 97 | blocking: bool = True, 98 | **kwargs, 99 | ) -> str: 100 | """ 101 | Parameters: 102 | Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. 103 | commit_message (`str`, *optional*, defaults to `"End of training"`): 104 | Message to commit while pushing. 105 | blocking (`bool`, *optional*, defaults to `True`): 106 | Whether the function should return only when the `git push` has finished. 107 | kwargs: 108 | Additional keyword arguments passed along to [`create_model_card`]. 109 | Returns: 110 | The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the 111 | commit and an object to track the progress of the commit if `blocking=True` 112 | """ 113 | 114 | if args.hub_model_id is None: 115 | model_name = Path(args.output_dir).name 116 | else: 117 | model_name = args.hub_model_id.split("/")[-1] 118 | 119 | output_dir = args.output_dir 120 | os.makedirs(output_dir, exist_ok=True) 121 | logger.info(f"Saving pipeline checkpoint to {output_dir}") 122 | pipeline.save_pretrained(output_dir) 123 | 124 | # Only push from one node. 125 | if args.local_rank not in [-1, 0]: 126 | return 127 | 128 | # Cancel any async push in progress if blocking=True. The commits will all be pushed together. 129 | if ( 130 | blocking 131 | and len(repo.command_queue) > 0 132 | and repo.command_queue[-1] is not None 133 | and not repo.command_queue[-1].is_done 134 | ): 135 | repo.command_queue[-1]._process.kill() 136 | 137 | git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True) 138 | # push separately the model card to be independent from the rest of the model 139 | create_model_card(args, model_name=model_name) 140 | try: 141 | repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True) 142 | except EnvironmentError as exc: 143 | logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") 144 | 145 | return git_head_commit_url 146 | 147 | 148 | def create_model_card(args, model_name): 149 | if args.local_rank not in [-1, 0]: 150 | return 151 | 152 | repo_name = get_full_repo_name(model_name, token=args.hub_token) 153 | 154 | model_card = ModelCard.from_template( 155 | card_data=CardData( # Card metadata object that will be converted to YAML block 156 | language="en", 157 | license="apache-2.0", 158 | library_name="diffusers", 159 | tags=[], 160 | datasets=args.dataset, 161 | metrics=[], 162 | ), 163 | template_path=MODEL_CARD_TEMPLATE_PATH, 164 | model_name=model_name, 165 | repo_name=repo_name, 166 | dataset_name=args.dataset, 167 | learning_rate=args.learning_rate, 168 | train_batch_size=args.train_batch_size, 169 | eval_batch_size=args.eval_batch_size, 170 | gradient_accumulation_steps=args.gradient_accumulation_steps, 171 | adam_beta1=args.adam_beta1, 172 | adam_beta2=args.adam_beta2, 173 | adam_weight_decay=args.adam_weight_decay, 174 | adam_epsilon=args.adam_epsilon, 175 | lr_scheduler=args.lr_scheduler, 176 | lr_warmup_steps=args.lr_warmup_steps, 177 | ema_inv_gamma=args.ema_inv_gamma, 178 | ema_power=args.ema_power, 179 | ema_max_decay=args.ema_max_decay, 180 | mixed_precision=args.mixed_precision, 181 | ) 182 | 183 | card_path = os.path.join(args.output_dir, "README.md") 184 | model_card.save(card_path) 185 | -------------------------------------------------------------------------------- /src/diffusers/models/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | - Models: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to denoise a noisy input to an image. Examples: UNet, Conditioned UNet, 3D UNet, Transformer UNet 4 | 5 | ## API 6 | 7 | TODO(Suraj, Patrick) 8 | 9 | ## Examples 10 | 11 | TODO(Suraj, Patrick) 12 | -------------------------------------------------------------------------------- /src/diffusers/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2022 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .unet import UNetModel 20 | from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel 21 | from .unet_grad_tts import UNetGradTTSModel 22 | from .unet_ldm import UNetLDMModel 23 | from .unet_rl import TemporalUNet 24 | from .unet_sde_score_estimation import NCSNpp 25 | from .unet_unconditional import UNetUnconditionalModel 26 | from .vae import AutoencoderKL, VQModel 27 | -------------------------------------------------------------------------------- /src/diffusers/models/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | 16 | import numpy as np 17 | import torch 18 | from torch import nn 19 | 20 | 21 | def get_timestep_embedding( 22 | timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000 23 | ): 24 | """ 25 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 26 | 27 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 28 | These may be fractional. 29 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 30 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 31 | """ 32 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 33 | 34 | half_dim = embedding_dim // 2 35 | 36 | emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift) 37 | emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) 38 | emb = torch.exp(emb * emb_coeff) 39 | emb = timesteps[:, None].float() * emb[None, :] 40 | 41 | # scale embeddings 42 | emb = scale * emb 43 | 44 | # concat sine and cosine embeddings 45 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 46 | 47 | # flip sine and cosine embeddings 48 | if flip_sin_to_cos: 49 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 50 | 51 | # zero pad 52 | if embedding_dim % 2 == 1: 53 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 54 | return emb 55 | 56 | 57 | # unet_sde_score_estimation.py 58 | class GaussianFourierProjection(nn.Module): 59 | """Gaussian Fourier embeddings for noise levels.""" 60 | 61 | def __init__(self, embedding_size=256, scale=1.0): 62 | super().__init__() 63 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 64 | 65 | def forward(self, x): 66 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 67 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 68 | -------------------------------------------------------------------------------- /src/diffusers/models/unet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | 14 | # limitations under the License. 15 | 16 | # helpers functions 17 | 18 | import torch 19 | from torch import nn 20 | 21 | from ..configuration_utils import ConfigMixin 22 | from ..modeling_utils import ModelMixin 23 | from .attention import AttentionBlock 24 | from .embeddings import get_timestep_embedding 25 | from .resnet import Downsample2D, ResnetBlock2D, Upsample2D 26 | from .unet_new import UNetMidBlock2D 27 | 28 | 29 | def nonlinearity(x): 30 | # swish 31 | return x * torch.sigmoid(x) 32 | 33 | 34 | def Normalize(in_channels): 35 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 36 | 37 | 38 | class UNetModel(ModelMixin, ConfigMixin): 39 | def __init__( 40 | self, 41 | ch=128, 42 | out_ch=3, 43 | ch_mult=(1, 1, 2, 2, 4, 4), 44 | num_res_blocks=2, 45 | attn_resolutions=(16,), 46 | dropout=0.0, 47 | resamp_with_conv=True, 48 | in_channels=3, 49 | resolution=256, 50 | ): 51 | super().__init__() 52 | self.register_to_config( 53 | ch=ch, 54 | out_ch=out_ch, 55 | ch_mult=ch_mult, 56 | num_res_blocks=num_res_blocks, 57 | attn_resolutions=attn_resolutions, 58 | dropout=dropout, 59 | resamp_with_conv=resamp_with_conv, 60 | in_channels=in_channels, 61 | resolution=resolution, 62 | ) 63 | ch_mult = tuple(ch_mult) 64 | self.ch = ch 65 | self.temb_ch = self.ch * 4 66 | self.num_resolutions = len(ch_mult) 67 | self.num_res_blocks = num_res_blocks 68 | self.resolution = resolution 69 | self.in_channels = in_channels 70 | 71 | # timestep embedding 72 | self.temb = nn.Module() 73 | self.temb.dense = nn.ModuleList( 74 | [ 75 | torch.nn.Linear(self.ch, self.temb_ch), 76 | torch.nn.Linear(self.temb_ch, self.temb_ch), 77 | ] 78 | ) 79 | 80 | # downsampling 81 | self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 82 | 83 | curr_res = resolution 84 | in_ch_mult = (1,) + ch_mult 85 | self.down = nn.ModuleList() 86 | for i_level in range(self.num_resolutions): 87 | block = nn.ModuleList() 88 | attn = nn.ModuleList() 89 | block_in = ch * in_ch_mult[i_level] 90 | block_out = ch * ch_mult[i_level] 91 | for i_block in range(self.num_res_blocks): 92 | block.append( 93 | ResnetBlock2D( 94 | in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout 95 | ) 96 | ) 97 | block_in = block_out 98 | if curr_res in attn_resolutions: 99 | attn.append(AttentionBlock(block_in, overwrite_qkv=True)) 100 | down = nn.Module() 101 | down.block = block 102 | down.attn = attn 103 | if i_level != self.num_resolutions - 1: 104 | down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0) 105 | curr_res = curr_res // 2 106 | self.down.append(down) 107 | 108 | # middle 109 | self.mid = nn.Module() 110 | self.mid.block_1 = ResnetBlock2D( 111 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout 112 | ) 113 | self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) 114 | self.mid.block_2 = ResnetBlock2D( 115 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout 116 | ) 117 | self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 118 | self.mid_new.resnets[0] = self.mid.block_1 119 | self.mid_new.attentions[0] = self.mid.attn_1 120 | self.mid_new.resnets[1] = self.mid.block_2 121 | 122 | # upsampling 123 | self.up = nn.ModuleList() 124 | for i_level in reversed(range(self.num_resolutions)): 125 | block = nn.ModuleList() 126 | attn = nn.ModuleList() 127 | block_out = ch * ch_mult[i_level] 128 | skip_in = ch * ch_mult[i_level] 129 | for i_block in range(self.num_res_blocks + 1): 130 | if i_block == self.num_res_blocks: 131 | skip_in = ch * in_ch_mult[i_level] 132 | block.append( 133 | ResnetBlock2D( 134 | in_channels=block_in + skip_in, 135 | out_channels=block_out, 136 | temb_channels=self.temb_ch, 137 | dropout=dropout, 138 | ) 139 | ) 140 | block_in = block_out 141 | if curr_res in attn_resolutions: 142 | attn.append(AttentionBlock(block_in, overwrite_qkv=True)) 143 | up = nn.Module() 144 | up.block = block 145 | up.attn = attn 146 | if i_level != 0: 147 | up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv) 148 | curr_res = curr_res * 2 149 | self.up.insert(0, up) # prepend to get consistent order 150 | 151 | # end 152 | self.norm_out = Normalize(block_in) 153 | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 154 | 155 | def forward(self, sample, timesteps): 156 | x = sample 157 | assert x.shape[2] == x.shape[3] == self.resolution 158 | 159 | if not torch.is_tensor(timesteps): 160 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) 161 | 162 | # timestep embedding 163 | temb = get_timestep_embedding(timesteps, self.ch) 164 | temb = self.temb.dense[0](temb) 165 | temb = nonlinearity(temb) 166 | temb = self.temb.dense[1](temb) 167 | 168 | # downsampling 169 | hs = [self.conv_in(x)] 170 | for i_level in range(self.num_resolutions): 171 | for i_block in range(self.num_res_blocks): 172 | h = self.down[i_level].block[i_block](hs[-1], temb) 173 | if len(self.down[i_level].attn) > 0: 174 | h = self.down[i_level].attn[i_block](h) 175 | hs.append(h) 176 | if i_level != self.num_resolutions - 1: 177 | hs.append(self.down[i_level].downsample(hs[-1])) 178 | 179 | # middle 180 | h = self.mid_new(hs[-1], temb) 181 | 182 | # upsampling 183 | for i_level in reversed(range(self.num_resolutions)): 184 | for i_block in range(self.num_res_blocks + 1): 185 | h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) 186 | if len(self.up[i_level].attn) > 0: 187 | h = self.up[i_level].attn[i_block](h) 188 | if i_level != 0: 189 | h = self.up[i_level].upsample(h) 190 | 191 | # end 192 | h = self.norm_out(h) 193 | h = nonlinearity(h) 194 | h = self.conv_out(h) 195 | return h 196 | -------------------------------------------------------------------------------- /src/diffusers/models/unet_grad_tts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..configuration_utils import ConfigMixin 4 | from ..modeling_utils import ModelMixin 5 | from .attention import LinearAttention 6 | from .embeddings import get_timestep_embedding 7 | from .resnet import Downsample2D, ResnetBlock2D, Upsample2D 8 | from .unet_new import UNetMidBlock2D 9 | 10 | 11 | class Mish(torch.nn.Module): 12 | def forward(self, x): 13 | return x * torch.tanh(torch.nn.functional.softplus(x)) 14 | 15 | 16 | class Rezero(torch.nn.Module): 17 | def __init__(self, fn): 18 | super(Rezero, self).__init__() 19 | self.fn = fn 20 | self.g = torch.nn.Parameter(torch.zeros(1)) 21 | 22 | def forward(self, x, encoder_out=None): 23 | return self.fn(x, encoder_out) * self.g 24 | 25 | 26 | class Block(torch.nn.Module): 27 | def __init__(self, dim, dim_out, groups=8): 28 | super(Block, self).__init__() 29 | self.block = torch.nn.Sequential( 30 | torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() 31 | ) 32 | 33 | def forward(self, x, mask): 34 | output = self.block(x * mask) 35 | return output * mask 36 | 37 | 38 | class Residual(torch.nn.Module): 39 | def __init__(self, fn): 40 | super(Residual, self).__init__() 41 | self.fn = fn 42 | 43 | def forward(self, x, *args, **kwargs): 44 | output = self.fn(x, *args, **kwargs) + x 45 | return output 46 | 47 | 48 | class UNetGradTTSModel(ModelMixin, ConfigMixin): 49 | def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): 50 | super(UNetGradTTSModel, self).__init__() 51 | 52 | self.register_to_config( 53 | dim=dim, 54 | dim_mults=dim_mults, 55 | groups=groups, 56 | n_spks=n_spks, 57 | spk_emb_dim=spk_emb_dim, 58 | n_feats=n_feats, 59 | pe_scale=pe_scale, 60 | ) 61 | 62 | self.dim = dim 63 | self.dim_mults = dim_mults 64 | self.groups = groups 65 | self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 66 | self.spk_emb_dim = spk_emb_dim 67 | self.pe_scale = pe_scale 68 | 69 | if n_spks > 1: 70 | self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) 71 | self.spk_mlp = torch.nn.Sequential( 72 | torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) 73 | ) 74 | 75 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) 76 | 77 | dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] 78 | in_out = list(zip(dims[:-1], dims[1:])) 79 | self.downs = torch.nn.ModuleList([]) 80 | self.ups = torch.nn.ModuleList([]) 81 | num_resolutions = len(in_out) 82 | 83 | for ind, (dim_in, dim_out) in enumerate(in_out): 84 | is_last = ind >= (num_resolutions - 1) 85 | self.downs.append( 86 | torch.nn.ModuleList( 87 | [ 88 | ResnetBlock2D( 89 | in_channels=dim_in, 90 | out_channels=dim_out, 91 | temb_channels=dim, 92 | groups=8, 93 | pre_norm=False, 94 | eps=1e-5, 95 | non_linearity="mish", 96 | overwrite_for_grad_tts=True, 97 | ), 98 | ResnetBlock2D( 99 | in_channels=dim_out, 100 | out_channels=dim_out, 101 | temb_channels=dim, 102 | groups=8, 103 | pre_norm=False, 104 | eps=1e-5, 105 | non_linearity="mish", 106 | overwrite_for_grad_tts=True, 107 | ), 108 | Residual(Rezero(LinearAttention(dim_out))), 109 | Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), 110 | ] 111 | ) 112 | ) 113 | 114 | mid_dim = dims[-1] 115 | 116 | self.mid = UNetMidBlock2D( 117 | in_channels=mid_dim, 118 | temb_channels=dim, 119 | resnet_groups=8, 120 | resnet_pre_norm=False, 121 | resnet_eps=1e-5, 122 | resnet_act_fn="mish", 123 | attention_layer_type="linear", 124 | ) 125 | 126 | self.mid_block1 = ResnetBlock2D( 127 | in_channels=mid_dim, 128 | out_channels=mid_dim, 129 | temb_channels=dim, 130 | groups=8, 131 | pre_norm=False, 132 | eps=1e-5, 133 | non_linearity="mish", 134 | overwrite_for_grad_tts=True, 135 | ) 136 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) 137 | self.mid_block2 = ResnetBlock2D( 138 | in_channels=mid_dim, 139 | out_channels=mid_dim, 140 | temb_channels=dim, 141 | groups=8, 142 | pre_norm=False, 143 | eps=1e-5, 144 | non_linearity="mish", 145 | overwrite_for_grad_tts=True, 146 | ) 147 | self.mid.resnets[0] = self.mid_block1 148 | self.mid.attentions[0] = self.mid_attn 149 | self.mid.resnets[1] = self.mid_block2 150 | 151 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 152 | self.ups.append( 153 | torch.nn.ModuleList( 154 | [ 155 | ResnetBlock2D( 156 | in_channels=dim_out * 2, 157 | out_channels=dim_in, 158 | temb_channels=dim, 159 | groups=8, 160 | pre_norm=False, 161 | eps=1e-5, 162 | non_linearity="mish", 163 | overwrite_for_grad_tts=True, 164 | ), 165 | ResnetBlock2D( 166 | in_channels=dim_in, 167 | out_channels=dim_in, 168 | temb_channels=dim, 169 | groups=8, 170 | pre_norm=False, 171 | eps=1e-5, 172 | non_linearity="mish", 173 | overwrite_for_grad_tts=True, 174 | ), 175 | Residual(Rezero(LinearAttention(dim_in))), 176 | Upsample2D(dim_in, use_conv_transpose=True), 177 | ] 178 | ) 179 | ) 180 | self.final_block = Block(dim, dim) 181 | self.final_conv = torch.nn.Conv2d(dim, 1, 1) 182 | 183 | def forward(self, sample, timesteps, mu, mask, spk=None): 184 | x = sample 185 | if self.n_spks > 1: 186 | # Get speaker embedding 187 | spk = self.spk_emb(spk) 188 | 189 | if not isinstance(spk, type(None)): 190 | s = self.spk_mlp(spk) 191 | 192 | t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale) 193 | t = self.mlp(t) 194 | 195 | if self.n_spks < 2: 196 | x = torch.stack([mu, x], 1) 197 | else: 198 | s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) 199 | x = torch.stack([mu, x, s], 1) 200 | mask = mask.unsqueeze(1) 201 | 202 | hiddens = [] 203 | masks = [mask] 204 | for resnet1, resnet2, attn, downsample in self.downs: 205 | mask_down = masks[-1] 206 | x = resnet1(x, t, mask_down) 207 | x = resnet2(x, t, mask_down) 208 | x = attn(x) 209 | hiddens.append(x) 210 | x = downsample(x * mask_down) 211 | masks.append(mask_down[:, :, :, ::2]) 212 | 213 | masks = masks[:-1] 214 | mask_mid = masks[-1] 215 | 216 | x = self.mid(x, t, mask=mask_mid) 217 | 218 | for resnet1, resnet2, attn, upsample in self.ups: 219 | mask_up = masks.pop() 220 | x = torch.cat((x, hiddens.pop()), dim=1) 221 | x = resnet1(x, t, mask_up) 222 | x = resnet2(x, t, mask_up) 223 | x = attn(x) 224 | x = upsample(x * mask_up) 225 | 226 | x = self.final_block(x, mask) 227 | output = self.final_conv(x * mask) 228 | 229 | return (output * mask).squeeze(1) 230 | -------------------------------------------------------------------------------- /src/diffusers/models/unet_rl.py: -------------------------------------------------------------------------------- 1 | # model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..configuration_utils import ConfigMixin 7 | from ..modeling_utils import ModelMixin 8 | from .embeddings import get_timestep_embedding 9 | from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D 10 | 11 | 12 | class SinusoidalPosEmb(nn.Module): 13 | def __init__(self, dim): 14 | super().__init__() 15 | self.dim = dim 16 | 17 | def forward(self, x): 18 | return get_timestep_embedding(x, self.dim) 19 | 20 | 21 | class RearrangeDim(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, tensor): 26 | if len(tensor.shape) == 2: 27 | return tensor[:, :, None] 28 | if len(tensor.shape) == 3: 29 | return tensor[:, :, None, :] 30 | elif len(tensor.shape) == 4: 31 | return tensor[:, :, 0, :] 32 | else: 33 | raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") 34 | 35 | 36 | class Conv1dBlock(nn.Module): 37 | """ 38 | Conv1d --> GroupNorm --> Mish 39 | """ 40 | 41 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 42 | super().__init__() 43 | 44 | self.block = nn.Sequential( 45 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 46 | RearrangeDim(), 47 | # Rearrange("batch channels horizon -> batch channels 1 horizon"), 48 | nn.GroupNorm(n_groups, out_channels), 49 | RearrangeDim(), 50 | # Rearrange("batch channels 1 horizon -> batch channels horizon"), 51 | nn.Mish(), 52 | ) 53 | 54 | def forward(self, x): 55 | return self.block(x) 56 | 57 | 58 | class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): 59 | def __init__( 60 | self, 61 | training_horizon=128, 62 | transition_dim=14, 63 | cond_dim=3, 64 | predict_epsilon=False, 65 | clip_denoised=True, 66 | dim=32, 67 | dim_mults=(1, 4, 8), 68 | ): 69 | super().__init__() 70 | 71 | self.transition_dim = transition_dim 72 | self.cond_dim = cond_dim 73 | self.predict_epsilon = predict_epsilon 74 | self.clip_denoised = clip_denoised 75 | 76 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] 77 | in_out = list(zip(dims[:-1], dims[1:])) 78 | 79 | time_dim = dim 80 | self.time_mlp = nn.Sequential( 81 | SinusoidalPosEmb(dim), 82 | nn.Linear(dim, dim * 4), 83 | nn.Mish(), 84 | nn.Linear(dim * 4, dim), 85 | ) 86 | 87 | self.downs = nn.ModuleList([]) 88 | self.ups = nn.ModuleList([]) 89 | num_resolutions = len(in_out) 90 | 91 | for ind, (dim_in, dim_out) in enumerate(in_out): 92 | is_last = ind >= (num_resolutions - 1) 93 | 94 | self.downs.append( 95 | nn.ModuleList( 96 | [ 97 | ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), 98 | ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), 99 | Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), 100 | ] 101 | ) 102 | ) 103 | 104 | if not is_last: 105 | training_horizon = training_horizon // 2 106 | 107 | mid_dim = dims[-1] 108 | self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) 109 | self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) 110 | 111 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 112 | is_last = ind >= (num_resolutions - 1) 113 | 114 | self.ups.append( 115 | nn.ModuleList( 116 | [ 117 | ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), 118 | ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), 119 | Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(), 120 | ] 121 | ) 122 | ) 123 | 124 | if not is_last: 125 | training_horizon = training_horizon * 2 126 | 127 | self.final_conv = nn.Sequential( 128 | Conv1dBlock(dim, dim, kernel_size=5), 129 | nn.Conv1d(dim, transition_dim, 1), 130 | ) 131 | 132 | def forward(self, sample, timesteps): 133 | """ 134 | x : [ batch x horizon x transition ] 135 | """ 136 | x = sample 137 | 138 | x = x.permute(0, 2, 1) 139 | 140 | t = self.time_mlp(timesteps) 141 | h = [] 142 | 143 | for resnet, resnet2, downsample in self.downs: 144 | x = resnet(x, t) 145 | x = resnet2(x, t) 146 | h.append(x) 147 | x = downsample(x) 148 | 149 | x = self.mid_block1(x, t) 150 | x = self.mid_block2(x, t) 151 | 152 | for resnet, resnet2, upsample in self.ups: 153 | x = torch.cat((x, h.pop()), dim=1) 154 | x = resnet(x, t) 155 | x = resnet2(x, t) 156 | x = upsample(x) 157 | 158 | x = self.final_conv(x) 159 | 160 | x = x.permute(0, 2, 1) 161 | return x 162 | 163 | 164 | class TemporalValue(nn.Module): 165 | def __init__( 166 | self, 167 | horizon, 168 | transition_dim, 169 | cond_dim, 170 | dim=32, 171 | time_dim=None, 172 | out_dim=1, 173 | dim_mults=(1, 2, 4, 8), 174 | ): 175 | super().__init__() 176 | 177 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] 178 | in_out = list(zip(dims[:-1], dims[1:])) 179 | 180 | time_dim = time_dim or dim 181 | self.time_mlp = nn.Sequential( 182 | SinusoidalPosEmb(dim), 183 | nn.Linear(dim, dim * 4), 184 | nn.Mish(), 185 | nn.Linear(dim * 4, dim), 186 | ) 187 | 188 | self.blocks = nn.ModuleList([]) 189 | 190 | print(in_out) 191 | for dim_in, dim_out in in_out: 192 | self.blocks.append( 193 | nn.ModuleList( 194 | [ 195 | ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), 196 | ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), 197 | Downsample1d(dim_out), 198 | ] 199 | ) 200 | ) 201 | 202 | horizon = horizon // 2 203 | 204 | fc_dim = dims[-1] * max(horizon, 1) 205 | 206 | self.final_block = nn.Sequential( 207 | nn.Linear(fc_dim + time_dim, fc_dim // 2), 208 | nn.Mish(), 209 | nn.Linear(fc_dim // 2, out_dim), 210 | ) 211 | 212 | def forward(self, x, cond, time, *args): 213 | """ 214 | x : [ batch x horizon x transition ] 215 | """ 216 | x = x.permute(0, 2, 1) 217 | 218 | t = self.time_mlp(time) 219 | 220 | for resnet, resnet2, downsample in self.blocks: 221 | x = resnet(x, t) 222 | x = resnet2(x, t) 223 | x = downsample(x) 224 | 225 | x = x.view(len(x), -1) 226 | out = self.final_block(torch.cat([x, t], dim=-1)) 227 | return out 228 | -------------------------------------------------------------------------------- /src/diffusers/pipeline_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import importlib 18 | import os 19 | from typing import Optional, Union 20 | 21 | from huggingface_hub import snapshot_download 22 | 23 | from .configuration_utils import ConfigMixin 24 | from .utils import DIFFUSERS_CACHE, logging 25 | 26 | 27 | INDEX_FILE = "diffusion_model.pt" 28 | 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | 33 | LOADABLE_CLASSES = { 34 | "diffusers": { 35 | "ModelMixin": ["save_pretrained", "from_pretrained"], 36 | "SchedulerMixin": ["save_config", "from_config"], 37 | "DiffusionPipeline": ["save_pretrained", "from_pretrained"], 38 | }, 39 | "transformers": { 40 | "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], 41 | "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], 42 | "PreTrainedModel": ["save_pretrained", "from_pretrained"], 43 | }, 44 | } 45 | 46 | ALL_IMPORTABLE_CLASSES = {} 47 | for library in LOADABLE_CLASSES: 48 | ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) 49 | 50 | 51 | class DiffusionPipeline(ConfigMixin): 52 | 53 | config_name = "model_index.json" 54 | 55 | def register_modules(self, **kwargs): 56 | # import it here to avoid circular import 57 | from diffusers import pipelines 58 | 59 | for name, module in kwargs.items(): 60 | # retrive library 61 | library = module.__module__.split(".")[0] 62 | 63 | # check if the module is a pipeline module 64 | pipeline_file = module.__module__.split(".")[-1] 65 | pipeline_dir = module.__module__.split(".")[-2] 66 | is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir) 67 | 68 | # if library is not in LOADABLE_CLASSES, then it is a custom module. 69 | # Or if it's a pipeline module, then the module is inside the pipeline 70 | # folder so we set the library to module name. 71 | if library not in LOADABLE_CLASSES or is_pipeline_module: 72 | library = pipeline_dir 73 | 74 | # retrive class_name 75 | class_name = module.__class__.__name__ 76 | 77 | register_dict = {name: (library, class_name)} 78 | 79 | # save model index config 80 | self.register_to_config(**register_dict) 81 | 82 | # set models 83 | setattr(self, name, module) 84 | 85 | def save_pretrained(self, save_directory: Union[str, os.PathLike]): 86 | self.save_config(save_directory) 87 | 88 | model_index_dict = dict(self.config) 89 | model_index_dict.pop("_class_name") 90 | model_index_dict.pop("_diffusers_version") 91 | model_index_dict.pop("_module", None) 92 | 93 | for pipeline_component_name in model_index_dict.keys(): 94 | sub_model = getattr(self, pipeline_component_name) 95 | model_cls = sub_model.__class__ 96 | 97 | save_method_name = None 98 | # search for the model's base class in LOADABLE_CLASSES 99 | for library_name, library_classes in LOADABLE_CLASSES.items(): 100 | library = importlib.import_module(library_name) 101 | for base_class, save_load_methods in library_classes.items(): 102 | class_candidate = getattr(library, base_class) 103 | if issubclass(model_cls, class_candidate): 104 | # if we found a suitable base class in LOADABLE_CLASSES then grab its save method 105 | save_method_name = save_load_methods[0] 106 | break 107 | if save_method_name is not None: 108 | break 109 | 110 | save_method = getattr(sub_model, save_method_name) 111 | save_method(os.path.join(save_directory, pipeline_component_name)) 112 | 113 | @classmethod 114 | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): 115 | r""" 116 | Add docstrings 117 | """ 118 | cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) 119 | resume_download = kwargs.pop("resume_download", False) 120 | proxies = kwargs.pop("proxies", None) 121 | local_files_only = kwargs.pop("local_files_only", False) 122 | use_auth_token = kwargs.pop("use_auth_token", None) 123 | 124 | # 1. Download the checkpoints and configs 125 | # use snapshot download here to get it working from from_pretrained 126 | if not os.path.isdir(pretrained_model_name_or_path): 127 | cached_folder = snapshot_download( 128 | pretrained_model_name_or_path, 129 | cache_dir=cache_dir, 130 | resume_download=resume_download, 131 | proxies=proxies, 132 | local_files_only=local_files_only, 133 | use_auth_token=use_auth_token, 134 | ) 135 | else: 136 | cached_folder = pretrained_model_name_or_path 137 | 138 | config_dict = cls.get_config_dict(cached_folder) 139 | 140 | # 2. Load the pipeline class, if using custom module then load it from the hub 141 | # if we load from explicit class, let's use it 142 | if cls != DiffusionPipeline: 143 | pipeline_class = cls 144 | else: 145 | diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) 146 | pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) 147 | 148 | init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) 149 | 150 | init_kwargs = {} 151 | 152 | # import it here to avoid circular import 153 | from diffusers import pipelines 154 | 155 | # 3. Load each module in the pipeline 156 | for name, (library_name, class_name) in init_dict.items(): 157 | is_pipeline_module = hasattr(pipelines, library_name) 158 | # if the model is in a pipeline module, then we load it from the pipeline 159 | if is_pipeline_module: 160 | pipeline_module = getattr(pipelines, library_name) 161 | class_obj = getattr(pipeline_module, class_name) 162 | importable_classes = ALL_IMPORTABLE_CLASSES 163 | class_candidates = {c: class_obj for c in importable_classes.keys()} 164 | else: 165 | # else we just import it from the library. 166 | library = importlib.import_module(library_name) 167 | class_obj = getattr(library, class_name) 168 | importable_classes = LOADABLE_CLASSES[library_name] 169 | class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} 170 | 171 | load_method_name = None 172 | for class_name, class_candidate in class_candidates.items(): 173 | if issubclass(class_obj, class_candidate): 174 | load_method_name = importable_classes[class_name][1] 175 | 176 | load_method = getattr(class_obj, load_method_name) 177 | 178 | # check if the module is in a subdirectory 179 | if os.path.isdir(os.path.join(cached_folder, name)): 180 | loaded_sub_model = load_method(os.path.join(cached_folder, name)) 181 | else: 182 | # else load from the root directory 183 | loaded_sub_model = load_method(cached_folder) 184 | 185 | init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) 186 | 187 | # 5. Instantiate the pipeline 188 | model = pipeline_class(**init_kwargs) 189 | return model 190 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/README.md: -------------------------------------------------------------------------------- 1 | # Pipelines 2 | 3 | - Pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box 4 | - Pipelines should stay as close as possible to their original implementation 5 | - Pipelines can include components of other library, such as text-encoders. 6 | 7 | ## API 8 | 9 | TODO(Patrick, Anton, Suraj) 10 | 11 | ## Examples 12 | 13 | - DDPM for unconditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py). 14 | - DDIM for unconditional image generation in [pipeline_ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py). 15 | - PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py). 16 | - Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py). 17 | - Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py). 18 | - BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py). 19 | - Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py). 20 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available 2 | from .bddm import BDDMPipeline 3 | from .ddim import DDIMPipeline 4 | from .ddpm import DDPMPipeline 5 | from .latent_diffusion_uncond import LatentDiffusionUncondPipeline 6 | from .pndm import PNDMPipeline 7 | from .score_sde_ve import ScoreSdeVePipeline 8 | from .score_sde_vp import ScoreSdeVpPipeline 9 | 10 | 11 | if is_transformers_available(): 12 | from .glide import GlidePipeline 13 | from .latent_diffusion import LatentDiffusionPipeline 14 | 15 | 16 | if is_transformers_available() and is_unidecode_available() and is_inflect_available(): 17 | from .grad_tts import GradTTSPipeline 18 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/bddm/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_bddm import BDDMPipeline, DiffWave 2 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/ddim/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_ddim import DDIMPipeline 2 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/ddim/pipeline_ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | import tqdm 20 | 21 | from ...pipeline_utils import DiffusionPipeline 22 | 23 | 24 | class DDIMPipeline(DiffusionPipeline): 25 | def __init__(self, unet, noise_scheduler): 26 | super().__init__() 27 | noise_scheduler = noise_scheduler.set_format("pt") 28 | self.register_modules(unet=unet, noise_scheduler=noise_scheduler) 29 | 30 | def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): 31 | # eta corresponds to η in paper and should be between [0, 1] 32 | if torch_device is None: 33 | torch_device = "cuda" if torch.cuda.is_available() else "cpu" 34 | 35 | num_trained_timesteps = self.noise_scheduler.config.timesteps 36 | inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) 37 | 38 | self.unet.to(torch_device) 39 | 40 | # Sample gaussian noise to begin loop 41 | image = torch.randn( 42 | (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), 43 | generator=generator, 44 | ) 45 | image = image.to(torch_device) 46 | 47 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 48 | # Ideally, read DDIM paper in-detail understanding 49 | 50 | # Notation ( -> 51 | # - pred_noise_t -> e_theta(x_t, t) 52 | # - pred_original_image -> f_theta(x_t, t) or x_0 53 | # - std_dev_t -> sigma_t 54 | # - eta -> η 55 | # - pred_image_direction -> "direction pointingc to x_t" 56 | # - pred_prev_image -> "x_t-1" 57 | for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): 58 | # 1. predict noise residual 59 | with torch.no_grad(): 60 | residual = self.unet(image, inference_step_times[t]) 61 | 62 | # 2. predict previous mean of image x_t-1 63 | pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta) 64 | 65 | # 3. optionally sample variance 66 | variance = 0 67 | if eta > 0: 68 | noise = torch.randn(image.shape, generator=generator).to(image.device) 69 | variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise 70 | 71 | # 4. set current image to prev_image: x_t -> x_t-1 72 | image = pred_prev_image + variance 73 | 74 | return image 75 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/ddpm/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_ddpm import DDPMPipeline 2 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/ddpm/pipeline_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | import tqdm 20 | 21 | from ...pipeline_utils import DiffusionPipeline 22 | 23 | 24 | class DDPMPipeline(DiffusionPipeline): 25 | def __init__(self, unet, noise_scheduler): 26 | super().__init__() 27 | noise_scheduler = noise_scheduler.set_format("pt") 28 | self.register_modules(unet=unet, noise_scheduler=noise_scheduler) 29 | 30 | def __call__(self, batch_size=1, generator=None, torch_device=None): 31 | if torch_device is None: 32 | torch_device = "cuda" if torch.cuda.is_available() else "cpu" 33 | 34 | self.unet.to(torch_device) 35 | 36 | # Sample gaussian noise to begin loop 37 | image = torch.randn( 38 | (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), 39 | generator=generator, 40 | ) 41 | image = image.to(torch_device) 42 | 43 | num_prediction_steps = len(self.noise_scheduler) 44 | for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): 45 | # 1. predict noise residual 46 | with torch.no_grad(): 47 | residual = self.unet(image, t) 48 | 49 | # 2. predict previous mean of image x_t-1 50 | pred_prev_image = self.noise_scheduler.step(residual, image, t) 51 | 52 | # 3. optionally sample variance 53 | variance = 0 54 | if t > 0: 55 | noise = torch.randn(image.shape, generator=generator).to(image.device) 56 | variance = self.noise_scheduler.get_variance(t).sqrt() * noise 57 | 58 | # 4. set current image to prev_image: x_t -> x_t-1 59 | image = pred_prev_image + variance 60 | 61 | return image 62 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/glide/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_transformers_available 2 | 3 | 4 | if is_transformers_available(): 5 | from .pipeline_glide import CLIPTextModel, GlidePipeline 6 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/grad_tts/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_inflect_available, is_transformers_available, is_unidecode_available 2 | 3 | 4 | if is_transformers_available() and is_unidecode_available() and is_inflect_available(): 5 | from .grad_tts_utils import GradTTSTokenizer 6 | from .pipeline_grad_tts import GradTTSPipeline, TextEncoder 7 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_transformers_available 2 | 3 | 4 | if is_transformers_available(): 5 | from .pipeline_latent_diffusion import LatentDiffusionPipeline, LDMBertModel 6 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/latent_diffusion_uncond/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline 2 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import tqdm 4 | 5 | from ...pipeline_utils import DiffusionPipeline 6 | 7 | 8 | class LatentDiffusionUncondPipeline(DiffusionPipeline): 9 | def __init__(self, vqvae, unet, noise_scheduler): 10 | super().__init__() 11 | noise_scheduler = noise_scheduler.set_format("pt") 12 | self.register_modules(vqvae=vqvae, unet=unet, noise_scheduler=noise_scheduler) 13 | 14 | @torch.no_grad() 15 | def __call__( 16 | self, 17 | batch_size=1, 18 | generator=None, 19 | torch_device=None, 20 | eta=0.0, 21 | num_inference_steps=50, 22 | ): 23 | # eta corresponds to η in paper and should be between [0, 1] 24 | 25 | if torch_device is None: 26 | torch_device = "cuda" if torch.cuda.is_available() else "cpu" 27 | 28 | self.unet.to(torch_device) 29 | self.vqvae.to(torch_device) 30 | 31 | num_trained_timesteps = self.noise_scheduler.config.timesteps 32 | inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) 33 | 34 | image = torch.randn( 35 | (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), 36 | generator=generator, 37 | ).to(torch_device) 38 | 39 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 40 | # Ideally, read DDIM paper in-detail understanding 41 | 42 | # Notation ( -> 43 | # - pred_noise_t -> e_theta(x_t, t) 44 | # - pred_original_image -> f_theta(x_t, t) or x_0 45 | # - std_dev_t -> sigma_t 46 | # - eta -> η 47 | # - pred_image_direction -> "direction pointingc to x_t" 48 | # - pred_prev_image -> "x_t-1" 49 | for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): 50 | # 1. predict noise residual 51 | timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) 52 | pred_noise_t = self.unet(image, timesteps) 53 | 54 | # 2. predict previous mean of image x_t-1 55 | pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) 56 | 57 | # 3. optionally sample variance 58 | variance = 0 59 | if eta > 0: 60 | noise = torch.randn(image.shape, generator=generator).to(image.device) 61 | variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise 62 | 63 | # 4. set current image to prev_image: x_t -> x_t-1 64 | image = pred_prev_image + variance 65 | 66 | # decode image with vae 67 | image = self.vqvae.decode(image) 68 | return image 69 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/pndm/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_pndm import PNDMPipeline 2 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/pndm/pipeline_pndm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | import tqdm 20 | 21 | from ...pipeline_utils import DiffusionPipeline 22 | 23 | 24 | class PNDMPipeline(DiffusionPipeline): 25 | def __init__(self, unet, noise_scheduler): 26 | super().__init__() 27 | noise_scheduler = noise_scheduler.set_format("pt") 28 | self.register_modules(unet=unet, noise_scheduler=noise_scheduler) 29 | 30 | def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): 31 | # For more information on the sampling method you can take a look at Algorithm 2 of 32 | # the official paper: https://arxiv.org/pdf/2202.09778.pdf 33 | if torch_device is None: 34 | torch_device = "cuda" if torch.cuda.is_available() else "cpu" 35 | 36 | self.unet.to(torch_device) 37 | 38 | # Sample gaussian noise to begin loop 39 | image = torch.randn( 40 | (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), 41 | generator=generator, 42 | ) 43 | image = image.to(torch_device) 44 | 45 | prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps) 46 | for t in tqdm.tqdm(range(len(prk_time_steps))): 47 | t_orig = prk_time_steps[t] 48 | residual = self.unet(image, t_orig) 49 | 50 | image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) 51 | 52 | timesteps = self.noise_scheduler.get_time_steps(num_inference_steps) 53 | for t in tqdm.tqdm(range(len(timesteps))): 54 | t_orig = timesteps[t] 55 | residual = self.unet(image, t_orig) 56 | 57 | image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps) 58 | 59 | return image 60 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/score_sde_ve/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_score_sde_ve import ScoreSdeVePipeline 2 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | 4 | from diffusers import DiffusionPipeline 5 | 6 | 7 | # TODO(Patrick, Anton, Suraj) - rename `x` to better variable names 8 | class ScoreSdeVePipeline(DiffusionPipeline): 9 | def __init__(self, model, scheduler): 10 | super().__init__() 11 | self.register_modules(model=model, scheduler=scheduler) 12 | 13 | def __call__(self, num_inference_steps=2000, generator=None): 14 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 15 | 16 | img_size = self.model.config.image_size 17 | channels = self.model.config.num_channels 18 | shape = (1, channels, img_size, img_size) 19 | 20 | model = self.model.to(device) 21 | 22 | # TODO(Patrick) move to scheduler config 23 | n_steps = 1 24 | 25 | x = torch.randn(*shape) * self.scheduler.config.sigma_max 26 | x = x.to(device) 27 | 28 | self.scheduler.set_timesteps(num_inference_steps) 29 | self.scheduler.set_sigmas(num_inference_steps) 30 | 31 | for i, t in enumerate(self.scheduler.timesteps): 32 | sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) 33 | 34 | for _ in range(n_steps): 35 | with torch.no_grad(): 36 | result = self.model(x, sigma_t) 37 | x = self.scheduler.step_correct(result, x) 38 | 39 | with torch.no_grad(): 40 | result = model(x, sigma_t) 41 | 42 | x, x_mean = self.scheduler.step_pred(result, x, t) 43 | 44 | return x_mean 45 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/score_sde_vp/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_score_sde_vp import ScoreSdeVpPipeline 2 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | 4 | from diffusers import DiffusionPipeline 5 | 6 | 7 | # TODO(Patrick, Anton, Suraj) - rename `x` to better variable names 8 | class ScoreSdeVpPipeline(DiffusionPipeline): 9 | def __init__(self, model, scheduler): 10 | super().__init__() 11 | self.register_modules(model=model, scheduler=scheduler) 12 | 13 | def __call__(self, num_inference_steps=1000, generator=None): 14 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 15 | 16 | img_size = self.model.config.image_size 17 | channels = self.model.config.num_channels 18 | shape = (1, channels, img_size, img_size) 19 | 20 | model = self.model.to(device) 21 | 22 | x = torch.randn(*shape).to(device) 23 | 24 | self.scheduler.set_timesteps(num_inference_steps) 25 | 26 | for t in self.scheduler.timesteps: 27 | t = t * torch.ones(shape[0], device=device) 28 | scaled_t = t * (num_inference_steps - 1) 29 | 30 | with torch.no_grad(): 31 | result = model(x, scaled_t) 32 | 33 | x, x_mean = self.scheduler.step_pred(result, x, t) 34 | 35 | x_mean = (x_mean + 1.0) / 2.0 36 | 37 | return x_mean 38 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/README.md: -------------------------------------------------------------------------------- 1 | # Schedulers 2 | 3 | - Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps. 4 | - Schedulers can be used interchangable between diffusion models in inference to find the preferred tradef-off between speed and generation quality. 5 | - Schedulers are available in numpy, but can easily be transformed into PyTorch. 6 | 7 | ## API 8 | 9 | - Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during 10 | the forward pass. 11 | - Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch 12 | with a `set_format(...)` method. 13 | 14 | ## Examples 15 | 16 | - The DDPM scheduler was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py). An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py). 17 | - The DDIM scheduler was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py). 18 | - The PNMD scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py). 19 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2022 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .scheduling_ddim import DDIMScheduler 20 | from .scheduling_ddpm import DDPMScheduler 21 | from .scheduling_grad_tts import GradTTSScheduler 22 | from .scheduling_pndm import PNDMScheduler 23 | from .scheduling_sde_ve import ScoreSdeVeScheduler 24 | from .scheduling_sde_vp import ScoreSdeVpScheduler 25 | from .scheduling_utils import SchedulerMixin 26 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/scheduling_ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion 16 | # and https://github.com/hojonathanho/diffusion 17 | 18 | import math 19 | 20 | import numpy as np 21 | 22 | from ..configuration_utils import ConfigMixin 23 | from .scheduling_utils import SchedulerMixin 24 | 25 | 26 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): 27 | """ 28 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 29 | (1-beta) over time from t = [0,1]. 30 | 31 | :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t 32 | from 0 to 1 and 33 | produces the cumulative product of (1-beta) up to that part of the diffusion process. 34 | :param max_beta: the maximum beta to use; use values lower than 1 to 35 | prevent singularities. 36 | """ 37 | 38 | def alpha_bar(time_step): 39 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 40 | 41 | betas = [] 42 | for i in range(num_diffusion_timesteps): 43 | t1 = i / num_diffusion_timesteps 44 | t2 = (i + 1) / num_diffusion_timesteps 45 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 46 | return np.array(betas, dtype=np.float32) 47 | 48 | 49 | class DDIMScheduler(SchedulerMixin, ConfigMixin): 50 | def __init__( 51 | self, 52 | timesteps=1000, 53 | beta_start=0.0001, 54 | beta_end=0.02, 55 | beta_schedule="linear", 56 | trained_betas=None, 57 | timestep_values=None, 58 | clip_sample=True, 59 | tensor_format="np", 60 | ): 61 | super().__init__() 62 | self.register_to_config( 63 | timesteps=timesteps, 64 | beta_start=beta_start, 65 | beta_end=beta_end, 66 | beta_schedule=beta_schedule, 67 | trained_betas=trained_betas, 68 | timestep_values=timestep_values, 69 | clip_sample=clip_sample, 70 | ) 71 | 72 | if beta_schedule == "linear": 73 | self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) 74 | elif beta_schedule == "scaled_linear": 75 | # this schedule is very specific to the latent diffusion model. 76 | self.betas = np.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=np.float32) ** 2 77 | elif beta_schedule == "squaredcos_cap_v2": 78 | # Glide cosine schedule 79 | self.betas = betas_for_alpha_bar(timesteps) 80 | else: 81 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 82 | 83 | self.alphas = 1.0 - self.betas 84 | self.alphas_cumprod = np.cumprod(self.alphas, axis=0) 85 | self.one = np.array(1.0) 86 | 87 | self.set_format(tensor_format=tensor_format) 88 | 89 | def get_variance(self, t, num_inference_steps): 90 | orig_t = self.config.timesteps // num_inference_steps * t 91 | orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1 92 | 93 | alpha_prod_t = self.alphas_cumprod[orig_t] 94 | alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one 95 | beta_prod_t = 1 - alpha_prod_t 96 | beta_prod_t_prev = 1 - alpha_prod_t_prev 97 | 98 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) 99 | 100 | return variance 101 | 102 | def step(self, residual, sample, t, num_inference_steps, eta, use_clipped_residual=False): 103 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 104 | # Ideally, read DDIM paper in-detail understanding 105 | 106 | # Notation ( -> 107 | # - pred_noise_t -> e_theta(x_t, t) 108 | # - pred_original_sample -> f_theta(x_t, t) or x_0 109 | # - std_dev_t -> sigma_t 110 | # - eta -> η 111 | # - pred_sample_direction -> "direction pointingc to x_t" 112 | # - pred_prev_sample -> "x_t-1" 113 | 114 | # 1. get actual t and t-1 115 | orig_t = self.config.timesteps // num_inference_steps * t 116 | orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1 117 | 118 | # 2. compute alphas, betas 119 | alpha_prod_t = self.alphas_cumprod[orig_t] 120 | alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one 121 | beta_prod_t = 1 - alpha_prod_t 122 | 123 | # 3. compute predicted original sample from predicted noise also called 124 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 125 | pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) 126 | 127 | # 4. Clip "predicted x_0" 128 | if self.config.clip_sample: 129 | pred_original_sample = self.clip(pred_original_sample, -1, 1) 130 | 131 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 132 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 133 | variance = self.get_variance(t, num_inference_steps) 134 | std_dev_t = eta * variance ** (0.5) 135 | 136 | if use_clipped_residual: 137 | # the residual is always re-derived from the clipped x_0 in Glide 138 | residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 139 | 140 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 141 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual 142 | 143 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 144 | pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 145 | 146 | return pred_prev_sample 147 | 148 | def add_noise(self, original_samples, noise, timesteps): 149 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 150 | sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) 151 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 152 | sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) 153 | 154 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 155 | return noisy_samples 156 | 157 | def __len__(self): 158 | return self.config.timesteps 159 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/scheduling_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim 16 | 17 | import math 18 | 19 | import numpy as np 20 | 21 | from ..configuration_utils import ConfigMixin 22 | from .scheduling_utils import SchedulerMixin 23 | 24 | 25 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): 26 | """ 27 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 28 | (1-beta) over time from t = [0,1]. 29 | 30 | :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t 31 | from 0 to 1 and 32 | produces the cumulative product of (1-beta) up to that part of the diffusion process. 33 | :param max_beta: the maximum beta to use; use values lower than 1 to 34 | prevent singularities. 35 | """ 36 | 37 | def alpha_bar(time_step): 38 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 39 | 40 | betas = [] 41 | for i in range(num_diffusion_timesteps): 42 | t1 = i / num_diffusion_timesteps 43 | t2 = (i + 1) / num_diffusion_timesteps 44 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 45 | return np.array(betas, dtype=np.float32) 46 | 47 | 48 | class DDPMScheduler(SchedulerMixin, ConfigMixin): 49 | def __init__( 50 | self, 51 | timesteps=1000, 52 | beta_start=0.0001, 53 | beta_end=0.02, 54 | beta_schedule="linear", 55 | trained_betas=None, 56 | timestep_values=None, 57 | variance_type="fixed_small", 58 | clip_sample=True, 59 | tensor_format="np", 60 | ): 61 | super().__init__() 62 | self.register_to_config( 63 | timesteps=timesteps, 64 | beta_start=beta_start, 65 | beta_end=beta_end, 66 | beta_schedule=beta_schedule, 67 | trained_betas=trained_betas, 68 | timestep_values=timestep_values, 69 | variance_type=variance_type, 70 | clip_sample=clip_sample, 71 | ) 72 | 73 | if trained_betas is not None: 74 | self.betas = np.asarray(trained_betas) 75 | elif beta_schedule == "linear": 76 | self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) 77 | elif beta_schedule == "squaredcos_cap_v2": 78 | # Glide cosine schedule 79 | self.betas = betas_for_alpha_bar(timesteps) 80 | else: 81 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 82 | 83 | self.alphas = 1.0 - self.betas 84 | self.alphas_cumprod = np.cumprod(self.alphas, axis=0) 85 | self.one = np.array(1.0) 86 | 87 | self.set_format(tensor_format=tensor_format) 88 | 89 | def get_variance(self, t, variance_type=None): 90 | alpha_prod_t = self.alphas_cumprod[t] 91 | alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one 92 | 93 | # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) 94 | # and sample from it to get previous sample 95 | # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample 96 | variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] 97 | 98 | if variance_type is None: 99 | variance_type = self.config.variance_type 100 | 101 | # hacks - were probs added for training stability 102 | if variance_type == "fixed_small": 103 | variance = self.clip(variance, min_value=1e-20) 104 | # for rl-diffuser https://arxiv.org/abs/2205.09991 105 | elif variance_type == "fixed_small_log": 106 | variance = self.log(self.clip(variance, min_value=1e-20)) 107 | elif variance_type == "fixed_large": 108 | variance = self.betas[t] 109 | elif variance_type == "fixed_large_log": 110 | # Glide max_log 111 | variance = self.log(self.betas[t]) 112 | 113 | return variance 114 | 115 | def step(self, residual, sample, t, predict_epsilon=True): 116 | # 1. compute alphas, betas 117 | alpha_prod_t = self.alphas_cumprod[t] 118 | alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one 119 | beta_prod_t = 1 - alpha_prod_t 120 | beta_prod_t_prev = 1 - alpha_prod_t_prev 121 | 122 | # 2. compute predicted original sample from predicted noise also called 123 | # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf 124 | if predict_epsilon: 125 | pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) 126 | else: 127 | pred_original_sample = residual 128 | 129 | # 3. Clip "predicted x_0" 130 | if self.config.clip_sample: 131 | pred_original_sample = self.clip(pred_original_sample, -1, 1) 132 | 133 | # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t 134 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 135 | pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t 136 | current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t 137 | 138 | # 5. Compute predicted previous sample µ_t 139 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 140 | pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample 141 | 142 | return pred_prev_sample 143 | 144 | def add_noise(self, original_samples, noise, timesteps): 145 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 146 | sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) 147 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 148 | sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) 149 | 150 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 151 | return noisy_samples 152 | 153 | def __len__(self): 154 | return self.config.timesteps 155 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/scheduling_grad_tts.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | from ..configuration_utils import ConfigMixin 18 | from .scheduling_utils import SchedulerMixin 19 | 20 | 21 | class GradTTSScheduler(SchedulerMixin, ConfigMixin): 22 | def __init__( 23 | self, 24 | beta_start=0.05, 25 | beta_end=20, 26 | tensor_format="np", 27 | ): 28 | super().__init__() 29 | self.register_to_config( 30 | beta_start=beta_start, 31 | beta_end=beta_end, 32 | ) 33 | self.set_format(tensor_format=tensor_format) 34 | self.betas = None 35 | 36 | def get_timesteps(self, num_inference_steps): 37 | return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)]) 38 | 39 | def set_betas(self, num_inference_steps): 40 | timesteps = self.get_timesteps(num_inference_steps) 41 | self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps]) 42 | 43 | def step(self, residual, sample, t, num_inference_steps): 44 | # This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix) 45 | if self.betas is None: 46 | self.set_betas(num_inference_steps) 47 | 48 | beta_t = self.betas[t] 49 | beta_t_deriv = beta_t / num_inference_steps 50 | 51 | sample_deriv = residual * beta_t_deriv / 2 52 | 53 | sample = sample + sample_deriv 54 | return sample 55 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/scheduling_pndm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim 16 | 17 | import math 18 | 19 | import numpy as np 20 | 21 | from ..configuration_utils import ConfigMixin 22 | from .scheduling_utils import SchedulerMixin 23 | 24 | 25 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): 26 | """ 27 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 28 | (1-beta) over time from t = [0,1]. 29 | 30 | :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t 31 | from 0 to 1 and 32 | produces the cumulative product of (1-beta) up to that part of the diffusion process. 33 | :param max_beta: the maximum beta to use; use values lower than 1 to 34 | prevent singularities. 35 | """ 36 | 37 | def alpha_bar(time_step): 38 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 39 | 40 | betas = [] 41 | for i in range(num_diffusion_timesteps): 42 | t1 = i / num_diffusion_timesteps 43 | t2 = (i + 1) / num_diffusion_timesteps 44 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 45 | return np.array(betas, dtype=np.float32) 46 | 47 | 48 | class PNDMScheduler(SchedulerMixin, ConfigMixin): 49 | def __init__( 50 | self, 51 | timesteps=1000, 52 | beta_start=0.0001, 53 | beta_end=0.02, 54 | beta_schedule="linear", 55 | tensor_format="np", 56 | ): 57 | super().__init__() 58 | self.register_to_config( 59 | timesteps=timesteps, 60 | beta_start=beta_start, 61 | beta_end=beta_end, 62 | beta_schedule=beta_schedule, 63 | ) 64 | 65 | if beta_schedule == "linear": 66 | self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) 67 | elif beta_schedule == "squaredcos_cap_v2": 68 | # Glide cosine schedule 69 | self.betas = betas_for_alpha_bar(timesteps) 70 | else: 71 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 72 | 73 | self.alphas = 1.0 - self.betas 74 | self.alphas_cumprod = np.cumprod(self.alphas, axis=0) 75 | 76 | self.one = np.array(1.0) 77 | 78 | self.set_format(tensor_format=tensor_format) 79 | 80 | # For now we only support F-PNDM, i.e. the runge-kutta method 81 | # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf 82 | # mainly at formula (9), (12), (13) and the Algorithm 2. 83 | self.pndm_order = 4 84 | 85 | # running values 86 | self.cur_residual = 0 87 | self.cur_sample = None 88 | self.ets = [] 89 | self.prk_time_steps = {} 90 | self.time_steps = {} 91 | self.set_prk_mode() 92 | 93 | def get_prk_time_steps(self, num_inference_steps): 94 | if num_inference_steps in self.prk_time_steps: 95 | return self.prk_time_steps[num_inference_steps] 96 | 97 | inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps)) 98 | 99 | prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( 100 | np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order 101 | ) 102 | self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) 103 | 104 | return self.prk_time_steps[num_inference_steps] 105 | 106 | def get_time_steps(self, num_inference_steps): 107 | if num_inference_steps in self.time_steps: 108 | return self.time_steps[num_inference_steps] 109 | 110 | inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps)) 111 | self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) 112 | 113 | return self.time_steps[num_inference_steps] 114 | 115 | def set_prk_mode(self): 116 | self.mode = "prk" 117 | 118 | def set_plms_mode(self): 119 | self.mode = "plms" 120 | 121 | def step(self, *args, **kwargs): 122 | if self.mode == "prk": 123 | return self.step_prk(*args, **kwargs) 124 | if self.mode == "plms": 125 | return self.step_plms(*args, **kwargs) 126 | 127 | raise ValueError(f"mode {self.mode} does not exist.") 128 | 129 | def step_prk(self, residual, sample, t, num_inference_steps): 130 | prk_time_steps = self.get_prk_time_steps(num_inference_steps) 131 | 132 | t_orig = prk_time_steps[t // 4 * 4] 133 | t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)] 134 | 135 | if t % 4 == 0: 136 | self.cur_residual += 1 / 6 * residual 137 | self.ets.append(residual) 138 | self.cur_sample = sample 139 | elif (t - 1) % 4 == 0: 140 | self.cur_residual += 1 / 3 * residual 141 | elif (t - 2) % 4 == 0: 142 | self.cur_residual += 1 / 3 * residual 143 | elif (t - 3) % 4 == 0: 144 | residual = self.cur_residual + 1 / 6 * residual 145 | self.cur_residual = 0 146 | 147 | # cur_sample should not be `None` 148 | cur_sample = self.cur_sample if self.cur_sample is not None else sample 149 | 150 | return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual) 151 | 152 | def step_plms(self, residual, sample, t, num_inference_steps): 153 | if len(self.ets) < 3: 154 | raise ValueError( 155 | f"{self.__class__} can only be run AFTER scheduler has been run " 156 | "in 'prk' mode for at least 12 iterations " 157 | "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " 158 | "for more information." 159 | ) 160 | 161 | timesteps = self.get_time_steps(num_inference_steps) 162 | 163 | t_orig = timesteps[t] 164 | t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)] 165 | self.ets.append(residual) 166 | 167 | residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) 168 | 169 | return self.get_prev_sample(sample, t_orig, t_orig_prev, residual) 170 | 171 | def get_prev_sample(self, sample, t_orig, t_orig_prev, residual): 172 | # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf 173 | # this function computes x_(t−δ) using the formula of (9) 174 | # Note that x_t needs to be added to both sides of the equation 175 | 176 | # Notation ( -> 177 | # alpha_prod_t -> α_t 178 | # alpha_prod_t_prev -> α_(t−δ) 179 | # beta_prod_t -> (1 - α_t) 180 | # beta_prod_t_prev -> (1 - α_(t−δ)) 181 | # sample -> x_t 182 | # residual -> e_θ(x_t, t) 183 | # prev_sample -> x_(t−δ) 184 | alpha_prod_t = self.alphas_cumprod[t_orig + 1] 185 | alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1] 186 | beta_prod_t = 1 - alpha_prod_t 187 | beta_prod_t_prev = 1 - alpha_prod_t_prev 188 | 189 | # corresponds to (α_(t−δ) - α_t) divided by 190 | # denominator of x_t in formula (9) and plus 1 191 | # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = 192 | # sqrt(α_(t−δ)) / sqrt(α_t)) 193 | sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) 194 | 195 | # corresponds to denominator of e_θ(x_t, t) in formula (9) 196 | residual_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( 197 | alpha_prod_t * beta_prod_t * alpha_prod_t_prev 198 | ) ** (0.5) 199 | 200 | # full formula (9) 201 | prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * residual / residual_denom_coeff 202 | 203 | return prev_sample 204 | 205 | def __len__(self): 206 | return self.config.timesteps 207 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/scheduling_sde_ve.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch 16 | 17 | # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from ..configuration_utils import ConfigMixin 23 | from .scheduling_utils import SchedulerMixin 24 | 25 | 26 | class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): 27 | def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): 28 | super().__init__() 29 | self.register_to_config( 30 | snr=snr, 31 | sigma_min=sigma_min, 32 | sigma_max=sigma_max, 33 | sampling_eps=sampling_eps, 34 | ) 35 | 36 | self.sigmas = None 37 | self.discrete_sigmas = None 38 | self.timesteps = None 39 | 40 | def set_timesteps(self, num_inference_steps): 41 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) 42 | 43 | def set_sigmas(self, num_inference_steps): 44 | if self.timesteps is None: 45 | self.set_timesteps(num_inference_steps) 46 | 47 | self.discrete_sigmas = torch.exp( 48 | torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) 49 | ) 50 | self.sigmas = torch.tensor( 51 | [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] 52 | ) 53 | 54 | def step_pred(self, result, x, t): 55 | # TODO(Patrick) better comments + non-PyTorch 56 | t = t * torch.ones(x.shape[0], device=x.device) 57 | timestep = (t * (len(self.timesteps) - 1)).long() 58 | 59 | sigma = self.discrete_sigmas.to(t.device)[timestep] 60 | adjacent_sigma = torch.where( 61 | timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device) 62 | ) 63 | f = torch.zeros_like(x) 64 | G = torch.sqrt(sigma**2 - adjacent_sigma**2) 65 | 66 | f = f - G[:, None, None, None] ** 2 * result 67 | 68 | z = torch.randn_like(x) 69 | x_mean = x - f 70 | x = x_mean + G[:, None, None, None] * z 71 | return x, x_mean 72 | 73 | def step_correct(self, result, x): 74 | # TODO(Patrick) better comments + non-PyTorch 75 | noise = torch.randn_like(x) 76 | grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() 77 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() 78 | step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 79 | step_size = step_size * torch.ones(x.shape[0], device=x.device) 80 | x_mean = x + step_size[:, None, None, None] * result 81 | 82 | x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/scheduling_sde_vp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch 16 | 17 | # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from ..configuration_utils import ConfigMixin 23 | from .scheduling_utils import SchedulerMixin 24 | 25 | 26 | class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): 27 | def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): 28 | super().__init__() 29 | self.register_to_config( 30 | beta_min=beta_min, 31 | beta_max=beta_max, 32 | sampling_eps=sampling_eps, 33 | ) 34 | 35 | self.sigmas = None 36 | self.discrete_sigmas = None 37 | self.timesteps = None 38 | 39 | def set_timesteps(self, num_inference_steps): 40 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) 41 | 42 | def step_pred(self, result, x, t): 43 | # TODO(Patrick) better comments + non-PyTorch 44 | # postprocess model result 45 | log_mean_coeff = ( 46 | -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min 47 | ) 48 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) 49 | result = -result / std[:, None, None, None] 50 | 51 | # compute 52 | dt = -1.0 / len(self.timesteps) 53 | 54 | beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) 55 | drift = -0.5 * beta_t[:, None, None, None] * x 56 | diffusion = torch.sqrt(beta_t) 57 | drift = drift - diffusion[:, None, None, None] ** 2 * result 58 | x_mean = x + drift * dt 59 | 60 | # add noise 61 | z = torch.randn_like(x) 62 | x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z 63 | 64 | return x, x_mean 65 | -------------------------------------------------------------------------------- /src/diffusers/schedulers/scheduling_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Union 15 | 16 | import numpy as np 17 | import torch 18 | 19 | 20 | SCHEDULER_CONFIG_NAME = "scheduler_config.json" 21 | 22 | 23 | class SchedulerMixin: 24 | 25 | config_name = SCHEDULER_CONFIG_NAME 26 | 27 | def set_format(self, tensor_format="pt"): 28 | self.tensor_format = tensor_format 29 | if tensor_format == "pt": 30 | for key, value in vars(self).items(): 31 | if isinstance(value, np.ndarray): 32 | setattr(self, key, torch.from_numpy(value)) 33 | 34 | return self 35 | 36 | def clip(self, tensor, min_value=None, max_value=None): 37 | tensor_format = getattr(self, "tensor_format", "pt") 38 | 39 | if tensor_format == "np": 40 | return np.clip(tensor, min_value, max_value) 41 | elif tensor_format == "pt": 42 | return torch.clamp(tensor, min_value, max_value) 43 | 44 | raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") 45 | 46 | def log(self, tensor): 47 | tensor_format = getattr(self, "tensor_format", "pt") 48 | 49 | if tensor_format == "np": 50 | return np.log(tensor) 51 | elif tensor_format == "pt": 52 | return torch.log(tensor) 53 | 54 | raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") 55 | 56 | def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): 57 | """ 58 | Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. 59 | 60 | Args: 61 | timesteps: an array or tensor of values to extract. 62 | broadcast_array: an array with a larger shape of K dimensions with the batch 63 | dimension equal to the length of timesteps. 64 | Returns: 65 | a tensor of shape [batch_size, 1, ...] where the shape has K dims. 66 | """ 67 | 68 | tensor_format = getattr(self, "tensor_format", "pt") 69 | values = values.flatten() 70 | 71 | while len(values.shape) < len(broadcast_array.shape): 72 | values = values[..., None] 73 | if tensor_format == "pt": 74 | values = values.to(broadcast_array.device) 75 | 76 | return values 77 | -------------------------------------------------------------------------------- /src/diffusers/testing_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import unittest 4 | from distutils.util import strtobool 5 | 6 | import torch 7 | 8 | 9 | global_rng = random.Random() 10 | torch_device = "cuda" if torch.cuda.is_available() else "cpu" 11 | 12 | 13 | def parse_flag_from_env(key, default=False): 14 | try: 15 | value = os.environ[key] 16 | except KeyError: 17 | # KEY isn't set, default to `default`. 18 | _value = default 19 | else: 20 | # KEY is set, convert it to True or False. 21 | try: 22 | _value = strtobool(value) 23 | except ValueError: 24 | # More values are supported, but let's keep the message simple. 25 | raise ValueError(f"If set, {key} must be yes or no.") 26 | return _value 27 | 28 | 29 | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) 30 | 31 | 32 | def floats_tensor(shape, scale=1.0, rng=None, name=None): 33 | """Creates a random float32 tensor""" 34 | if rng is None: 35 | rng = global_rng 36 | 37 | total_dims = 1 38 | for dim in shape: 39 | total_dims *= dim 40 | 41 | values = [] 42 | for _ in range(total_dims): 43 | values.append(rng.random() * scale) 44 | 45 | return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() 46 | 47 | 48 | def slow(test_case): 49 | """ 50 | Decorator marking a test as slow. 51 | 52 | Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. 53 | 54 | """ 55 | return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) 56 | -------------------------------------------------------------------------------- /src/diffusers/training_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | 6 | class EMAModel: 7 | """ 8 | Exponential Moving Average of models weights 9 | """ 10 | 11 | def __init__( 12 | self, 13 | model, 14 | update_after_step=0, 15 | inv_gamma=1.0, 16 | power=2 / 3, 17 | min_value=0.0, 18 | max_value=0.9999, 19 | device=None, 20 | ): 21 | """ 22 | @crowsonkb's notes on EMA Warmup: 23 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 24 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 25 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 26 | at 215.4k steps). 27 | Args: 28 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 29 | power (float): Exponential factor of EMA warmup. Default: 2/3. 30 | min_value (float): The minimum EMA decay rate. Default: 0. 31 | """ 32 | 33 | self.averaged_model = copy.deepcopy(model).eval() 34 | self.averaged_model.requires_grad_(False) 35 | 36 | self.update_after_step = update_after_step 37 | self.inv_gamma = inv_gamma 38 | self.power = power 39 | self.min_value = min_value 40 | self.max_value = max_value 41 | 42 | if device is not None: 43 | self.averaged_model = self.averaged_model.to(device=device) 44 | 45 | self.decay = 0.0 46 | self.optimization_step = 0 47 | 48 | def get_decay(self, optimization_step): 49 | """ 50 | Compute the decay factor for the exponential moving average. 51 | """ 52 | step = max(0, optimization_step - self.update_after_step - 1) 53 | value = 1 - (1 + step / self.inv_gamma) ** -self.power 54 | 55 | if step <= 0: 56 | return 0.0 57 | 58 | return max(self.min_value, min(value, self.max_value)) 59 | 60 | @torch.no_grad() 61 | def step(self, new_model): 62 | ema_state_dict = {} 63 | ema_params = self.averaged_model.state_dict() 64 | 65 | self.decay = self.get_decay(self.optimization_step) 66 | 67 | for key, param in new_model.named_parameters(): 68 | if isinstance(param, dict): 69 | continue 70 | try: 71 | ema_param = ema_params[key] 72 | except KeyError: 73 | ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) 74 | ema_params[key] = ema_param 75 | 76 | if not param.requires_grad: 77 | ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) 78 | ema_param = ema_params[key] 79 | else: 80 | ema_param.mul_(self.decay) 81 | ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) 82 | 83 | ema_state_dict[key] = ema_param 84 | 85 | for key, param in new_model.named_buffers(): 86 | ema_state_dict[key] = param 87 | 88 | self.averaged_model.load_state_dict(ema_state_dict, strict=False) 89 | self.optimization_step += 1 90 | -------------------------------------------------------------------------------- /src/diffusers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | import os 16 | from collections import OrderedDict 17 | 18 | import importlib_metadata 19 | from requests.exceptions import HTTPError 20 | 21 | from .logging import get_logger 22 | 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | hf_cache_home = os.path.expanduser( 28 | os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) 29 | ) 30 | default_cache_path = os.path.join(hf_cache_home, "diffusers") 31 | 32 | 33 | CONFIG_NAME = "config.json" 34 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" 35 | DIFFUSERS_CACHE = default_cache_path 36 | DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" 37 | HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) 38 | 39 | 40 | _transformers_available = importlib.util.find_spec("transformers") is not None 41 | try: 42 | _transformers_version = importlib_metadata.version("transformers") 43 | logger.debug(f"Successfully imported transformers version {_transformers_version}") 44 | except importlib_metadata.PackageNotFoundError: 45 | _transformers_available = False 46 | 47 | 48 | _inflect_available = importlib.util.find_spec("inflect") is not None 49 | try: 50 | _inflect_version = importlib_metadata.version("inflect") 51 | logger.debug(f"Successfully imported inflect version {_inflect_version}") 52 | except importlib_metadata.PackageNotFoundError: 53 | _inflect_available = False 54 | 55 | 56 | _unidecode_available = importlib.util.find_spec("unidecode") is not None 57 | try: 58 | _unidecode_version = importlib_metadata.version("unidecode") 59 | logger.debug(f"Successfully imported unidecode version {_unidecode_version}") 60 | except importlib_metadata.PackageNotFoundError: 61 | _unidecode_available = False 62 | 63 | 64 | def is_transformers_available(): 65 | return _transformers_available 66 | 67 | 68 | def is_inflect_available(): 69 | return _inflect_available 70 | 71 | 72 | def is_unidecode_available(): 73 | return _unidecode_available 74 | 75 | 76 | class RepositoryNotFoundError(HTTPError): 77 | """ 78 | Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does 79 | not have access to. 80 | """ 81 | 82 | 83 | class EntryNotFoundError(HTTPError): 84 | """Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename.""" 85 | 86 | 87 | class RevisionNotFoundError(HTTPError): 88 | """Raised when trying to access a hf.co URL with a valid repository but an invalid revision.""" 89 | 90 | 91 | TRANSFORMERS_IMPORT_ERROR = """ 92 | {0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip 93 | install transformers` 94 | """ 95 | 96 | 97 | UNIDECODE_IMPORT_ERROR = """ 98 | {0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install 99 | Unidecode` 100 | """ 101 | 102 | 103 | INFLECT_IMPORT_ERROR = """ 104 | {0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install 105 | inflect` 106 | """ 107 | 108 | 109 | BACKENDS_MAPPING = OrderedDict( 110 | [ 111 | ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), 112 | ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), 113 | ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), 114 | ] 115 | ) 116 | 117 | 118 | def requires_backends(obj, backends): 119 | if not isinstance(backends, (list, tuple)): 120 | backends = [backends] 121 | 122 | name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ 123 | checks = (BACKENDS_MAPPING[backend] for backend in backends) 124 | failed = [msg.format(name) for available, msg in checks if not available()] 125 | if failed: 126 | raise ImportError("".join(failed)) 127 | 128 | 129 | class DummyObject(type): 130 | """ 131 | Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by 132 | `requires_backend` each time a user tries to access any method of that class. 133 | """ 134 | 135 | def __getattr__(cls, key): 136 | if key.startswith("_"): 137 | return super().__getattr__(cls, key) 138 | requires_backends(cls, cls._backends) 139 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | # flake8: noqa 3 | from ..utils import DummyObject, requires_backends 4 | 5 | 6 | class GradTTSPipeline(metaclass=DummyObject): 7 | _backends = ["transformers", "inflect", "unidecode"] 8 | 9 | def __init__(self, *args, **kwargs): 10 | requires_backends(self, ["transformers", "inflect", "unidecode"]) 11 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_transformers_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | # flake8: noqa 3 | from ..utils import DummyObject, requires_backends 4 | 5 | 6 | class GlideSuperResUNetModel(metaclass=DummyObject): 7 | _backends = ["transformers"] 8 | 9 | def __init__(self, *args, **kwargs): 10 | requires_backends(self, ["transformers"]) 11 | 12 | 13 | class GlideTextToImageUNetModel(metaclass=DummyObject): 14 | _backends = ["transformers"] 15 | 16 | def __init__(self, *args, **kwargs): 17 | requires_backends(self, ["transformers"]) 18 | 19 | 20 | class GlideUNetModel(metaclass=DummyObject): 21 | _backends = ["transformers"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["transformers"]) 25 | 26 | 27 | class UNetGradTTSModel(metaclass=DummyObject): 28 | _backends = ["transformers"] 29 | 30 | def __init__(self, *args, **kwargs): 31 | requires_backends(self, ["transformers"]) 32 | 33 | 34 | class GlidePipeline(metaclass=DummyObject): 35 | _backends = ["transformers"] 36 | 37 | def __init__(self, *args, **kwargs): 38 | requires_backends(self, ["transformers"]) 39 | 40 | 41 | class LatentDiffusionPipeline(metaclass=DummyObject): 42 | _backends = ["transformers"] 43 | 44 | def __init__(self, *args, **kwargs): 45 | requires_backends(self, ["transformers"]) 46 | -------------------------------------------------------------------------------- /src/diffusers/utils/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Optuna, Hugging Face 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logging utilities.""" 16 | 17 | import logging 18 | import os 19 | import sys 20 | import threading 21 | from logging import CRITICAL # NOQA 22 | from logging import DEBUG # NOQA 23 | from logging import ERROR # NOQA 24 | from logging import FATAL # NOQA 25 | from logging import INFO # NOQA 26 | from logging import NOTSET # NOQA 27 | from logging import WARN # NOQA 28 | from logging import WARNING # NOQA 29 | from typing import Optional 30 | 31 | from tqdm import auto as tqdm_lib 32 | 33 | 34 | _lock = threading.Lock() 35 | _default_handler: Optional[logging.Handler] = None 36 | 37 | log_levels = { 38 | "debug": logging.DEBUG, 39 | "info": logging.INFO, 40 | "warning": logging.WARNING, 41 | "error": logging.ERROR, 42 | "critical": logging.CRITICAL, 43 | } 44 | 45 | _default_log_level = logging.WARNING 46 | 47 | _tqdm_active = True 48 | 49 | 50 | def _get_default_logging_level(): 51 | """ 52 | If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is 53 | not - fall back to `_default_log_level` 54 | """ 55 | env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) 56 | if env_level_str: 57 | if env_level_str in log_levels: 58 | return log_levels[env_level_str] 59 | else: 60 | logging.getLogger().warning( 61 | f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " 62 | f"has to be one of: { ', '.join(log_levels.keys()) }" 63 | ) 64 | return _default_log_level 65 | 66 | 67 | def _get_library_name() -> str: 68 | 69 | return __name__.split(".")[0] 70 | 71 | 72 | def _get_library_root_logger() -> logging.Logger: 73 | 74 | return logging.getLogger(_get_library_name()) 75 | 76 | 77 | def _configure_library_root_logger() -> None: 78 | 79 | global _default_handler 80 | 81 | with _lock: 82 | if _default_handler: 83 | # This library has already configured the library root logger. 84 | return 85 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 86 | _default_handler.flush = sys.stderr.flush 87 | 88 | # Apply our default configuration to the library root logger. 89 | library_root_logger = _get_library_root_logger() 90 | library_root_logger.addHandler(_default_handler) 91 | library_root_logger.setLevel(_get_default_logging_level()) 92 | library_root_logger.propagate = False 93 | 94 | 95 | def _reset_library_root_logger() -> None: 96 | 97 | global _default_handler 98 | 99 | with _lock: 100 | if not _default_handler: 101 | return 102 | 103 | library_root_logger = _get_library_root_logger() 104 | library_root_logger.removeHandler(_default_handler) 105 | library_root_logger.setLevel(logging.NOTSET) 106 | _default_handler = None 107 | 108 | 109 | def get_log_levels_dict(): 110 | return log_levels 111 | 112 | 113 | def get_logger(name: Optional[str] = None) -> logging.Logger: 114 | """ 115 | Return a logger with the specified name. 116 | 117 | This function is not supposed to be directly accessed unless you are writing a custom diffusers module. 118 | """ 119 | 120 | if name is None: 121 | name = _get_library_name() 122 | 123 | _configure_library_root_logger() 124 | return logging.getLogger(name) 125 | 126 | 127 | def get_verbosity() -> int: 128 | """ 129 | Return the current level for the 🤗 Diffusers' root logger as an int. 130 | 131 | Returns: 132 | `int`: The logging level. 133 | 134 | 135 | 136 | 🤗 Diffusers has following logging levels: 137 | 138 | - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` 139 | - 40: `diffusers.logging.ERROR` 140 | - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN` 141 | - 20: `diffusers.logging.INFO` 142 | - 10: `diffusers.logging.DEBUG` 143 | 144 | """ 145 | 146 | _configure_library_root_logger() 147 | return _get_library_root_logger().getEffectiveLevel() 148 | 149 | 150 | def set_verbosity(verbosity: int) -> None: 151 | """ 152 | Set the verbosity level for the 🤗 Diffusers' root logger. 153 | 154 | Args: 155 | verbosity (`int`): 156 | Logging level, e.g., one of: 157 | 158 | - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` 159 | - `diffusers.logging.ERROR` 160 | - `diffusers.logging.WARNING` or `diffusers.logging.WARN` 161 | - `diffusers.logging.INFO` 162 | - `diffusers.logging.DEBUG` 163 | """ 164 | 165 | _configure_library_root_logger() 166 | _get_library_root_logger().setLevel(verbosity) 167 | 168 | 169 | def set_verbosity_info(): 170 | """Set the verbosity to the `INFO` level.""" 171 | return set_verbosity(INFO) 172 | 173 | 174 | def set_verbosity_warning(): 175 | """Set the verbosity to the `WARNING` level.""" 176 | return set_verbosity(WARNING) 177 | 178 | 179 | def set_verbosity_debug(): 180 | """Set the verbosity to the `DEBUG` level.""" 181 | return set_verbosity(DEBUG) 182 | 183 | 184 | def set_verbosity_error(): 185 | """Set the verbosity to the `ERROR` level.""" 186 | return set_verbosity(ERROR) 187 | 188 | 189 | def disable_default_handler() -> None: 190 | """Disable the default handler of the HuggingFace Diffusers' root logger.""" 191 | 192 | _configure_library_root_logger() 193 | 194 | assert _default_handler is not None 195 | _get_library_root_logger().removeHandler(_default_handler) 196 | 197 | 198 | def enable_default_handler() -> None: 199 | """Enable the default handler of the HuggingFace Diffusers' root logger.""" 200 | 201 | _configure_library_root_logger() 202 | 203 | assert _default_handler is not None 204 | _get_library_root_logger().addHandler(_default_handler) 205 | 206 | 207 | def add_handler(handler: logging.Handler) -> None: 208 | """adds a handler to the HuggingFace Diffusers' root logger.""" 209 | 210 | _configure_library_root_logger() 211 | 212 | assert handler is not None 213 | _get_library_root_logger().addHandler(handler) 214 | 215 | 216 | def remove_handler(handler: logging.Handler) -> None: 217 | """removes given handler from the HuggingFace Diffusers' root logger.""" 218 | 219 | _configure_library_root_logger() 220 | 221 | assert handler is not None and handler not in _get_library_root_logger().handlers 222 | _get_library_root_logger().removeHandler(handler) 223 | 224 | 225 | def disable_propagation() -> None: 226 | """ 227 | Disable propagation of the library log outputs. Note that log propagation is disabled by default. 228 | """ 229 | 230 | _configure_library_root_logger() 231 | _get_library_root_logger().propagate = False 232 | 233 | 234 | def enable_propagation() -> None: 235 | """ 236 | Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent 237 | double logging if the root logger has been configured. 238 | """ 239 | 240 | _configure_library_root_logger() 241 | _get_library_root_logger().propagate = True 242 | 243 | 244 | def enable_explicit_format() -> None: 245 | """ 246 | Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows: 247 | ``` 248 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE 249 | ``` 250 | All handlers currently bound to the root logger are affected by this method. 251 | """ 252 | handlers = _get_library_root_logger().handlers 253 | 254 | for handler in handlers: 255 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") 256 | handler.setFormatter(formatter) 257 | 258 | 259 | def reset_format() -> None: 260 | """ 261 | Resets the formatting for HuggingFace Diffusers' loggers. 262 | 263 | All handlers currently bound to the root logger are affected by this method. 264 | """ 265 | handlers = _get_library_root_logger().handlers 266 | 267 | for handler in handlers: 268 | handler.setFormatter(None) 269 | 270 | 271 | def warning_advice(self, *args, **kwargs): 272 | """ 273 | This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this 274 | warning will not be printed 275 | """ 276 | no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) 277 | if no_advisory_warnings: 278 | return 279 | self.warning(*args, **kwargs) 280 | 281 | 282 | logging.Logger.warning_advice = warning_advice 283 | 284 | 285 | class EmptyTqdm: 286 | """Dummy tqdm which doesn't do anything.""" 287 | 288 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument 289 | self._iterator = args[0] if args else None 290 | 291 | def __iter__(self): 292 | return iter(self._iterator) 293 | 294 | def __getattr__(self, _): 295 | """Return empty function.""" 296 | 297 | def empty_fn(*args, **kwargs): # pylint: disable=unused-argument 298 | return 299 | 300 | return empty_fn 301 | 302 | def __enter__(self): 303 | return self 304 | 305 | def __exit__(self, type_, value, traceback): 306 | return 307 | 308 | 309 | class _tqdm_cls: 310 | def __call__(self, *args, **kwargs): 311 | if _tqdm_active: 312 | return tqdm_lib.tqdm(*args, **kwargs) 313 | else: 314 | return EmptyTqdm(*args, **kwargs) 315 | 316 | def set_lock(self, *args, **kwargs): 317 | self._lock = None 318 | if _tqdm_active: 319 | return tqdm_lib.tqdm.set_lock(*args, **kwargs) 320 | 321 | def get_lock(self): 322 | if _tqdm_active: 323 | return tqdm_lib.tqdm.get_lock() 324 | 325 | 326 | tqdm = _tqdm_cls() 327 | 328 | 329 | def is_progress_bar_enabled() -> bool: 330 | """Return a boolean indicating whether tqdm progress bars are enabled.""" 331 | global _tqdm_active 332 | return bool(_tqdm_active) 333 | 334 | 335 | def enable_progress_bar(): 336 | """Enable tqdm progress bar.""" 337 | global _tqdm_active 338 | _tqdm_active = True 339 | 340 | 341 | def disable_progress_bar(): 342 | """Disable tqdm progress bar.""" 343 | global _tqdm_active 344 | _tqdm_active = False 345 | -------------------------------------------------------------------------------- /src/diffusers/utils/model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | 7 | 8 | # {{ model_name | default("Diffusion Model") }} 9 | 10 | ## Model description 11 | 12 | This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library 13 | on the `{{ dataset_name }}` dataset. 14 | 15 | ## Intended uses & limitations 16 | 17 | #### How to use 18 | 19 | ```python 20 | # TODO: add an example code snippet for running this diffusion pipeline 21 | ``` 22 | 23 | #### Limitations and bias 24 | 25 | [TODO: provide examples of latent issues and potential remediations] 26 | 27 | ## Training data 28 | 29 | [TODO: describe the data used to train the model] 30 | 31 | ### Training hyperparameters 32 | 33 | The following hyperparameters were used during training: 34 | - learning_rate: {{ learning_rate }} 35 | - train_batch_size: {{ train_batch_size }} 36 | - eval_batch_size: {{ eval_batch_size }} 37 | - gradient_accumulation_steps: {{ gradient_accumulation_steps }} 38 | - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} 39 | - lr_scheduler: {{ lr_scheduler }} 40 | - lr_warmup_steps: {{ lr_warmup_steps }} 41 | - ema_inv_gamma: {{ ema_inv_gamma }} 42 | - ema_inv_gamma: {{ ema_power }} 43 | - ema_inv_gamma: {{ ema_max_decay }} 44 | - mixed_precision: {{ mixed_precision }} 45 | 46 | ### Training results 47 | 48 | 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) 49 | 50 | 51 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/diffusers_all/c8c0c0e846c8afc07602c44180278a2f7f15331d/tests/__init__.py -------------------------------------------------------------------------------- /utils/check_config_docstrings.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import importlib 17 | import inspect 18 | import os 19 | import re 20 | 21 | 22 | # All paths are set with the intent you should run this script from the root of the repo with the command 23 | # python utils/check_config_docstrings.py 24 | PATH_TO_TRANSFORMERS = "src/transformers" 25 | 26 | 27 | # This is to make sure the transformers module imported is the one in the repo. 28 | spec = importlib.util.spec_from_file_location( 29 | "transformers", 30 | os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), 31 | submodule_search_locations=[PATH_TO_TRANSFORMERS], 32 | ) 33 | transformers = spec.loader.load_module() 34 | 35 | CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING 36 | 37 | # Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`. 38 | # For example, `[bert-base-uncased](https://huggingface.co/bert-base-uncased)` 39 | _re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)") 40 | 41 | 42 | CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { 43 | "CLIPConfigMixin", 44 | "DecisionTransformerConfigMixin", 45 | "EncoderDecoderConfigMixin", 46 | "RagConfigMixin", 47 | "SpeechEncoderDecoderConfigMixin", 48 | "VisionEncoderDecoderConfigMixin", 49 | "VisionTextDualEncoderConfigMixin", 50 | } 51 | 52 | 53 | def check_config_docstrings_have_checkpoints(): 54 | configs_without_checkpoint = [] 55 | 56 | for config_class in list(CONFIG_MAPPING.values()): 57 | checkpoint_found = False 58 | 59 | # source code of `config_class` 60 | config_source = inspect.getsource(config_class) 61 | checkpoints = _re_checkpoint.findall(config_source) 62 | 63 | for checkpoint in checkpoints: 64 | # Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link. 65 | # For example, `('bert-base-uncased', 'https://huggingface.co/bert-base-uncased')` 66 | ckpt_name, ckpt_link = checkpoint 67 | 68 | # verify the checkpoint name corresponds to the checkpoint link 69 | ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}" 70 | if ckpt_link == ckpt_link_from_name: 71 | checkpoint_found = True 72 | break 73 | 74 | name = config_class.__name__ 75 | if not checkpoint_found and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK: 76 | configs_without_checkpoint.append(name) 77 | 78 | if len(configs_without_checkpoint) > 0: 79 | message = "\n".join(sorted(configs_without_checkpoint)) 80 | raise ValueError(f"The following configurations don't contain any valid checkpoint:\n{message}") 81 | 82 | 83 | if __name__ == "__main__": 84 | check_config_docstrings_have_checkpoints() 85 | -------------------------------------------------------------------------------- /utils/check_dummies.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import os 18 | import re 19 | 20 | 21 | # All paths are set with the intent you should run this script from the root of the repo with the command 22 | # python utils/check_dummies.py 23 | PATH_TO_DIFFUSERS = "src/diffusers" 24 | 25 | # Matches is_xxx_available() 26 | _re_backend = re.compile(r"is\_([a-z_]*)_available\(\)") 27 | # Matches from xxx import bla 28 | _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") 29 | 30 | 31 | DUMMY_CONSTANT = """ 32 | {0} = None 33 | """ 34 | 35 | DUMMY_CLASS = """ 36 | class {0}(metaclass=DummyObject): 37 | _backends = {1} 38 | 39 | def __init__(self, *args, **kwargs): 40 | requires_backends(self, {1}) 41 | """ 42 | 43 | 44 | DUMMY_FUNCTION = """ 45 | def {0}(*args, **kwargs): 46 | requires_backends({0}, {1}) 47 | """ 48 | 49 | 50 | def find_backend(line): 51 | """Find one (or multiple) backend in a code line of the init.""" 52 | backends = _re_backend.findall(line) 53 | if len(backends) == 0: 54 | return None 55 | 56 | return "_and_".join(backends) 57 | 58 | 59 | def read_init(): 60 | """Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects.""" 61 | with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: 62 | lines = f.readlines() 63 | 64 | # Get to the point we do the actual imports for type checking 65 | line_index = 0 66 | backend_specific_objects = {} 67 | # Go through the end of the file 68 | while line_index < len(lines): 69 | # If the line is an if is_backend_available, we grab all objects associated. 70 | backend = find_backend(lines[line_index]) 71 | if backend is not None: 72 | objects = [] 73 | line_index += 1 74 | # Until we unindent, add backend objects to the list 75 | while not lines[line_index].startswith("else:"): 76 | line = lines[line_index] 77 | single_line_import_search = _re_single_line_import.search(line) 78 | if single_line_import_search is not None: 79 | objects.extend(single_line_import_search.groups()[0].split(", ")) 80 | elif line.startswith(" " * 12): 81 | objects.append(line[12:-2]) 82 | line_index += 1 83 | 84 | backend_specific_objects[backend] = objects 85 | else: 86 | line_index += 1 87 | 88 | return backend_specific_objects 89 | 90 | 91 | def create_dummy_object(name, backend_name): 92 | """Create the code for the dummy object corresponding to `name`.""" 93 | if name.isupper(): 94 | return DUMMY_CONSTANT.format(name) 95 | elif name.islower(): 96 | return DUMMY_FUNCTION.format(name, backend_name) 97 | else: 98 | return DUMMY_CLASS.format(name, backend_name) 99 | 100 | 101 | def create_dummy_files(): 102 | """Create the content of the dummy files.""" 103 | backend_specific_objects = read_init() 104 | # For special correspondence backend to module name as used in the function requires_modulename 105 | dummy_files = {} 106 | 107 | for backend, objects in backend_specific_objects.items(): 108 | backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]" 109 | dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" 110 | dummy_file += "# flake8: noqa\n" 111 | dummy_file += "from ..utils import DummyObject, requires_backends\n\n" 112 | dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) 113 | dummy_files[backend] = dummy_file 114 | 115 | return dummy_files 116 | 117 | 118 | def check_dummies(overwrite=False): 119 | """Check if the dummy files are up to date and maybe `overwrite` with the right content.""" 120 | dummy_files = create_dummy_files() 121 | # For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py 122 | short_names = {"torch": "pt"} 123 | 124 | # Locate actual dummy modules and read their content. 125 | path = os.path.join(PATH_TO_DIFFUSERS, "utils") 126 | dummy_file_paths = { 127 | backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") 128 | for backend in dummy_files.keys() 129 | } 130 | 131 | actual_dummies = {} 132 | for backend, file_path in dummy_file_paths.items(): 133 | if os.path.isfile(file_path): 134 | with open(file_path, "r", encoding="utf-8", newline="\n") as f: 135 | actual_dummies[backend] = f.read() 136 | else: 137 | actual_dummies[backend] = "" 138 | 139 | for backend in dummy_files.keys(): 140 | if dummy_files[backend] != actual_dummies[backend]: 141 | if overwrite: 142 | print( 143 | f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main " 144 | "__init__ has new objects." 145 | ) 146 | with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f: 147 | f.write(dummy_files[backend]) 148 | else: 149 | raise ValueError( 150 | "The main __init__ has objects that are not present in " 151 | f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " 152 | "to fix this." 153 | ) 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") 159 | args = parser.parse_args() 160 | 161 | check_dummies(args.fix_and_overwrite) 162 | -------------------------------------------------------------------------------- /utils/check_table.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import collections 18 | import importlib.util 19 | import os 20 | import re 21 | 22 | 23 | # All paths are set with the intent you should run this script from the root of the repo with the command 24 | # python utils/check_table.py 25 | TRANSFORMERS_PATH = "src/diffusers" 26 | PATH_TO_DOCS = "docs/source/en" 27 | REPO_PATH = "." 28 | 29 | 30 | def _find_text_in_file(filename, start_prompt, end_prompt): 31 | """ 32 | Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty 33 | lines. 34 | """ 35 | with open(filename, "r", encoding="utf-8", newline="\n") as f: 36 | lines = f.readlines() 37 | # Find the start prompt. 38 | start_index = 0 39 | while not lines[start_index].startswith(start_prompt): 40 | start_index += 1 41 | start_index += 1 42 | 43 | end_index = start_index 44 | while not lines[end_index].startswith(end_prompt): 45 | end_index += 1 46 | end_index -= 1 47 | 48 | while len(lines[start_index]) <= 1: 49 | start_index += 1 50 | while len(lines[end_index]) <= 1: 51 | end_index -= 1 52 | end_index += 1 53 | return "".join(lines[start_index:end_index]), start_index, end_index, lines 54 | 55 | 56 | # Add here suffixes that are used to identify models, seperated by | 57 | ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration" 58 | # Regexes that match TF/Flax/PT model names. 59 | _re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") 60 | _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") 61 | # Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes. 62 | _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") 63 | 64 | 65 | # This is to make sure the diffusers module imported is the one in the repo. 66 | spec = importlib.util.spec_from_file_location( 67 | "diffusers", 68 | os.path.join(TRANSFORMERS_PATH, "__init__.py"), 69 | submodule_search_locations=[TRANSFORMERS_PATH], 70 | ) 71 | diffusers_module = spec.loader.load_module() 72 | 73 | 74 | # Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python 75 | def camel_case_split(identifier): 76 | "Split a camelcased `identifier` into words." 77 | matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier) 78 | return [m.group(0) for m in matches] 79 | 80 | 81 | def _center_text(text, width): 82 | text_length = 2 if text == "✅" or text == "❌" else len(text) 83 | left_indent = (width - text_length) // 2 84 | right_indent = width - text_length - left_indent 85 | return " " * left_indent + text + " " * right_indent 86 | 87 | 88 | def get_model_table_from_auto_modules(): 89 | """Generates an up-to-date model table from the content of the auto modules.""" 90 | # Dictionary model names to config. 91 | config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES 92 | model_name_to_config = { 93 | name: config_maping_names[code] 94 | for code, name in diffusers_module.MODEL_NAMES_MAPPING.items() 95 | if code in config_maping_names 96 | } 97 | model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()} 98 | 99 | # Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax. 100 | slow_tokenizers = collections.defaultdict(bool) 101 | fast_tokenizers = collections.defaultdict(bool) 102 | pt_models = collections.defaultdict(bool) 103 | tf_models = collections.defaultdict(bool) 104 | flax_models = collections.defaultdict(bool) 105 | 106 | # Let's lookup through all diffusers object (once). 107 | for attr_name in dir(diffusers_module): 108 | lookup_dict = None 109 | if attr_name.endswith("Tokenizer"): 110 | lookup_dict = slow_tokenizers 111 | attr_name = attr_name[:-9] 112 | elif attr_name.endswith("TokenizerFast"): 113 | lookup_dict = fast_tokenizers 114 | attr_name = attr_name[:-13] 115 | elif _re_tf_models.match(attr_name) is not None: 116 | lookup_dict = tf_models 117 | attr_name = _re_tf_models.match(attr_name).groups()[0] 118 | elif _re_flax_models.match(attr_name) is not None: 119 | lookup_dict = flax_models 120 | attr_name = _re_flax_models.match(attr_name).groups()[0] 121 | elif _re_pt_models.match(attr_name) is not None: 122 | lookup_dict = pt_models 123 | attr_name = _re_pt_models.match(attr_name).groups()[0] 124 | 125 | if lookup_dict is not None: 126 | while len(attr_name) > 0: 127 | if attr_name in model_name_to_prefix.values(): 128 | lookup_dict[attr_name] = True 129 | break 130 | # Try again after removing the last word in the name 131 | attr_name = "".join(camel_case_split(attr_name)[:-1]) 132 | 133 | # Let's build that table! 134 | model_names = list(model_name_to_config.keys()) 135 | model_names.sort(key=str.lower) 136 | columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"] 137 | # We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side). 138 | widths = [len(c) + 2 for c in columns] 139 | widths[0] = max([len(name) for name in model_names]) + 2 140 | 141 | # Build the table per se 142 | table = "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n" 143 | # Use ":-----:" format to center-aligned table cell texts 144 | table += "|" + "|".join([":" + "-" * (w - 2) + ":" for w in widths]) + "|\n" 145 | 146 | check = {True: "✅", False: "❌"} 147 | for name in model_names: 148 | prefix = model_name_to_prefix[name] 149 | line = [ 150 | name, 151 | check[slow_tokenizers[prefix]], 152 | check[fast_tokenizers[prefix]], 153 | check[pt_models[prefix]], 154 | check[tf_models[prefix]], 155 | check[flax_models[prefix]], 156 | ] 157 | table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n" 158 | return table 159 | 160 | 161 | def check_model_table(overwrite=False): 162 | """Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`.""" 163 | current_table, start_index, end_index, lines = _find_text_in_file( 164 | filename=os.path.join(PATH_TO_DOCS, "index.mdx"), 165 | start_prompt="", 167 | ) 168 | new_table = get_model_table_from_auto_modules() 169 | 170 | if current_table != new_table: 171 | if overwrite: 172 | with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f: 173 | f.writelines(lines[:start_index] + [new_table] + lines[end_index:]) 174 | else: 175 | raise ValueError( 176 | "The model table in the `index.mdx` has not been updated. Run `make fix-copies` to fix this." 177 | ) 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") 183 | args = parser.parse_args() 184 | 185 | check_model_table(args.fix_and_overwrite) 186 | -------------------------------------------------------------------------------- /utils/check_tf_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import json 18 | import os 19 | 20 | from tensorflow.core.protobuf.saved_model_pb2 import SavedModel 21 | 22 | 23 | # All paths are set with the intent you should run this script from the root of the repo with the command 24 | # python utils/check_copies.py 25 | REPO_PATH = "." 26 | 27 | # Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model) 28 | INTERNAL_OPS = [ 29 | "Assert", 30 | "AssignVariableOp", 31 | "EmptyTensorList", 32 | "MergeV2Checkpoints", 33 | "ReadVariableOp", 34 | "ResourceGather", 35 | "RestoreV2", 36 | "SaveV2", 37 | "ShardedFilename", 38 | "StatefulPartitionedCall", 39 | "StaticRegexFullMatch", 40 | "VarHandleOp", 41 | ] 42 | 43 | 44 | def onnx_compliancy(saved_model_path, strict, opset): 45 | saved_model = SavedModel() 46 | onnx_ops = [] 47 | 48 | with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f: 49 | onnx_opsets = json.load(f)["opsets"] 50 | 51 | for i in range(1, opset + 1): 52 | onnx_ops.extend(onnx_opsets[str(i)]) 53 | 54 | with open(saved_model_path, "rb") as f: 55 | saved_model.ParseFromString(f.read()) 56 | 57 | model_op_names = set() 58 | 59 | # Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs) 60 | for meta_graph in saved_model.meta_graphs: 61 | # Add operations in the graph definition 62 | model_op_names.update(node.op for node in meta_graph.graph_def.node) 63 | 64 | # Go through the functions in the graph definition 65 | for func in meta_graph.graph_def.library.function: 66 | # Add operations in each function 67 | model_op_names.update(node.op for node in func.node_def) 68 | 69 | # Convert to list, sorted if you want 70 | model_op_names = sorted(model_op_names) 71 | incompatible_ops = [] 72 | 73 | for op in model_op_names: 74 | if op not in onnx_ops and op not in INTERNAL_OPS: 75 | incompatible_ops.append(op) 76 | 77 | if strict and len(incompatible_ops) > 0: 78 | raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops) 79 | elif len(incompatible_ops) > 0: 80 | print(f"Found the following incompatible ops for the opset {opset}:") 81 | print(*incompatible_ops, sep="\n") 82 | else: 83 | print(f"The saved model {saved_model_path} can properly be converted with ONNX.") 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).") 89 | parser.add_argument( 90 | "--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested." 91 | ) 92 | parser.add_argument( 93 | "--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model." 94 | ) 95 | parser.add_argument( 96 | "--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)" 97 | ) 98 | args = parser.parse_args() 99 | 100 | if args.framework == "onnx": 101 | onnx_compliancy(args.saved_model_path, args.strict, args.opset) 102 | -------------------------------------------------------------------------------- /utils/custom_init_isort.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import os 18 | import re 19 | 20 | 21 | PATH_TO_TRANSFORMERS = "src/diffusers" 22 | 23 | # Pattern that looks at the indentation in a line. 24 | _re_indent = re.compile(r"^(\s*)\S") 25 | # Pattern that matches `"key":" and puts `key` in group 0. 26 | _re_direct_key = re.compile(r'^\s*"([^"]+)":') 27 | # Pattern that matches `_import_structure["key"]` and puts `key` in group 0. 28 | _re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]') 29 | # Pattern that matches `"key",` and puts `key` in group 0. 30 | _re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$') 31 | # Pattern that matches any `[stuff]` and puts `stuff` in group 0. 32 | _re_bracket_content = re.compile(r"\[([^\]]+)\]") 33 | 34 | 35 | def get_indent(line): 36 | """Returns the indent in `line`.""" 37 | search = _re_indent.search(line) 38 | return "" if search is None else search.groups()[0] 39 | 40 | 41 | def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None): 42 | """ 43 | Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after 44 | `start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's 45 | after `end_prompt` as a last block, so `code` is always the same as joining the result of this function). 46 | """ 47 | # Let's split the code into lines and move to start_index. 48 | index = 0 49 | lines = code.split("\n") 50 | if start_prompt is not None: 51 | while not lines[index].startswith(start_prompt): 52 | index += 1 53 | blocks = ["\n".join(lines[:index])] 54 | else: 55 | blocks = [] 56 | 57 | # We split into blocks until we get to the `end_prompt` (or the end of the block). 58 | current_block = [lines[index]] 59 | index += 1 60 | while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)): 61 | if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level: 62 | if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "): 63 | current_block.append(lines[index]) 64 | blocks.append("\n".join(current_block)) 65 | if index < len(lines) - 1: 66 | current_block = [lines[index + 1]] 67 | index += 1 68 | else: 69 | current_block = [] 70 | else: 71 | blocks.append("\n".join(current_block)) 72 | current_block = [lines[index]] 73 | else: 74 | current_block.append(lines[index]) 75 | index += 1 76 | 77 | # Adds current block if it's nonempty. 78 | if len(current_block) > 0: 79 | blocks.append("\n".join(current_block)) 80 | 81 | # Add final block after end_prompt if provided. 82 | if end_prompt is not None and index < len(lines): 83 | blocks.append("\n".join(lines[index:])) 84 | 85 | return blocks 86 | 87 | 88 | def ignore_underscore(key): 89 | "Wraps a `key` (that maps an object to string) to lower case and remove underscores." 90 | 91 | def _inner(x): 92 | return key(x).lower().replace("_", "") 93 | 94 | return _inner 95 | 96 | 97 | def sort_objects(objects, key=None): 98 | "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." 99 | # If no key is provided, we use a noop. 100 | def noop(x): 101 | return x 102 | 103 | if key is None: 104 | key = noop 105 | # Constants are all uppercase, they go first. 106 | constants = [obj for obj in objects if key(obj).isupper()] 107 | # Classes are not all uppercase but start with a capital, they go second. 108 | classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()] 109 | # Functions begin with a lowercase, they go last. 110 | functions = [obj for obj in objects if not key(obj)[0].isupper()] 111 | 112 | key1 = ignore_underscore(key) 113 | return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1) 114 | 115 | 116 | def sort_objects_in_import(import_statement): 117 | """ 118 | Return the same `import_statement` but with objects properly sorted. 119 | """ 120 | # This inner function sort imports between [ ]. 121 | def _replace(match): 122 | imports = match.groups()[0] 123 | if "," not in imports: 124 | return f"[{imports}]" 125 | keys = [part.strip().replace('"', "") for part in imports.split(",")] 126 | # We will have a final empty element if the line finished with a comma. 127 | if len(keys[-1]) == 0: 128 | keys = keys[:-1] 129 | return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]" 130 | 131 | lines = import_statement.split("\n") 132 | if len(lines) > 3: 133 | # Here we have to sort internal imports that are on several lines (one per name): 134 | # key: [ 135 | # "object1", 136 | # "object2", 137 | # ... 138 | # ] 139 | 140 | # We may have to ignore one or two lines on each side. 141 | idx = 2 if lines[1].strip() == "[" else 1 142 | keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])] 143 | sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1]) 144 | sorted_lines = [lines[x[0] + idx] for x in sorted_indices] 145 | return "\n".join(lines[:idx] + sorted_lines + lines[-idx:]) 146 | elif len(lines) == 3: 147 | # Here we have to sort internal imports that are on one separate line: 148 | # key: [ 149 | # "object1", "object2", ... 150 | # ] 151 | if _re_bracket_content.search(lines[1]) is not None: 152 | lines[1] = _re_bracket_content.sub(_replace, lines[1]) 153 | else: 154 | keys = [part.strip().replace('"', "") for part in lines[1].split(",")] 155 | # We will have a final empty element if the line finished with a comma. 156 | if len(keys[-1]) == 0: 157 | keys = keys[:-1] 158 | lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)]) 159 | return "\n".join(lines) 160 | else: 161 | # Finally we have to deal with imports fitting on one line 162 | import_statement = _re_bracket_content.sub(_replace, import_statement) 163 | return import_statement 164 | 165 | 166 | def sort_imports(file, check_only=True): 167 | """ 168 | Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite. 169 | """ 170 | with open(file, "r") as f: 171 | code = f.read() 172 | 173 | if "_import_structure" not in code: 174 | return 175 | 176 | # Blocks of indent level 0 177 | main_blocks = split_code_in_indented_blocks( 178 | code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:" 179 | ) 180 | 181 | # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt). 182 | for block_idx in range(1, len(main_blocks) - 1): 183 | # Check if the block contains some `_import_structure`s thingy to sort. 184 | block = main_blocks[block_idx] 185 | block_lines = block.split("\n") 186 | 187 | # Get to the start of the imports. 188 | line_idx = 0 189 | while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]: 190 | # Skip dummy import blocks 191 | if "import dummy" in block_lines[line_idx]: 192 | line_idx = len(block_lines) 193 | else: 194 | line_idx += 1 195 | if line_idx >= len(block_lines): 196 | continue 197 | 198 | # Ignore beginning and last line: they don't contain anything. 199 | internal_block_code = "\n".join(block_lines[line_idx:-1]) 200 | indent = get_indent(block_lines[1]) 201 | # Slit the internal block into blocks of indent level 1. 202 | internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent) 203 | # We have two categories of import key: list or _import_structu[key].append/extend 204 | pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key 205 | # Grab the keys, but there is a trap: some lines are empty or jsut comments. 206 | keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks] 207 | # We only sort the lines with a key. 208 | keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None] 209 | sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])] 210 | 211 | # We reorder the blocks by leaving empty lines/comments as they were and reorder the rest. 212 | count = 0 213 | reorderded_blocks = [] 214 | for i in range(len(internal_blocks)): 215 | if keys[i] is None: 216 | reorderded_blocks.append(internal_blocks[i]) 217 | else: 218 | block = sort_objects_in_import(internal_blocks[sorted_indices[count]]) 219 | reorderded_blocks.append(block) 220 | count += 1 221 | 222 | # And we put our main block back together with its first and last line. 223 | main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]]) 224 | 225 | if code != "\n".join(main_blocks): 226 | if check_only: 227 | return True 228 | else: 229 | print(f"Overwriting {file}.") 230 | with open(file, "w") as f: 231 | f.write("\n".join(main_blocks)) 232 | 233 | 234 | def sort_imports_in_all_inits(check_only=True): 235 | failures = [] 236 | for root, _, files in os.walk(PATH_TO_TRANSFORMERS): 237 | if "__init__.py" in files: 238 | result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only) 239 | if result: 240 | failures = [os.path.join(root, "__init__.py")] 241 | if len(failures) > 0: 242 | raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.") 243 | 244 | 245 | if __name__ == "__main__": 246 | parser = argparse.ArgumentParser() 247 | parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.") 248 | args = parser.parse_args() 249 | 250 | sort_imports_in_all_inits(check_only=args.check_only) 251 | --------------------------------------------------------------------------------