├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── docs
└── source
│ └── imgs
│ └── diffusers_library.jpg
├── examples
├── README.md
├── experimental
│ ├── train_glide_text_to_image.py
│ └── train_latent_text_to_image.py
└── train_unconditional.py
├── pyproject.toml
├── run.py
├── scripts
├── conversion_bddm.py
├── conversion_glide.py
└── conversion_ldm_uncond.py
├── setup.cfg
├── setup.py
├── src
└── diffusers
│ ├── __init__.py
│ ├── configuration_utils.py
│ ├── dependency_versions_check.py
│ ├── dependency_versions_table.py
│ ├── dynamic_modules_utils.py
│ ├── hub_utils.py
│ ├── modeling_utils.py
│ ├── models
│ ├── README.md
│ ├── __init__.py
│ ├── attention.py
│ ├── embeddings.py
│ ├── resnet.py
│ ├── unet.py
│ ├── unet_glide.py
│ ├── unet_grad_tts.py
│ ├── unet_ldm.py
│ ├── unet_new.py
│ ├── unet_rl.py
│ ├── unet_sde_score_estimation.py
│ ├── unet_unconditional.py
│ └── vae.py
│ ├── optimization.py
│ ├── pipeline_utils.py
│ ├── pipelines
│ ├── README.md
│ ├── __init__.py
│ ├── bddm
│ │ ├── __init__.py
│ │ └── pipeline_bddm.py
│ ├── ddim
│ │ ├── __init__.py
│ │ └── pipeline_ddim.py
│ ├── ddpm
│ │ ├── __init__.py
│ │ └── pipeline_ddpm.py
│ ├── glide
│ │ ├── __init__.py
│ │ └── pipeline_glide.py
│ ├── grad_tts
│ │ ├── __init__.py
│ │ ├── grad_tts_utils.py
│ │ └── pipeline_grad_tts.py
│ ├── latent_diffusion
│ │ ├── __init__.py
│ │ └── pipeline_latent_diffusion.py
│ ├── latent_diffusion_uncond
│ │ ├── __init__.py
│ │ └── pipeline_latent_diffusion_uncond.py
│ ├── pndm
│ │ ├── __init__.py
│ │ └── pipeline_pndm.py
│ ├── score_sde_ve
│ │ ├── __init__.py
│ │ └── pipeline_score_sde_ve.py
│ └── score_sde_vp
│ │ ├── __init__.py
│ │ └── pipeline_score_sde_vp.py
│ ├── schedulers
│ ├── README.md
│ ├── __init__.py
│ ├── scheduling_ddim.py
│ ├── scheduling_ddpm.py
│ ├── scheduling_grad_tts.py
│ ├── scheduling_pndm.py
│ ├── scheduling_sde_ve.py
│ ├── scheduling_sde_vp.py
│ └── scheduling_utils.py
│ ├── testing_utils.py
│ ├── training_utils.py
│ └── utils
│ ├── __init__.py
│ ├── dummy_transformers_and_inflect_and_unidecode_objects.py
│ ├── dummy_transformers_objects.py
│ ├── logging.py
│ └── model_card_template.md
├── tests
├── __init__.py
├── test_layers_utils.py
├── test_modeling_utils.py
└── test_scheduler.py
└── utils
├── check_config_docstrings.py
├── check_copies.py
├── check_dummies.py
├── check_inits.py
├── check_repo.py
├── check_table.py
├── check_tf_ops.py
└── custom_init_isort.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Initially taken from Github's Python gitignore file
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # tests and logs
12 | tests/fixtures/cached_*_text.txt
13 | logs/
14 | lightning_logs/
15 | lang_code_data/
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 | .dmypy.json
119 | dmypy.json
120 |
121 | # Pyre type checker
122 | .pyre/
123 |
124 | # vscode
125 | .vs
126 | .vscode
127 |
128 | # Pycharm
129 | .idea
130 |
131 | # TF code
132 | tensorflow_code
133 |
134 | # Models
135 | proc_data
136 |
137 | # examples
138 | runs
139 | /runs_old
140 | /wandb
141 | /examples/runs
142 | /examples/**/*.args
143 | /examples/rag/sweep
144 |
145 | # data
146 | /data
147 | serialization_dir
148 |
149 | # emacs
150 | *.*~
151 | debug.env
152 |
153 | # vim
154 | .*.swp
155 |
156 | #ctags
157 | tags
158 |
159 | # pre-commit
160 | .pre-commit*
161 |
162 | # .lock
163 | *.lock
164 |
165 | # DS_Store (MacOS)
166 | .DS_Store
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := examples tests src utils
7 |
8 | modified_only_fixup:
9 | $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
10 | @if test -n "$(modified_py_files)"; then \
11 | echo "Checking/fixing $(modified_py_files)"; \
12 | black --preview $(modified_py_files); \
13 | isort $(modified_py_files); \
14 | flake8 $(modified_py_files); \
15 | else \
16 | echo "No library .py files were modified"; \
17 | fi
18 |
19 | # Update src/diffusers/dependency_versions_table.py
20 |
21 | deps_table_update:
22 | @python setup.py deps_table_update
23 |
24 | deps_table_check_updated:
25 | @md5sum src/diffusers/dependency_versions_table.py > md5sum.saved
26 | @python setup.py deps_table_update
27 | @md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1)
28 | @rm md5sum.saved
29 |
30 | # autogenerating code
31 |
32 | autogenerate_code: deps_table_update
33 |
34 | # Check that the repo is in a good state
35 |
36 | repo-consistency:
37 | python utils/check_dummies.py
38 | python utils/check_repo.py
39 | python utils/check_inits.py
40 |
41 | # this target runs checks on all files
42 |
43 | quality:
44 | black --check --preview $(check_dirs)
45 | isort --check-only $(check_dirs)
46 | flake8 $(check_dirs)
47 | doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
48 |
49 | # Format source code automatically and check is there are any problems left that need manual fixing
50 |
51 | extra_style_checks:
52 | python utils/custom_init_isort.py
53 | doc-builder style src/diffusers docs/source --max_len 119 --path_to_docs docs/source
54 |
55 | # this target runs checks on all files and potentially modifies some of them
56 |
57 | style:
58 | black --preview $(check_dirs)
59 | isort $(check_dirs)
60 | ${MAKE} autogenerate_code
61 | ${MAKE} extra_style_checks
62 |
63 | # Super fast fix and check target that only works on relevant modified files since the branch was made
64 |
65 | fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
66 |
67 | # Make marked copies of snippets of codes conform to the original
68 |
69 | fix-copies:
70 | python utils/check_dummies.py --fix_and_overwrite
71 |
72 | # Run tests for the library
73 |
74 | test:
75 | python -m pytest -n auto --dist=loadfile -s -v ./tests/
76 |
77 | # Run tests for examples
78 |
79 | test-examples:
80 | python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/
81 |
82 | # Run tests for SageMaker DLC release
83 |
84 | test-sagemaker: # install sagemaker dependencies in advance with pip install .[sagemaker]
85 | TEST_SAGEMAKER=True python -m pytest -n auto -s -v ./tests/sagemaker
86 |
87 |
88 | # Release stuff
89 |
90 | pre-release:
91 | python utils/release.py
92 |
93 | pre-patch:
94 | python utils/release.py --patch
95 |
96 | post-release:
97 | python utils/release.py --post_release
98 |
99 | post-patch:
100 | python utils/release.py --post_release --patch
101 |
--------------------------------------------------------------------------------
/docs/source/imgs/diffusers_library.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/diffusers_all/c8c0c0e846c8afc07602c44180278a2f7f15331d/docs/source/imgs/diffusers_library.jpg
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | ## Training examples
2 |
3 | ### Unconditional Flowers
4 |
5 | The command to train a DDPM UNet model on the Oxford Flowers dataset:
6 |
7 | ```bash
8 | python -m torch.distributed.launch \
9 | --nproc_per_node 4 \
10 | train_unconditional.py \
11 | --dataset="huggan/flowers-102-categories" \
12 | --resolution=64 \
13 | --output_dir="flowers-ddpm" \
14 | --batch_size=16 \
15 | --num_epochs=100 \
16 | --gradient_accumulation_steps=1 \
17 | --lr=1e-4 \
18 | --warmup_steps=500 \
19 | --mixed_precision=no
20 | ```
21 |
22 | A full training run takes 2 hours on 4xV100 GPUs.
23 |
24 |
25 |
26 |
27 | ### Unconditional Pokemon
28 |
29 | The command to train a DDPM UNet model on the Pokemon dataset:
30 |
31 | ```bash
32 | python -m torch.distributed.launch \
33 | --nproc_per_node 4 \
34 | train_unconditional.py \
35 | --dataset="huggan/pokemon" \
36 | --resolution=64 \
37 | --output_dir="pokemon-ddpm" \
38 | --batch_size=16 \
39 | --num_epochs=100 \
40 | --gradient_accumulation_steps=1 \
41 | --lr=1e-4 \
42 | --warmup_steps=500 \
43 | --mixed_precision=no
44 | ```
45 |
46 | A full training run takes 2 hours on 4xV100 GPUs.
47 |
48 |
49 |
--------------------------------------------------------------------------------
/examples/experimental/train_glide_text_to_image.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | import bitsandbytes as bnb
8 | import PIL.Image
9 | from accelerate import Accelerator
10 | from datasets import load_dataset
11 | from diffusers import DDPMScheduler, Glide, GlideUNetModel
12 | from diffusers.hub_utils import init_git_repo, push_to_hub
13 | from diffusers.optimization import get_scheduler
14 | from diffusers.utils import logging
15 | from torchvision.transforms import (
16 | CenterCrop,
17 | Compose,
18 | InterpolationMode,
19 | Normalize,
20 | RandomHorizontalFlip,
21 | Resize,
22 | ToTensor,
23 | )
24 | from tqdm.auto import tqdm
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | def main(args):
31 | accelerator = Accelerator(mixed_precision=args.mixed_precision)
32 |
33 | pipeline = Glide.from_pretrained("fusing/glide-base")
34 | model = pipeline.text_unet
35 | noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
36 | optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
37 |
38 | augmentations = Compose(
39 | [
40 | Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
41 | CenterCrop(args.resolution),
42 | RandomHorizontalFlip(),
43 | ToTensor(),
44 | Normalize([0.5], [0.5]),
45 | ]
46 | )
47 | dataset = load_dataset(args.dataset, split="train")
48 |
49 | text_encoder = pipeline.text_encoder.eval()
50 |
51 | def transforms(examples):
52 | images = [augmentations(image.convert("RGB")) for image in examples["image"]]
53 | text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
54 | text_inputs = text_inputs.input_ids.to(accelerator.device)
55 | with torch.no_grad():
56 | text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state
57 | return {"images": images, "text_embeddings": text_embeddings}
58 |
59 | dataset.set_transform(transforms)
60 | train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
61 |
62 | lr_scheduler = get_scheduler(
63 | "linear",
64 | optimizer=optimizer,
65 | num_warmup_steps=args.warmup_steps,
66 | num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
67 | )
68 |
69 | model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
70 | model, text_encoder, optimizer, train_dataloader, lr_scheduler
71 | )
72 |
73 | if args.push_to_hub:
74 | repo = init_git_repo(args, at_init=True)
75 |
76 | # Train!
77 | is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
78 | world_size = torch.distributed.get_world_size() if is_distributed else 1
79 | total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
80 | max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
81 | logger.info("***** Running training *****")
82 | logger.info(f" Num examples = {len(train_dataloader.dataset)}")
83 | logger.info(f" Num Epochs = {args.num_epochs}")
84 | logger.info(f" Instantaneous batch size per device = {args.batch_size}")
85 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
86 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
87 | logger.info(f" Total optimization steps = {max_steps}")
88 |
89 | for epoch in range(args.num_epochs):
90 | model.train()
91 | with tqdm(total=len(train_dataloader), unit="ba") as pbar:
92 | pbar.set_description(f"Epoch {epoch}")
93 | for step, batch in enumerate(train_dataloader):
94 | clean_images = batch["images"]
95 | batch_size, n_channels, height, width = clean_images.shape
96 | noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
97 | timesteps = torch.randint(
98 | 0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device
99 | ).long()
100 |
101 | # add noise onto the clean images according to the noise magnitude at each timestep
102 | # (this is the forward diffusion process)
103 | noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
104 |
105 | if step % args.gradient_accumulation_steps != 0:
106 | with accelerator.no_sync(model):
107 | model_output = model(noisy_images, timesteps, batch["text_embeddings"])
108 | model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
109 | # Learn the variance using the variational bound, but don't let
110 | # it affect our mean prediction.
111 | frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
112 |
113 | # predict the noise residual
114 | loss = F.mse_loss(model_output, noise_samples)
115 |
116 | loss = loss / args.gradient_accumulation_steps
117 |
118 | accelerator.backward(loss)
119 | optimizer.step()
120 | else:
121 | model_output = model(noisy_images, timesteps, batch["text_embeddings"])
122 | model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
123 | # Learn the variance using the variational bound, but don't let
124 | # it affect our mean prediction.
125 | frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
126 |
127 | # predict the noise residual
128 | loss = F.mse_loss(model_output, noise_samples)
129 | loss = loss / args.gradient_accumulation_steps
130 | accelerator.backward(loss)
131 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
132 | optimizer.step()
133 | lr_scheduler.step()
134 | optimizer.zero_grad()
135 | pbar.update(1)
136 | pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
137 |
138 | accelerator.wait_for_everyone()
139 |
140 | # Generate a sample image for visual inspection
141 | if accelerator.is_main_process:
142 | model.eval()
143 | with torch.no_grad():
144 | pipeline.unet = accelerator.unwrap_model(model)
145 |
146 | generator = torch.manual_seed(0)
147 | # run pipeline in inference (sample random noise and denoise)
148 | image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50)
149 |
150 | # process image to PIL
151 | image_processed = image.squeeze(0)
152 | image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
153 | image_pil = PIL.Image.fromarray(image_processed)
154 |
155 | # save image
156 | test_dir = os.path.join(args.output_dir, "test_samples")
157 | os.makedirs(test_dir, exist_ok=True)
158 | image_pil.save(f"{test_dir}/{epoch:04d}.png")
159 |
160 | # save the model
161 | if args.push_to_hub:
162 | push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
163 | else:
164 | pipeline.save_pretrained(args.output_dir)
165 | accelerator.wait_for_everyone()
166 |
167 |
168 | if __name__ == "__main__":
169 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
170 | parser.add_argument("--local_rank", type=int, default=-1)
171 | parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
172 | parser.add_argument("--output_dir", type=str, default="glide-text2image")
173 | parser.add_argument("--overwrite_output_dir", action="store_true")
174 | parser.add_argument("--resolution", type=int, default=64)
175 | parser.add_argument("--batch_size", type=int, default=4)
176 | parser.add_argument("--num_epochs", type=int, default=100)
177 | parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
178 | parser.add_argument("--lr", type=float, default=1e-4)
179 | parser.add_argument("--warmup_steps", type=int, default=500)
180 | parser.add_argument("--push_to_hub", action="store_true")
181 | parser.add_argument("--hub_token", type=str, default=None)
182 | parser.add_argument("--hub_model_id", type=str, default=None)
183 | parser.add_argument("--hub_private_repo", action="store_true")
184 | parser.add_argument(
185 | "--mixed_precision",
186 | type=str,
187 | default="no",
188 | choices=["no", "fp16", "bf16"],
189 | help=(
190 | "Whether to use mixed precision. Choose"
191 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
192 | "and an Nvidia Ampere GPU."
193 | ),
194 | )
195 |
196 | args = parser.parse_args()
197 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
198 | if env_local_rank != -1 and env_local_rank != args.local_rank:
199 | args.local_rank = env_local_rank
200 |
201 | main(args)
202 |
--------------------------------------------------------------------------------
/examples/experimental/train_latent_text_to_image.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | import bitsandbytes as bnb
8 | import PIL.Image
9 | from accelerate import Accelerator
10 | from datasets import load_dataset
11 | from diffusers import DDPMScheduler, LatentDiffusion, UNetLDMModel
12 | from diffusers.hub_utils import init_git_repo, push_to_hub
13 | from diffusers.optimization import get_scheduler
14 | from diffusers.utils import logging
15 | from torchvision.transforms import (
16 | CenterCrop,
17 | Compose,
18 | InterpolationMode,
19 | Normalize,
20 | RandomHorizontalFlip,
21 | Resize,
22 | ToTensor,
23 | )
24 | from tqdm.auto import tqdm
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | def main(args):
31 | accelerator = Accelerator(mixed_precision=args.mixed_precision)
32 |
33 | pipeline = LatentDiffusion.from_pretrained("fusing/latent-diffusion-text2im-large")
34 | pipeline.unet = None # this model will be trained from scratch now
35 | model = UNetLDMModel(
36 | attention_resolutions=[4, 2, 1],
37 | channel_mult=[1, 2, 4, 4],
38 | context_dim=1280,
39 | conv_resample=True,
40 | dims=2,
41 | dropout=0,
42 | image_size=8,
43 | in_channels=4,
44 | model_channels=320,
45 | num_heads=8,
46 | num_res_blocks=2,
47 | out_channels=4,
48 | resblock_updown=False,
49 | transformer_depth=1,
50 | use_new_attention_order=False,
51 | use_scale_shift_norm=False,
52 | use_spatial_transformer=True,
53 | legacy=False,
54 | )
55 | noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
56 | optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
57 |
58 | augmentations = Compose(
59 | [
60 | Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
61 | CenterCrop(args.resolution),
62 | RandomHorizontalFlip(),
63 | ToTensor(),
64 | Normalize([0.5], [0.5]),
65 | ]
66 | )
67 | dataset = load_dataset(args.dataset, split="train")
68 |
69 | text_encoder = pipeline.bert.eval()
70 | vqvae = pipeline.vqvae.eval()
71 |
72 | def transforms(examples):
73 | images = [augmentations(image.convert("RGB")) for image in examples["image"]]
74 | text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
75 | with torch.no_grad():
76 | text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs.input_ids.cpu()).last_hidden_state
77 | images = 1 / 0.18215 * torch.stack(images, dim=0)
78 | latents = accelerator.unwrap_model(vqvae).encode(images.cpu()).mode()
79 | return {"images": images, "text_embeddings": text_embeddings, "latents": latents}
80 |
81 | dataset.set_transform(transforms)
82 | train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
83 |
84 | lr_scheduler = get_scheduler(
85 | "linear",
86 | optimizer=optimizer,
87 | num_warmup_steps=args.warmup_steps,
88 | num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
89 | )
90 |
91 | model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
92 | model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler
93 | )
94 | text_encoder = text_encoder.cpu()
95 | vqvae = vqvae.cpu()
96 |
97 | if args.push_to_hub:
98 | repo = init_git_repo(args, at_init=True)
99 |
100 | # Train!
101 | is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
102 | world_size = torch.distributed.get_world_size() if is_distributed else 1
103 | total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
104 | max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
105 | logger.info("***** Running training *****")
106 | logger.info(f" Num examples = {len(train_dataloader.dataset)}")
107 | logger.info(f" Num Epochs = {args.num_epochs}")
108 | logger.info(f" Instantaneous batch size per device = {args.batch_size}")
109 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
110 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
111 | logger.info(f" Total optimization steps = {max_steps}")
112 |
113 | global_step = 0
114 | for epoch in range(args.num_epochs):
115 | model.train()
116 | with tqdm(total=len(train_dataloader), unit="ba") as pbar:
117 | pbar.set_description(f"Epoch {epoch}")
118 | for step, batch in enumerate(train_dataloader):
119 | clean_latents = batch["latents"]
120 | noise_samples = torch.randn(clean_latents.shape).to(clean_latents.device)
121 | bsz = clean_latents.shape[0]
122 | timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_latents.device).long()
123 |
124 | # add noise onto the clean latents according to the noise magnitude at each timestep
125 | # (this is the forward diffusion process)
126 | noisy_latents = noise_scheduler.training_step(clean_latents, noise_samples, timesteps)
127 |
128 | if step % args.gradient_accumulation_steps != 0:
129 | with accelerator.no_sync(model):
130 | output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
131 | # predict the noise residual
132 | loss = F.mse_loss(output, noise_samples)
133 | loss = loss / args.gradient_accumulation_steps
134 | accelerator.backward(loss)
135 | optimizer.step()
136 | else:
137 | output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
138 | # predict the noise residual
139 | loss = F.mse_loss(output, noise_samples)
140 | loss = loss / args.gradient_accumulation_steps
141 | accelerator.backward(loss)
142 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
143 | optimizer.step()
144 | lr_scheduler.step()
145 | optimizer.zero_grad()
146 | pbar.update(1)
147 | pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
148 | global_step += 1
149 |
150 | accelerator.wait_for_everyone()
151 |
152 | # Generate a sample image for visual inspection
153 | if accelerator.is_main_process:
154 | model.eval()
155 | with torch.no_grad():
156 | pipeline.unet = accelerator.unwrap_model(model)
157 |
158 | generator = torch.manual_seed(0)
159 | # run pipeline in inference (sample random noise and denoise)
160 | image = pipeline(
161 | ["a clip art of a corgi"], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50
162 | )
163 |
164 | # process image to PIL
165 | image_processed = image.cpu().permute(0, 2, 3, 1)
166 | image_processed = image_processed * 255.0
167 | image_processed = image_processed.type(torch.uint8).numpy()
168 | image_pil = PIL.Image.fromarray(image_processed[0])
169 |
170 | # save image
171 | test_dir = os.path.join(args.output_dir, "test_samples")
172 | os.makedirs(test_dir, exist_ok=True)
173 | image_pil.save(f"{test_dir}/{epoch:04d}.png")
174 |
175 | # save the model
176 | if args.push_to_hub:
177 | push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
178 | else:
179 | pipeline.save_pretrained(args.output_dir)
180 | accelerator.wait_for_everyone()
181 |
182 |
183 | if __name__ == "__main__":
184 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
185 | parser.add_argument("--local_rank", type=int, default=-1)
186 | parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
187 | parser.add_argument("--output_dir", type=str, default="ldm-text2image")
188 | parser.add_argument("--overwrite_output_dir", action="store_true")
189 | parser.add_argument("--resolution", type=int, default=128)
190 | parser.add_argument("--batch_size", type=int, default=1)
191 | parser.add_argument("--num_epochs", type=int, default=100)
192 | parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
193 | parser.add_argument("--lr", type=float, default=1e-4)
194 | parser.add_argument("--warmup_steps", type=int, default=500)
195 | parser.add_argument("--push_to_hub", action="store_true")
196 | parser.add_argument("--hub_token", type=str, default=None)
197 | parser.add_argument("--hub_model_id", type=str, default=None)
198 | parser.add_argument("--hub_private_repo", action="store_true")
199 | parser.add_argument(
200 | "--mixed_precision",
201 | type=str,
202 | default="no",
203 | choices=["no", "fp16", "bf16"],
204 | help=(
205 | "Whether to use mixed precision. Choose"
206 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
207 | "and an Nvidia Ampere GPU."
208 | ),
209 | )
210 |
211 | args = parser.parse_args()
212 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
213 | if env_local_rank != -1 and env_local_rank != args.local_rank:
214 | args.local_rank = env_local_rank
215 |
216 | main(args)
217 |
--------------------------------------------------------------------------------
/examples/train_unconditional.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from accelerate import Accelerator, DistributedDataParallelKwargs
8 | from accelerate.logging import get_logger
9 | from datasets import load_dataset
10 | from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
11 | from diffusers.hub_utils import init_git_repo, push_to_hub
12 | from diffusers.optimization import get_scheduler
13 | from diffusers.training_utils import EMAModel
14 | from torchvision.transforms import (
15 | CenterCrop,
16 | Compose,
17 | InterpolationMode,
18 | Normalize,
19 | RandomHorizontalFlip,
20 | Resize,
21 | ToTensor,
22 | )
23 | from tqdm.auto import tqdm
24 |
25 |
26 | logger = get_logger(__name__)
27 |
28 |
29 | def main(args):
30 | ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
31 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
32 | accelerator = Accelerator(
33 | mixed_precision=args.mixed_precision,
34 | log_with="tensorboard",
35 | logging_dir=logging_dir,
36 | kwargs_handlers=[ddp_unused_params],
37 | )
38 |
39 | model = UNetModel(
40 | attn_resolutions=(16,),
41 | ch=128,
42 | ch_mult=(1, 2, 4, 8),
43 | dropout=0.0,
44 | num_res_blocks=2,
45 | resamp_with_conv=True,
46 | resolution=args.resolution,
47 | )
48 | noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt")
49 | optimizer = torch.optim.AdamW(
50 | model.parameters(),
51 | lr=args.learning_rate,
52 | betas=(args.adam_beta1, args.adam_beta2),
53 | weight_decay=args.adam_weight_decay,
54 | eps=args.adam_epsilon,
55 | )
56 |
57 | augmentations = Compose(
58 | [
59 | Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
60 | CenterCrop(args.resolution),
61 | RandomHorizontalFlip(),
62 | ToTensor(),
63 | Normalize([0.5], [0.5]),
64 | ]
65 | )
66 | dataset = load_dataset(args.dataset, split="train")
67 |
68 | def transforms(examples):
69 | images = [augmentations(image.convert("RGB")) for image in examples["image"]]
70 | return {"input": images}
71 |
72 | dataset.set_transform(transforms)
73 | train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
74 |
75 | lr_scheduler = get_scheduler(
76 | args.lr_scheduler,
77 | optimizer=optimizer,
78 | num_warmup_steps=args.lr_warmup_steps,
79 | num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
80 | )
81 |
82 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
83 | model, optimizer, train_dataloader, lr_scheduler
84 | )
85 |
86 | ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
87 |
88 | if args.push_to_hub:
89 | repo = init_git_repo(args, at_init=True)
90 |
91 | if accelerator.is_main_process:
92 | run = os.path.split(__file__)[-1].split(".")[0]
93 | accelerator.init_trackers(run)
94 |
95 | # Train!
96 | is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
97 | world_size = torch.distributed.get_world_size() if is_distributed else 1
98 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
99 | max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
100 | logger.info("***** Running training *****")
101 | logger.info(f" Num examples = {len(train_dataloader.dataset)}")
102 | logger.info(f" Num Epochs = {args.num_epochs}")
103 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
104 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
105 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
106 | logger.info(f" Total optimization steps = {max_steps}")
107 |
108 | global_step = 0
109 | for epoch in range(args.num_epochs):
110 | model.train()
111 | progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
112 | progress_bar.set_description(f"Epoch {epoch}")
113 | for step, batch in enumerate(train_dataloader):
114 | clean_images = batch["input"]
115 | noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
116 | bsz = clean_images.shape[0]
117 | timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
118 |
119 | # add noise onto the clean images according to the noise magnitude at each timestep
120 | # (this is the forward diffusion process)
121 | noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps)
122 |
123 | if step % args.gradient_accumulation_steps != 0:
124 | with accelerator.no_sync(model):
125 | output = model(noisy_images, timesteps)
126 | # predict the noise residual
127 | loss = F.mse_loss(output, noise_samples)
128 | loss = loss / args.gradient_accumulation_steps
129 | accelerator.backward(loss)
130 | else:
131 | output = model(noisy_images, timesteps)
132 | # predict the noise residual
133 | loss = F.mse_loss(output, noise_samples)
134 | loss = loss / args.gradient_accumulation_steps
135 | accelerator.backward(loss)
136 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
137 | optimizer.step()
138 | lr_scheduler.step()
139 | ema_model.step(model)
140 | optimizer.zero_grad()
141 | progress_bar.update(1)
142 | progress_bar.set_postfix(
143 | loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
144 | )
145 | accelerator.log(
146 | {
147 | "train_loss": loss.detach().item(),
148 | "epoch": epoch,
149 | "ema_decay": ema_model.decay,
150 | "step": global_step,
151 | },
152 | step=global_step,
153 | )
154 | global_step += 1
155 | progress_bar.close()
156 |
157 | accelerator.wait_for_everyone()
158 |
159 | # Generate a sample image for visual inspection
160 | if accelerator.is_main_process:
161 | with torch.no_grad():
162 | pipeline = DDIMPipeline(
163 | unet=accelerator.unwrap_model(ema_model.averaged_model),
164 | noise_scheduler=noise_scheduler,
165 | )
166 |
167 | generator = torch.manual_seed(0)
168 | # run pipeline in inference (sample random noise and denoise)
169 | images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50)
170 |
171 | # denormalize the images and save to tensorboard
172 | images_processed = (images.cpu() + 1.0) * 127.5
173 | images_processed = images_processed.clamp(0, 255).type(torch.uint8).numpy()
174 |
175 | accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
176 |
177 | # save the model
178 | if args.push_to_hub:
179 | push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
180 | else:
181 | pipeline.save_pretrained(args.output_dir)
182 | accelerator.wait_for_everyone()
183 |
184 | accelerator.end_training()
185 |
186 |
187 | if __name__ == "__main__":
188 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
189 | parser.add_argument("--local_rank", type=int, default=-1)
190 | parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
191 | parser.add_argument("--output_dir", type=str, default="ddpm-model")
192 | parser.add_argument("--overwrite_output_dir", action="store_true")
193 | parser.add_argument("--resolution", type=int, default=64)
194 | parser.add_argument("--train_batch_size", type=int, default=16)
195 | parser.add_argument("--eval_batch_size", type=int, default=16)
196 | parser.add_argument("--num_epochs", type=int, default=100)
197 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
198 | parser.add_argument("--learning_rate", type=float, default=1e-4)
199 | parser.add_argument("--lr_scheduler", type=str, default="cosine")
200 | parser.add_argument("--lr_warmup_steps", type=int, default=500)
201 | parser.add_argument("--adam_beta1", type=float, default=0.95)
202 | parser.add_argument("--adam_beta2", type=float, default=0.999)
203 | parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
204 | parser.add_argument("--adam_epsilon", type=float, default=1e-3)
205 | parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
206 | parser.add_argument("--ema_power", type=float, default=3 / 4)
207 | parser.add_argument("--ema_max_decay", type=float, default=0.9999)
208 | parser.add_argument("--push_to_hub", action="store_true")
209 | parser.add_argument("--hub_token", type=str, default=None)
210 | parser.add_argument("--hub_model_id", type=str, default=None)
211 | parser.add_argument("--hub_private_repo", action="store_true")
212 | parser.add_argument("--logging_dir", type=str, default="logs")
213 | parser.add_argument(
214 | "--mixed_precision",
215 | type=str,
216 | default="no",
217 | choices=["no", "fp16", "bf16"],
218 | help=(
219 | "Whether to use mixed precision. Choose"
220 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
221 | "and an Nvidia Ampere GPU."
222 | ),
223 | )
224 |
225 | args = parser.parse_args()
226 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
227 | if env_local_rank != -1 and env_local_rank != args.local_rank:
228 | args.local_rank = env_local_rank
229 |
230 | main(args)
231 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 119
3 | target-version = ['py36']
4 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import numpy as np
3 | import PIL
4 | import torch
5 | #from configs.ve import ffhq_ncsnpp_continuous as configs
6 | # from configs.ve import cifar10_ncsnpp_continuous as configs
7 |
8 |
9 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
10 |
11 | torch.backends.cuda.matmul.allow_tf32 = False
12 | torch.manual_seed(0)
13 |
14 |
15 | class NewReverseDiffusionPredictor:
16 | def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
17 | super().__init__()
18 | self.sigma_min = sigma_min
19 | self.sigma_max = sigma_max
20 | self.N = N
21 | self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
22 |
23 | self.probability_flow = probability_flow
24 | self.score_fn = score_fn
25 |
26 | def discretize(self, x, t):
27 | timestep = (t * (self.N - 1)).long()
28 | sigma = self.discrete_sigmas.to(t.device)[timestep]
29 | adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
30 | self.discrete_sigmas[timestep - 1].to(t.device))
31 | f = torch.zeros_like(x)
32 | G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
33 |
34 | labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
35 | result = self.score_fn(x, labels)
36 |
37 | rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
38 | rev_G = torch.zeros_like(G) if self.probability_flow else G
39 | return rev_f, rev_G
40 |
41 | def update_fn(self, x, t):
42 | f, G = self.discretize(x, t)
43 | z = torch.randn_like(x)
44 | x_mean = x - f
45 | x = x_mean + G[:, None, None, None] * z
46 | return x, x_mean
47 |
48 |
49 | class NewLangevinCorrector:
50 | def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
51 | super().__init__()
52 | self.score_fn = score_fn
53 | self.snr = snr
54 | self.n_steps = n_steps
55 |
56 | self.sigma_min = sigma_min
57 | self.sigma_max = sigma_max
58 |
59 | def update_fn(self, x, t):
60 | score_fn = self.score_fn
61 | n_steps = self.n_steps
62 | target_snr = self.snr
63 | # if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
64 | # timestep = (t * (sde.N - 1) / sde.T).long()
65 | # alpha = sde.alphas.to(t.device)[timestep]
66 | # else:
67 | alpha = torch.ones_like(t)
68 |
69 | for i in range(n_steps):
70 | labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
71 | grad = score_fn(x, labels)
72 | noise = torch.randn_like(x)
73 | grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
74 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
75 | step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
76 | x_mean = x + step_size[:, None, None, None] * grad
77 | x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
78 |
79 | return x, x_mean
80 |
81 |
82 |
83 | def save_image(x):
84 | image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
85 | image_pil = PIL.Image.fromarray(image_processed[0])
86 | image_pil.save("../images/hey.png")
87 |
88 |
89 | # ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
90 | #ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
91 | # Note usually we need to restore ema etc...
92 | # ema restored checkpoint used from below
93 |
94 | N = 2
95 | sigma_min = 0.01
96 | sigma_max = 1348
97 | sampling_eps = 1e-5
98 | batch_size = 1
99 | centered = False
100 |
101 | from diffusers import NCSNpp
102 |
103 | model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
104 | model = torch.nn.DataParallel(model)
105 |
106 | img_size = model.module.config.image_size
107 | channels = model.module.config.num_channels
108 | shape = (batch_size, channels, img_size, img_size)
109 | probability_flow = False
110 | snr = 0.15
111 | n_steps = 1
112 |
113 |
114 | new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
115 | new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
116 |
117 | with torch.no_grad():
118 | # Initial sample
119 | x = torch.randn(*shape) * sigma_max
120 | x = x.to(device)
121 | timesteps = torch.linspace(1, sampling_eps, N, device=device)
122 |
123 | for i in range(N):
124 | t = timesteps[i]
125 | vec_t = torch.ones(shape[0], device=t.device) * t
126 | x, x_mean = new_corrector.update_fn(x, vec_t)
127 | x, x_mean = new_predictor.update_fn(x, vec_t)
128 |
129 | x = x_mean
130 | if centered:
131 | x = (x + 1.) / 2.
132 |
133 |
134 | # save_image(x)
135 |
136 | # for 5 cifar10
137 | x_sum = 106071.9922
138 | x_mean = 34.52864456176758
139 |
140 | # for 1000 cifar10
141 | x_sum = 461.9700
142 | x_mean = 0.1504
143 |
144 | # for 2 for 1024
145 | x_sum = 3382810112.0
146 | x_mean = 1075.366455078125
147 |
148 | def check_x_sum_x_mean(x, x_sum, x_mean):
149 | assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
150 | assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
151 |
152 |
153 | check_x_sum_x_mean(x, x_sum, x_mean)
154 |
--------------------------------------------------------------------------------
/scripts/conversion_bddm.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import torch
4 |
5 | from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
6 | from diffusers import DDPMScheduler
7 |
8 |
9 | def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
10 | sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
11 | noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
12 |
13 | model = DiffWave()
14 | model.load_state_dict(sd, strict=False)
15 |
16 | ts, _, betas, _ = noise_scheduler_sd
17 | ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
18 |
19 | noise_scheduler = DDPMScheduler(
20 | timesteps=12,
21 | trained_betas=betas,
22 | timestep_values=ts,
23 | clip_sample=False,
24 | tensor_format="np",
25 | )
26 |
27 | pipeline = BDDMPipeline(model, noise_scheduler)
28 | pipeline.save_pretrained(output_path)
29 |
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("--checkpoint_path", type=str, required=True)
34 | parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
35 | parser.add_argument("--output_path", type=str, required=True)
36 | args = parser.parse_args()
37 |
38 | convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)
39 |
40 |
41 |
--------------------------------------------------------------------------------
/scripts/conversion_glide.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
5 | from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
6 | from transformers import CLIPTextConfig, GPT2Tokenizer
7 |
8 |
9 | # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
10 | state_dict = torch.load("base.pt", map_location="cpu")
11 | state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
12 |
13 | ### Convert the text encoder
14 |
15 | config = CLIPTextConfig(
16 | vocab_size=50257,
17 | max_position_embeddings=128,
18 | hidden_size=512,
19 | intermediate_size=2048,
20 | num_hidden_layers=16,
21 | num_attention_heads=8,
22 | use_padding_embeddings=True,
23 | )
24 | model = CLIPTextModel(config).eval()
25 | tokenizer = GPT2Tokenizer(
26 | "./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
27 | )
28 |
29 | hf_encoder = model.text_model
30 |
31 | hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
32 | hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
33 | hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
34 |
35 | hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
36 | hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
37 |
38 | for layer_idx in range(config.num_hidden_layers):
39 | hf_layer = hf_encoder.encoder.layers[layer_idx]
40 | hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
41 | hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
42 |
43 | hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
44 | hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
45 |
46 | hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
47 | hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
48 | hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
49 | hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
50 |
51 | hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
52 | hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
53 | hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
54 | hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
55 |
56 | ### Convert the Text-to-Image UNet
57 |
58 | text2im_model = GlideTextToImageUNetModel(
59 | in_channels=3,
60 | model_channels=192,
61 | out_channels=6,
62 | num_res_blocks=3,
63 | attention_resolutions=(2, 4, 8),
64 | dropout=0.1,
65 | channel_mult=(1, 2, 3, 4),
66 | num_heads=1,
67 | num_head_channels=64,
68 | num_heads_upsample=1,
69 | use_scale_shift_norm=True,
70 | resblock_updown=True,
71 | transformer_dim=512,
72 | )
73 |
74 | text2im_model.load_state_dict(state_dict, strict=False)
75 |
76 | text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
77 |
78 | ### Convert the Super-Resolution UNet
79 |
80 | # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
81 | ups_state_dict = torch.load("upsample.pt", map_location="cpu")
82 |
83 | superres_model = GlideSuperResUNetModel(
84 | in_channels=6,
85 | model_channels=192,
86 | out_channels=6,
87 | num_res_blocks=2,
88 | attention_resolutions=(8, 16, 32),
89 | dropout=0.1,
90 | channel_mult=(1, 1, 2, 2, 4, 4),
91 | num_heads=1,
92 | num_head_channels=64,
93 | num_heads_upsample=1,
94 | use_scale_shift_norm=True,
95 | resblock_updown=True,
96 | )
97 |
98 | superres_model.load_state_dict(ups_state_dict, strict=False)
99 |
100 | upscale_scheduler = DDIMScheduler(
101 | timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
102 | )
103 |
104 | glide = Glide(
105 | text_unet=text2im_model,
106 | text_noise_scheduler=text_scheduler,
107 | text_encoder=model,
108 | tokenizer=tokenizer,
109 | upscale_unet=superres_model,
110 | upscale_noise_scheduler=upscale_scheduler,
111 | )
112 |
113 | glide.save_pretrained("./glide-base")
114 |
--------------------------------------------------------------------------------
/scripts/conversion_ldm_uncond.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import OmegaConf
4 | import torch
5 |
6 | from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler
7 |
8 | def convert_ldm_original(checkpoint_path, config_path, output_path):
9 | config = OmegaConf.load(config_path)
10 | state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
11 | keys = list(state_dict.keys())
12 |
13 | # extract state_dict for VQVAE
14 | first_stage_dict = {}
15 | first_stage_key = "first_stage_model."
16 | for key in keys:
17 | if key.startswith(first_stage_key):
18 | first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
19 |
20 | # extract state_dict for UNetLDM
21 | unet_state_dict = {}
22 | unet_key = "model.diffusion_model."
23 | for key in keys:
24 | if key.startswith(unet_key):
25 | unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
26 |
27 | vqvae_init_args = config.model.params.first_stage_config.params
28 | unet_init_args = config.model.params.unet_config.params
29 |
30 | vqvae = VQModel(**vqvae_init_args).eval()
31 | vqvae.load_state_dict(first_stage_dict)
32 |
33 | unet = UNetLDMModel(**unet_init_args).eval()
34 | unet.load_state_dict(unet_state_dict)
35 |
36 | noise_scheduler = DDIMScheduler(
37 | timesteps=config.model.params.timesteps,
38 | beta_schedule="scaled_linear",
39 | beta_start=config.model.params.linear_start,
40 | beta_end=config.model.params.linear_end,
41 | clip_sample=False,
42 | )
43 |
44 | pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler)
45 | pipeline.save_pretrained(output_path)
46 |
47 |
48 | if __name__ == "__main__":
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument("--checkpoint_path", type=str, required=True)
51 | parser.add_argument("--config_path", type=str, required=True)
52 | parser.add_argument("--output_path", type=str, required=True)
53 | args = parser.parse_args()
54 |
55 | convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
56 |
57 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = accelerate
7 | known_third_party =
8 | numpy
9 | torch
10 | torch_xla
11 |
12 | line_length = 119
13 | lines_after_imports = 2
14 | multi_line_output = 3
15 | use_parentheses = True
16 |
17 | [flake8]
18 | ignore = E203, E722, E501, E741, W503, W605
19 | max-line-length = 119
20 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py
17 |
18 | To create the package for pypi.
19 |
20 | 1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
21 | documentation.
22 |
23 | If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
24 | for the post-release and run `make fix-copies` on the main branch as well.
25 |
26 | 2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
27 |
28 | 3. Unpin specific versions from setup.py that use a git install.
29 |
30 | 4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the
31 | message: "Release: " and push.
32 |
33 | 5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
34 |
35 | 6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' "
36 | Push the tag to git: git push --tags origin v-release
37 |
38 | 7. Build both the sources and the wheel. Do not change anything in setup.py between
39 | creating the wheel and the source distribution (obviously).
40 |
41 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
42 | (this will build a wheel for the python version you use to build it).
43 |
44 | For the sources, run: "python setup.py sdist"
45 | You should now have a /dist directory with both .whl and .tar.gz source versions.
46 |
47 | 8. Check that everything looks correct by uploading the package to the pypi test server:
48 |
49 | twine upload dist/* -r pypitest
50 | (pypi suggest using twine as other methods upload files via plaintext.)
51 | You may have to specify the repository url, use the following command then:
52 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
53 |
54 | Check that you can install it in a virtualenv by running:
55 | pip install -i https://testpypi.python.org/pypi diffusers
56 |
57 | Check you can run the following commands:
58 | python -c "from diffusers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
59 | python -c "from diffusers import *"
60 |
61 | 9. Upload the final version to actual pypi:
62 | twine upload dist/* -r pypi
63 |
64 | 10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
65 |
66 | 11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
67 | you need to go back to main before executing this.
68 | """
69 |
70 | import re
71 | from distutils.core import Command
72 |
73 | from setuptools import find_packages, setup
74 |
75 | # IMPORTANT:
76 | # 1. all dependencies should be listed here with their version requirements if any
77 | # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
78 | _deps = [
79 | "Pillow",
80 | "black~=22.0,>=22.3",
81 | "filelock",
82 | "flake8>=3.8.3",
83 | "huggingface-hub",
84 | "isort>=5.5.4",
85 | "numpy",
86 | "pytest",
87 | "regex!=2019.12.17",
88 | "requests",
89 | "torch>=1.4",
90 | "tensorboard",
91 | "modelcards==0.1.4"
92 | ]
93 |
94 | # this is a lookup table with items like:
95 | #
96 | # tokenizers: "huggingface-hub==0.8.0"
97 | # packaging: "packaging"
98 | #
99 | # some of the values are versioned whereas others aren't.
100 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)}
101 |
102 | # since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from
103 | # anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:
104 | #
105 | # python -c 'import sys; from diffusers.dependency_versions_table import deps; \
106 | # print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets
107 | #
108 | # Just pass the desired package names to that script as it's shown with 2 packages above.
109 | #
110 | # If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above
111 | #
112 | # You can then feed this for example to `pip`:
113 | #
114 | # pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \
115 | # print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets)
116 | #
117 |
118 |
119 | def deps_list(*pkgs):
120 | return [deps[pkg] for pkg in pkgs]
121 |
122 |
123 | class DepsTableUpdateCommand(Command):
124 | """
125 | A custom distutils command that updates the dependency table.
126 | usage: python setup.py deps_table_update
127 | """
128 |
129 | description = "build runtime dependency table"
130 | user_options = [
131 | # format: (long option, short option, description).
132 | ("dep-table-update", None, "updates src/diffusers/dependency_versions_table.py"),
133 | ]
134 |
135 | def initialize_options(self):
136 | pass
137 |
138 | def finalize_options(self):
139 | pass
140 |
141 | def run(self):
142 | entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()])
143 | content = [
144 | "# THIS FILE HAS BEEN AUTOGENERATED. To update:",
145 | "# 1. modify the `_deps` dict in setup.py",
146 | "# 2. run `make deps_table_update``",
147 | "deps = {",
148 | entries,
149 | "}",
150 | "",
151 | ]
152 | target = "src/diffusers/dependency_versions_table.py"
153 | print(f"updating {target}")
154 | with open(target, "w", encoding="utf-8", newline="\n") as f:
155 | f.write("\n".join(content))
156 |
157 |
158 | extras = {}
159 |
160 |
161 | extras = {}
162 | extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
163 | extras["docs"] = []
164 | extras["test"] = [
165 | "pytest",
166 | ]
167 | extras["dev"] = extras["quality"] + extras["test"]
168 |
169 | install_requires = [
170 | deps["filelock"],
171 | deps["huggingface-hub"],
172 | deps["numpy"],
173 | deps["regex"],
174 | deps["requests"],
175 | deps["torch"],
176 | deps["Pillow"],
177 | deps["tensorboard"],
178 | deps["modelcards"],
179 | ]
180 |
181 | setup(
182 | name="diffusers",
183 | version="0.0.4",
184 | description="Diffusers",
185 | long_description=open("README.md", "r", encoding="utf-8").read(),
186 | long_description_content_type="text/markdown",
187 | keywords="deep learning",
188 | license="Apache",
189 | author="The HuggingFace team",
190 | author_email="patrick@huggingface.co",
191 | url="https://github.com/huggingface/diffusers",
192 | package_dir={"": "src"},
193 | packages=find_packages("src"),
194 | python_requires=">=3.6.0",
195 | install_requires=install_requires,
196 | extras_require=extras,
197 | classifiers=[
198 | "Development Status :: 5 - Production/Stable",
199 | "Intended Audience :: Developers",
200 | "Intended Audience :: Education",
201 | "Intended Audience :: Science/Research",
202 | "License :: OSI Approved :: Apache Software License",
203 | "Operating System :: OS Independent",
204 | "Programming Language :: Python :: 3",
205 | "Programming Language :: Python :: 3.6",
206 | "Programming Language :: Python :: 3.7",
207 | "Programming Language :: Python :: 3.8",
208 | "Programming Language :: Python :: 3.9",
209 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
210 | ],
211 | cmdclass={"deps_table_update": DepsTableUpdateCommand},
212 | )
213 |
214 | # Release checklist
215 | # 1. Change the version in __init__.py and setup.py.
216 | # 2. Commit these changes with the message: "Release: Release"
217 | # 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for pypi' "
218 | # Push the tag to git: git push --tags origin main
219 | # 4. Run the following commands in the top-level directory:
220 | # python setup.py bdist_wheel
221 | # python setup.py sdist
222 | # 5. Upload the package to the pypi test server first:
223 | # twine upload dist/* -r pypitest
224 | # twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
225 | # 6. Check that you can install it in a virtualenv by running:
226 | # pip install -i https://testpypi.python.org/pypi diffusers
227 | # diffusers env
228 | # diffusers test
229 | # 7. Upload the final version to actual pypi:
230 | # twine upload dist/* -r pypi
231 | # 8. Add release notes to the tag in github once everything is looking hunky-dory.
232 | # 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master
233 |
--------------------------------------------------------------------------------
/src/diffusers/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 | from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
5 |
6 |
7 | __version__ = "0.0.4"
8 |
9 | from .modeling_utils import ModelMixin
10 | from .models import AutoencoderKL, NCSNpp, TemporalUNet, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel
11 | from .pipeline_utils import DiffusionPipeline
12 | from .pipelines import (
13 | BDDMPipeline,
14 | DDIMPipeline,
15 | DDPMPipeline,
16 | LatentDiffusionUncondPipeline,
17 | PNDMPipeline,
18 | ScoreSdeVePipeline,
19 | ScoreSdeVpPipeline,
20 | )
21 | from .schedulers import (
22 | DDIMScheduler,
23 | DDPMScheduler,
24 | GradTTSScheduler,
25 | PNDMScheduler,
26 | SchedulerMixin,
27 | ScoreSdeVeScheduler,
28 | ScoreSdeVpScheduler,
29 | )
30 |
31 |
32 | if is_transformers_available():
33 | from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
34 | from .models.unet_grad_tts import UNetGradTTSModel
35 | from .pipelines import GlidePipeline, LatentDiffusionPipeline
36 | else:
37 | from .utils.dummy_transformers_objects import *
38 |
39 |
40 | if is_transformers_available() and is_inflect_available() and is_unidecode_available():
41 | from .pipelines import GradTTSPipeline
42 | else:
43 | from .utils.dummy_transformers_and_inflect_and_unidecode_objects import *
44 |
--------------------------------------------------------------------------------
/src/diffusers/dependency_versions_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import sys
15 |
16 | from .dependency_versions_table import deps
17 | from .utils.versions import require_version, require_version_core
18 |
19 |
20 | # define which module versions we always want to check at run time
21 | # (usually the ones defined in `install_requires` in setup.py)
22 | #
23 | # order specific notes:
24 | # - tqdm must be checked before tokenizers
25 |
26 | pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27 | if sys.version_info < (3, 7):
28 | pkgs_to_check_at_runtime.append("dataclasses")
29 | if sys.version_info < (3, 8):
30 | pkgs_to_check_at_runtime.append("importlib_metadata")
31 |
32 | for pkg in pkgs_to_check_at_runtime:
33 | if pkg in deps:
34 | if pkg == "tokenizers":
35 | # must be loaded here, or else tqdm check may fail
36 | from .utils import is_tokenizers_available
37 |
38 | if not is_tokenizers_available():
39 | continue # not required, check version only if installed
40 |
41 | require_version_core(deps[pkg])
42 | else:
43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44 |
45 |
46 | def dep_version_check(pkg, hint=None):
47 | require_version(deps[pkg], hint)
48 |
--------------------------------------------------------------------------------
/src/diffusers/dependency_versions_table.py:
--------------------------------------------------------------------------------
1 | # THIS FILE HAS BEEN AUTOGENERATED. To update:
2 | # 1. modify the `_deps` dict in setup.py
3 | # 2. run `make deps_table_update``
4 | deps = {
5 | "Pillow": "Pillow",
6 | "black": "black~=22.0,>=22.3",
7 | "filelock": "filelock",
8 | "flake8": "flake8>=3.8.3",
9 | "huggingface-hub": "huggingface-hub",
10 | "isort": "isort>=5.5.4",
11 | "numpy": "numpy",
12 | "pytest": "pytest",
13 | "regex": "regex!=2019.12.17",
14 | "requests": "requests",
15 | "torch": "torch>=1.4",
16 | "tensorboard": "tensorboard",
17 | "modelcards": "modelcards==0.1.4",
18 | }
19 |
--------------------------------------------------------------------------------
/src/diffusers/hub_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import shutil
19 | from pathlib import Path
20 | from typing import Optional
21 |
22 | from diffusers import DiffusionPipeline
23 | from huggingface_hub import HfFolder, Repository, whoami
24 | from modelcards import CardData, ModelCard
25 |
26 | from .utils import logging
27 |
28 |
29 | logger = logging.get_logger(__name__)
30 |
31 |
32 | MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
33 |
34 |
35 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
36 | if token is None:
37 | token = HfFolder.get_token()
38 | if organization is None:
39 | username = whoami(token)["name"]
40 | return f"{username}/{model_id}"
41 | else:
42 | return f"{organization}/{model_id}"
43 |
44 |
45 | def init_git_repo(args, at_init: bool = False):
46 | """
47 | Args:
48 | Initializes a git repo in `args.hub_model_id`.
49 | at_init (`bool`, *optional*, defaults to `False`):
50 | Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
51 | and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
52 | """
53 | if args.local_rank not in [-1, 0]:
54 | return
55 | use_auth_token = True if args.hub_token is None else args.hub_token
56 | if args.hub_model_id is None:
57 | repo_name = Path(args.output_dir).absolute().name
58 | else:
59 | repo_name = args.hub_model_id
60 | if "/" not in repo_name:
61 | repo_name = get_full_repo_name(repo_name, token=args.hub_token)
62 |
63 | try:
64 | repo = Repository(
65 | args.output_dir,
66 | clone_from=repo_name,
67 | use_auth_token=use_auth_token,
68 | private=args.hub_private_repo,
69 | )
70 | except EnvironmentError:
71 | if args.overwrite_output_dir and at_init:
72 | # Try again after wiping output_dir
73 | shutil.rmtree(args.output_dir)
74 | repo = Repository(
75 | args.output_dir,
76 | clone_from=repo_name,
77 | use_auth_token=use_auth_token,
78 | )
79 | else:
80 | raise
81 |
82 | repo.git_pull()
83 |
84 | # By default, ignore the checkpoint folders
85 | if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
86 | with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
87 | writer.writelines(["checkpoint-*/"])
88 |
89 | return repo
90 |
91 |
92 | def push_to_hub(
93 | args,
94 | pipeline: DiffusionPipeline,
95 | repo: Repository,
96 | commit_message: Optional[str] = "End of training",
97 | blocking: bool = True,
98 | **kwargs,
99 | ) -> str:
100 | """
101 | Parameters:
102 | Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
103 | commit_message (`str`, *optional*, defaults to `"End of training"`):
104 | Message to commit while pushing.
105 | blocking (`bool`, *optional*, defaults to `True`):
106 | Whether the function should return only when the `git push` has finished.
107 | kwargs:
108 | Additional keyword arguments passed along to [`create_model_card`].
109 | Returns:
110 | The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
111 | commit and an object to track the progress of the commit if `blocking=True`
112 | """
113 |
114 | if args.hub_model_id is None:
115 | model_name = Path(args.output_dir).name
116 | else:
117 | model_name = args.hub_model_id.split("/")[-1]
118 |
119 | output_dir = args.output_dir
120 | os.makedirs(output_dir, exist_ok=True)
121 | logger.info(f"Saving pipeline checkpoint to {output_dir}")
122 | pipeline.save_pretrained(output_dir)
123 |
124 | # Only push from one node.
125 | if args.local_rank not in [-1, 0]:
126 | return
127 |
128 | # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
129 | if (
130 | blocking
131 | and len(repo.command_queue) > 0
132 | and repo.command_queue[-1] is not None
133 | and not repo.command_queue[-1].is_done
134 | ):
135 | repo.command_queue[-1]._process.kill()
136 |
137 | git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
138 | # push separately the model card to be independent from the rest of the model
139 | create_model_card(args, model_name=model_name)
140 | try:
141 | repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
142 | except EnvironmentError as exc:
143 | logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
144 |
145 | return git_head_commit_url
146 |
147 |
148 | def create_model_card(args, model_name):
149 | if args.local_rank not in [-1, 0]:
150 | return
151 |
152 | repo_name = get_full_repo_name(model_name, token=args.hub_token)
153 |
154 | model_card = ModelCard.from_template(
155 | card_data=CardData( # Card metadata object that will be converted to YAML block
156 | language="en",
157 | license="apache-2.0",
158 | library_name="diffusers",
159 | tags=[],
160 | datasets=args.dataset,
161 | metrics=[],
162 | ),
163 | template_path=MODEL_CARD_TEMPLATE_PATH,
164 | model_name=model_name,
165 | repo_name=repo_name,
166 | dataset_name=args.dataset,
167 | learning_rate=args.learning_rate,
168 | train_batch_size=args.train_batch_size,
169 | eval_batch_size=args.eval_batch_size,
170 | gradient_accumulation_steps=args.gradient_accumulation_steps,
171 | adam_beta1=args.adam_beta1,
172 | adam_beta2=args.adam_beta2,
173 | adam_weight_decay=args.adam_weight_decay,
174 | adam_epsilon=args.adam_epsilon,
175 | lr_scheduler=args.lr_scheduler,
176 | lr_warmup_steps=args.lr_warmup_steps,
177 | ema_inv_gamma=args.ema_inv_gamma,
178 | ema_power=args.ema_power,
179 | ema_max_decay=args.ema_max_decay,
180 | mixed_precision=args.mixed_precision,
181 | )
182 |
183 | card_path = os.path.join(args.output_dir, "README.md")
184 | model_card.save(card_path)
185 |
--------------------------------------------------------------------------------
/src/diffusers/models/README.md:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | - Models: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to denoise a noisy input to an image. Examples: UNet, Conditioned UNet, 3D UNet, Transformer UNet
4 |
5 | ## API
6 |
7 | TODO(Suraj, Patrick)
8 |
9 | ## Examples
10 |
11 | TODO(Suraj, Patrick)
12 |
--------------------------------------------------------------------------------
/src/diffusers/models/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | # Copyright 2022 The HuggingFace Team. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | from .unet import UNetModel
20 | from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
21 | from .unet_grad_tts import UNetGradTTSModel
22 | from .unet_ldm import UNetLDMModel
23 | from .unet_rl import TemporalUNet
24 | from .unet_sde_score_estimation import NCSNpp
25 | from .unet_unconditional import UNetUnconditionalModel
26 | from .vae import AutoencoderKL, VQModel
27 |
--------------------------------------------------------------------------------
/src/diffusers/models/embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 |
16 | import numpy as np
17 | import torch
18 | from torch import nn
19 |
20 |
21 | def get_timestep_embedding(
22 | timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
23 | ):
24 | """
25 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
26 |
27 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
28 | These may be fractional.
29 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
30 | embeddings. :return: an [N x dim] Tensor of positional embeddings.
31 | """
32 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
33 |
34 | half_dim = embedding_dim // 2
35 |
36 | emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
37 | emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
38 | emb = torch.exp(emb * emb_coeff)
39 | emb = timesteps[:, None].float() * emb[None, :]
40 |
41 | # scale embeddings
42 | emb = scale * emb
43 |
44 | # concat sine and cosine embeddings
45 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
46 |
47 | # flip sine and cosine embeddings
48 | if flip_sin_to_cos:
49 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
50 |
51 | # zero pad
52 | if embedding_dim % 2 == 1:
53 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
54 | return emb
55 |
56 |
57 | # unet_sde_score_estimation.py
58 | class GaussianFourierProjection(nn.Module):
59 | """Gaussian Fourier embeddings for noise levels."""
60 |
61 | def __init__(self, embedding_size=256, scale=1.0):
62 | super().__init__()
63 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
64 |
65 | def forward(self, x):
66 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
67 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
68 |
--------------------------------------------------------------------------------
/src/diffusers/models/unet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 |
14 | # limitations under the License.
15 |
16 | # helpers functions
17 |
18 | import torch
19 | from torch import nn
20 |
21 | from ..configuration_utils import ConfigMixin
22 | from ..modeling_utils import ModelMixin
23 | from .attention import AttentionBlock
24 | from .embeddings import get_timestep_embedding
25 | from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
26 | from .unet_new import UNetMidBlock2D
27 |
28 |
29 | def nonlinearity(x):
30 | # swish
31 | return x * torch.sigmoid(x)
32 |
33 |
34 | def Normalize(in_channels):
35 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36 |
37 |
38 | class UNetModel(ModelMixin, ConfigMixin):
39 | def __init__(
40 | self,
41 | ch=128,
42 | out_ch=3,
43 | ch_mult=(1, 1, 2, 2, 4, 4),
44 | num_res_blocks=2,
45 | attn_resolutions=(16,),
46 | dropout=0.0,
47 | resamp_with_conv=True,
48 | in_channels=3,
49 | resolution=256,
50 | ):
51 | super().__init__()
52 | self.register_to_config(
53 | ch=ch,
54 | out_ch=out_ch,
55 | ch_mult=ch_mult,
56 | num_res_blocks=num_res_blocks,
57 | attn_resolutions=attn_resolutions,
58 | dropout=dropout,
59 | resamp_with_conv=resamp_with_conv,
60 | in_channels=in_channels,
61 | resolution=resolution,
62 | )
63 | ch_mult = tuple(ch_mult)
64 | self.ch = ch
65 | self.temb_ch = self.ch * 4
66 | self.num_resolutions = len(ch_mult)
67 | self.num_res_blocks = num_res_blocks
68 | self.resolution = resolution
69 | self.in_channels = in_channels
70 |
71 | # timestep embedding
72 | self.temb = nn.Module()
73 | self.temb.dense = nn.ModuleList(
74 | [
75 | torch.nn.Linear(self.ch, self.temb_ch),
76 | torch.nn.Linear(self.temb_ch, self.temb_ch),
77 | ]
78 | )
79 |
80 | # downsampling
81 | self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
82 |
83 | curr_res = resolution
84 | in_ch_mult = (1,) + ch_mult
85 | self.down = nn.ModuleList()
86 | for i_level in range(self.num_resolutions):
87 | block = nn.ModuleList()
88 | attn = nn.ModuleList()
89 | block_in = ch * in_ch_mult[i_level]
90 | block_out = ch * ch_mult[i_level]
91 | for i_block in range(self.num_res_blocks):
92 | block.append(
93 | ResnetBlock2D(
94 | in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
95 | )
96 | )
97 | block_in = block_out
98 | if curr_res in attn_resolutions:
99 | attn.append(AttentionBlock(block_in, overwrite_qkv=True))
100 | down = nn.Module()
101 | down.block = block
102 | down.attn = attn
103 | if i_level != self.num_resolutions - 1:
104 | down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
105 | curr_res = curr_res // 2
106 | self.down.append(down)
107 |
108 | # middle
109 | self.mid = nn.Module()
110 | self.mid.block_1 = ResnetBlock2D(
111 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
112 | )
113 | self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
114 | self.mid.block_2 = ResnetBlock2D(
115 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
116 | )
117 | self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
118 | self.mid_new.resnets[0] = self.mid.block_1
119 | self.mid_new.attentions[0] = self.mid.attn_1
120 | self.mid_new.resnets[1] = self.mid.block_2
121 |
122 | # upsampling
123 | self.up = nn.ModuleList()
124 | for i_level in reversed(range(self.num_resolutions)):
125 | block = nn.ModuleList()
126 | attn = nn.ModuleList()
127 | block_out = ch * ch_mult[i_level]
128 | skip_in = ch * ch_mult[i_level]
129 | for i_block in range(self.num_res_blocks + 1):
130 | if i_block == self.num_res_blocks:
131 | skip_in = ch * in_ch_mult[i_level]
132 | block.append(
133 | ResnetBlock2D(
134 | in_channels=block_in + skip_in,
135 | out_channels=block_out,
136 | temb_channels=self.temb_ch,
137 | dropout=dropout,
138 | )
139 | )
140 | block_in = block_out
141 | if curr_res in attn_resolutions:
142 | attn.append(AttentionBlock(block_in, overwrite_qkv=True))
143 | up = nn.Module()
144 | up.block = block
145 | up.attn = attn
146 | if i_level != 0:
147 | up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
148 | curr_res = curr_res * 2
149 | self.up.insert(0, up) # prepend to get consistent order
150 |
151 | # end
152 | self.norm_out = Normalize(block_in)
153 | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
154 |
155 | def forward(self, sample, timesteps):
156 | x = sample
157 | assert x.shape[2] == x.shape[3] == self.resolution
158 |
159 | if not torch.is_tensor(timesteps):
160 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
161 |
162 | # timestep embedding
163 | temb = get_timestep_embedding(timesteps, self.ch)
164 | temb = self.temb.dense[0](temb)
165 | temb = nonlinearity(temb)
166 | temb = self.temb.dense[1](temb)
167 |
168 | # downsampling
169 | hs = [self.conv_in(x)]
170 | for i_level in range(self.num_resolutions):
171 | for i_block in range(self.num_res_blocks):
172 | h = self.down[i_level].block[i_block](hs[-1], temb)
173 | if len(self.down[i_level].attn) > 0:
174 | h = self.down[i_level].attn[i_block](h)
175 | hs.append(h)
176 | if i_level != self.num_resolutions - 1:
177 | hs.append(self.down[i_level].downsample(hs[-1]))
178 |
179 | # middle
180 | h = self.mid_new(hs[-1], temb)
181 |
182 | # upsampling
183 | for i_level in reversed(range(self.num_resolutions)):
184 | for i_block in range(self.num_res_blocks + 1):
185 | h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
186 | if len(self.up[i_level].attn) > 0:
187 | h = self.up[i_level].attn[i_block](h)
188 | if i_level != 0:
189 | h = self.up[i_level].upsample(h)
190 |
191 | # end
192 | h = self.norm_out(h)
193 | h = nonlinearity(h)
194 | h = self.conv_out(h)
195 | return h
196 |
--------------------------------------------------------------------------------
/src/diffusers/models/unet_grad_tts.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..configuration_utils import ConfigMixin
4 | from ..modeling_utils import ModelMixin
5 | from .attention import LinearAttention
6 | from .embeddings import get_timestep_embedding
7 | from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
8 | from .unet_new import UNetMidBlock2D
9 |
10 |
11 | class Mish(torch.nn.Module):
12 | def forward(self, x):
13 | return x * torch.tanh(torch.nn.functional.softplus(x))
14 |
15 |
16 | class Rezero(torch.nn.Module):
17 | def __init__(self, fn):
18 | super(Rezero, self).__init__()
19 | self.fn = fn
20 | self.g = torch.nn.Parameter(torch.zeros(1))
21 |
22 | def forward(self, x, encoder_out=None):
23 | return self.fn(x, encoder_out) * self.g
24 |
25 |
26 | class Block(torch.nn.Module):
27 | def __init__(self, dim, dim_out, groups=8):
28 | super(Block, self).__init__()
29 | self.block = torch.nn.Sequential(
30 | torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
31 | )
32 |
33 | def forward(self, x, mask):
34 | output = self.block(x * mask)
35 | return output * mask
36 |
37 |
38 | class Residual(torch.nn.Module):
39 | def __init__(self, fn):
40 | super(Residual, self).__init__()
41 | self.fn = fn
42 |
43 | def forward(self, x, *args, **kwargs):
44 | output = self.fn(x, *args, **kwargs) + x
45 | return output
46 |
47 |
48 | class UNetGradTTSModel(ModelMixin, ConfigMixin):
49 | def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
50 | super(UNetGradTTSModel, self).__init__()
51 |
52 | self.register_to_config(
53 | dim=dim,
54 | dim_mults=dim_mults,
55 | groups=groups,
56 | n_spks=n_spks,
57 | spk_emb_dim=spk_emb_dim,
58 | n_feats=n_feats,
59 | pe_scale=pe_scale,
60 | )
61 |
62 | self.dim = dim
63 | self.dim_mults = dim_mults
64 | self.groups = groups
65 | self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
66 | self.spk_emb_dim = spk_emb_dim
67 | self.pe_scale = pe_scale
68 |
69 | if n_spks > 1:
70 | self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
71 | self.spk_mlp = torch.nn.Sequential(
72 | torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
73 | )
74 |
75 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
76 |
77 | dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
78 | in_out = list(zip(dims[:-1], dims[1:]))
79 | self.downs = torch.nn.ModuleList([])
80 | self.ups = torch.nn.ModuleList([])
81 | num_resolutions = len(in_out)
82 |
83 | for ind, (dim_in, dim_out) in enumerate(in_out):
84 | is_last = ind >= (num_resolutions - 1)
85 | self.downs.append(
86 | torch.nn.ModuleList(
87 | [
88 | ResnetBlock2D(
89 | in_channels=dim_in,
90 | out_channels=dim_out,
91 | temb_channels=dim,
92 | groups=8,
93 | pre_norm=False,
94 | eps=1e-5,
95 | non_linearity="mish",
96 | overwrite_for_grad_tts=True,
97 | ),
98 | ResnetBlock2D(
99 | in_channels=dim_out,
100 | out_channels=dim_out,
101 | temb_channels=dim,
102 | groups=8,
103 | pre_norm=False,
104 | eps=1e-5,
105 | non_linearity="mish",
106 | overwrite_for_grad_tts=True,
107 | ),
108 | Residual(Rezero(LinearAttention(dim_out))),
109 | Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
110 | ]
111 | )
112 | )
113 |
114 | mid_dim = dims[-1]
115 |
116 | self.mid = UNetMidBlock2D(
117 | in_channels=mid_dim,
118 | temb_channels=dim,
119 | resnet_groups=8,
120 | resnet_pre_norm=False,
121 | resnet_eps=1e-5,
122 | resnet_act_fn="mish",
123 | attention_layer_type="linear",
124 | )
125 |
126 | self.mid_block1 = ResnetBlock2D(
127 | in_channels=mid_dim,
128 | out_channels=mid_dim,
129 | temb_channels=dim,
130 | groups=8,
131 | pre_norm=False,
132 | eps=1e-5,
133 | non_linearity="mish",
134 | overwrite_for_grad_tts=True,
135 | )
136 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
137 | self.mid_block2 = ResnetBlock2D(
138 | in_channels=mid_dim,
139 | out_channels=mid_dim,
140 | temb_channels=dim,
141 | groups=8,
142 | pre_norm=False,
143 | eps=1e-5,
144 | non_linearity="mish",
145 | overwrite_for_grad_tts=True,
146 | )
147 | self.mid.resnets[0] = self.mid_block1
148 | self.mid.attentions[0] = self.mid_attn
149 | self.mid.resnets[1] = self.mid_block2
150 |
151 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
152 | self.ups.append(
153 | torch.nn.ModuleList(
154 | [
155 | ResnetBlock2D(
156 | in_channels=dim_out * 2,
157 | out_channels=dim_in,
158 | temb_channels=dim,
159 | groups=8,
160 | pre_norm=False,
161 | eps=1e-5,
162 | non_linearity="mish",
163 | overwrite_for_grad_tts=True,
164 | ),
165 | ResnetBlock2D(
166 | in_channels=dim_in,
167 | out_channels=dim_in,
168 | temb_channels=dim,
169 | groups=8,
170 | pre_norm=False,
171 | eps=1e-5,
172 | non_linearity="mish",
173 | overwrite_for_grad_tts=True,
174 | ),
175 | Residual(Rezero(LinearAttention(dim_in))),
176 | Upsample2D(dim_in, use_conv_transpose=True),
177 | ]
178 | )
179 | )
180 | self.final_block = Block(dim, dim)
181 | self.final_conv = torch.nn.Conv2d(dim, 1, 1)
182 |
183 | def forward(self, sample, timesteps, mu, mask, spk=None):
184 | x = sample
185 | if self.n_spks > 1:
186 | # Get speaker embedding
187 | spk = self.spk_emb(spk)
188 |
189 | if not isinstance(spk, type(None)):
190 | s = self.spk_mlp(spk)
191 |
192 | t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
193 | t = self.mlp(t)
194 |
195 | if self.n_spks < 2:
196 | x = torch.stack([mu, x], 1)
197 | else:
198 | s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
199 | x = torch.stack([mu, x, s], 1)
200 | mask = mask.unsqueeze(1)
201 |
202 | hiddens = []
203 | masks = [mask]
204 | for resnet1, resnet2, attn, downsample in self.downs:
205 | mask_down = masks[-1]
206 | x = resnet1(x, t, mask_down)
207 | x = resnet2(x, t, mask_down)
208 | x = attn(x)
209 | hiddens.append(x)
210 | x = downsample(x * mask_down)
211 | masks.append(mask_down[:, :, :, ::2])
212 |
213 | masks = masks[:-1]
214 | mask_mid = masks[-1]
215 |
216 | x = self.mid(x, t, mask=mask_mid)
217 |
218 | for resnet1, resnet2, attn, upsample in self.ups:
219 | mask_up = masks.pop()
220 | x = torch.cat((x, hiddens.pop()), dim=1)
221 | x = resnet1(x, t, mask_up)
222 | x = resnet2(x, t, mask_up)
223 | x = attn(x)
224 | x = upsample(x * mask_up)
225 |
226 | x = self.final_block(x, mask)
227 | output = self.final_conv(x * mask)
228 |
229 | return (output * mask).squeeze(1)
230 |
--------------------------------------------------------------------------------
/src/diffusers/models/unet_rl.py:
--------------------------------------------------------------------------------
1 | # model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ..configuration_utils import ConfigMixin
7 | from ..modeling_utils import ModelMixin
8 | from .embeddings import get_timestep_embedding
9 | from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
10 |
11 |
12 | class SinusoidalPosEmb(nn.Module):
13 | def __init__(self, dim):
14 | super().__init__()
15 | self.dim = dim
16 |
17 | def forward(self, x):
18 | return get_timestep_embedding(x, self.dim)
19 |
20 |
21 | class RearrangeDim(nn.Module):
22 | def __init__(self):
23 | super().__init__()
24 |
25 | def forward(self, tensor):
26 | if len(tensor.shape) == 2:
27 | return tensor[:, :, None]
28 | if len(tensor.shape) == 3:
29 | return tensor[:, :, None, :]
30 | elif len(tensor.shape) == 4:
31 | return tensor[:, :, 0, :]
32 | else:
33 | raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
34 |
35 |
36 | class Conv1dBlock(nn.Module):
37 | """
38 | Conv1d --> GroupNorm --> Mish
39 | """
40 |
41 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
42 | super().__init__()
43 |
44 | self.block = nn.Sequential(
45 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
46 | RearrangeDim(),
47 | # Rearrange("batch channels horizon -> batch channels 1 horizon"),
48 | nn.GroupNorm(n_groups, out_channels),
49 | RearrangeDim(),
50 | # Rearrange("batch channels 1 horizon -> batch channels horizon"),
51 | nn.Mish(),
52 | )
53 |
54 | def forward(self, x):
55 | return self.block(x)
56 |
57 |
58 | class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
59 | def __init__(
60 | self,
61 | training_horizon=128,
62 | transition_dim=14,
63 | cond_dim=3,
64 | predict_epsilon=False,
65 | clip_denoised=True,
66 | dim=32,
67 | dim_mults=(1, 4, 8),
68 | ):
69 | super().__init__()
70 |
71 | self.transition_dim = transition_dim
72 | self.cond_dim = cond_dim
73 | self.predict_epsilon = predict_epsilon
74 | self.clip_denoised = clip_denoised
75 |
76 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
77 | in_out = list(zip(dims[:-1], dims[1:]))
78 |
79 | time_dim = dim
80 | self.time_mlp = nn.Sequential(
81 | SinusoidalPosEmb(dim),
82 | nn.Linear(dim, dim * 4),
83 | nn.Mish(),
84 | nn.Linear(dim * 4, dim),
85 | )
86 |
87 | self.downs = nn.ModuleList([])
88 | self.ups = nn.ModuleList([])
89 | num_resolutions = len(in_out)
90 |
91 | for ind, (dim_in, dim_out) in enumerate(in_out):
92 | is_last = ind >= (num_resolutions - 1)
93 |
94 | self.downs.append(
95 | nn.ModuleList(
96 | [
97 | ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
98 | ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
99 | Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
100 | ]
101 | )
102 | )
103 |
104 | if not is_last:
105 | training_horizon = training_horizon // 2
106 |
107 | mid_dim = dims[-1]
108 | self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
109 | self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
110 |
111 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
112 | is_last = ind >= (num_resolutions - 1)
113 |
114 | self.ups.append(
115 | nn.ModuleList(
116 | [
117 | ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
118 | ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
119 | Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
120 | ]
121 | )
122 | )
123 |
124 | if not is_last:
125 | training_horizon = training_horizon * 2
126 |
127 | self.final_conv = nn.Sequential(
128 | Conv1dBlock(dim, dim, kernel_size=5),
129 | nn.Conv1d(dim, transition_dim, 1),
130 | )
131 |
132 | def forward(self, sample, timesteps):
133 | """
134 | x : [ batch x horizon x transition ]
135 | """
136 | x = sample
137 |
138 | x = x.permute(0, 2, 1)
139 |
140 | t = self.time_mlp(timesteps)
141 | h = []
142 |
143 | for resnet, resnet2, downsample in self.downs:
144 | x = resnet(x, t)
145 | x = resnet2(x, t)
146 | h.append(x)
147 | x = downsample(x)
148 |
149 | x = self.mid_block1(x, t)
150 | x = self.mid_block2(x, t)
151 |
152 | for resnet, resnet2, upsample in self.ups:
153 | x = torch.cat((x, h.pop()), dim=1)
154 | x = resnet(x, t)
155 | x = resnet2(x, t)
156 | x = upsample(x)
157 |
158 | x = self.final_conv(x)
159 |
160 | x = x.permute(0, 2, 1)
161 | return x
162 |
163 |
164 | class TemporalValue(nn.Module):
165 | def __init__(
166 | self,
167 | horizon,
168 | transition_dim,
169 | cond_dim,
170 | dim=32,
171 | time_dim=None,
172 | out_dim=1,
173 | dim_mults=(1, 2, 4, 8),
174 | ):
175 | super().__init__()
176 |
177 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
178 | in_out = list(zip(dims[:-1], dims[1:]))
179 |
180 | time_dim = time_dim or dim
181 | self.time_mlp = nn.Sequential(
182 | SinusoidalPosEmb(dim),
183 | nn.Linear(dim, dim * 4),
184 | nn.Mish(),
185 | nn.Linear(dim * 4, dim),
186 | )
187 |
188 | self.blocks = nn.ModuleList([])
189 |
190 | print(in_out)
191 | for dim_in, dim_out in in_out:
192 | self.blocks.append(
193 | nn.ModuleList(
194 | [
195 | ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
196 | ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
197 | Downsample1d(dim_out),
198 | ]
199 | )
200 | )
201 |
202 | horizon = horizon // 2
203 |
204 | fc_dim = dims[-1] * max(horizon, 1)
205 |
206 | self.final_block = nn.Sequential(
207 | nn.Linear(fc_dim + time_dim, fc_dim // 2),
208 | nn.Mish(),
209 | nn.Linear(fc_dim // 2, out_dim),
210 | )
211 |
212 | def forward(self, x, cond, time, *args):
213 | """
214 | x : [ batch x horizon x transition ]
215 | """
216 | x = x.permute(0, 2, 1)
217 |
218 | t = self.time_mlp(time)
219 |
220 | for resnet, resnet2, downsample in self.blocks:
221 | x = resnet(x, t)
222 | x = resnet2(x, t)
223 | x = downsample(x)
224 |
225 | x = x.view(len(x), -1)
226 | out = self.final_block(torch.cat([x, t], dim=-1))
227 | return out
228 |
--------------------------------------------------------------------------------
/src/diffusers/pipeline_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The HuggingFace Inc. team.
3 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import importlib
18 | import os
19 | from typing import Optional, Union
20 |
21 | from huggingface_hub import snapshot_download
22 |
23 | from .configuration_utils import ConfigMixin
24 | from .utils import DIFFUSERS_CACHE, logging
25 |
26 |
27 | INDEX_FILE = "diffusion_model.pt"
28 |
29 |
30 | logger = logging.get_logger(__name__)
31 |
32 |
33 | LOADABLE_CLASSES = {
34 | "diffusers": {
35 | "ModelMixin": ["save_pretrained", "from_pretrained"],
36 | "SchedulerMixin": ["save_config", "from_config"],
37 | "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
38 | },
39 | "transformers": {
40 | "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
41 | "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
42 | "PreTrainedModel": ["save_pretrained", "from_pretrained"],
43 | },
44 | }
45 |
46 | ALL_IMPORTABLE_CLASSES = {}
47 | for library in LOADABLE_CLASSES:
48 | ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
49 |
50 |
51 | class DiffusionPipeline(ConfigMixin):
52 |
53 | config_name = "model_index.json"
54 |
55 | def register_modules(self, **kwargs):
56 | # import it here to avoid circular import
57 | from diffusers import pipelines
58 |
59 | for name, module in kwargs.items():
60 | # retrive library
61 | library = module.__module__.split(".")[0]
62 |
63 | # check if the module is a pipeline module
64 | pipeline_file = module.__module__.split(".")[-1]
65 | pipeline_dir = module.__module__.split(".")[-2]
66 | is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir)
67 |
68 | # if library is not in LOADABLE_CLASSES, then it is a custom module.
69 | # Or if it's a pipeline module, then the module is inside the pipeline
70 | # folder so we set the library to module name.
71 | if library not in LOADABLE_CLASSES or is_pipeline_module:
72 | library = pipeline_dir
73 |
74 | # retrive class_name
75 | class_name = module.__class__.__name__
76 |
77 | register_dict = {name: (library, class_name)}
78 |
79 | # save model index config
80 | self.register_to_config(**register_dict)
81 |
82 | # set models
83 | setattr(self, name, module)
84 |
85 | def save_pretrained(self, save_directory: Union[str, os.PathLike]):
86 | self.save_config(save_directory)
87 |
88 | model_index_dict = dict(self.config)
89 | model_index_dict.pop("_class_name")
90 | model_index_dict.pop("_diffusers_version")
91 | model_index_dict.pop("_module", None)
92 |
93 | for pipeline_component_name in model_index_dict.keys():
94 | sub_model = getattr(self, pipeline_component_name)
95 | model_cls = sub_model.__class__
96 |
97 | save_method_name = None
98 | # search for the model's base class in LOADABLE_CLASSES
99 | for library_name, library_classes in LOADABLE_CLASSES.items():
100 | library = importlib.import_module(library_name)
101 | for base_class, save_load_methods in library_classes.items():
102 | class_candidate = getattr(library, base_class)
103 | if issubclass(model_cls, class_candidate):
104 | # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
105 | save_method_name = save_load_methods[0]
106 | break
107 | if save_method_name is not None:
108 | break
109 |
110 | save_method = getattr(sub_model, save_method_name)
111 | save_method(os.path.join(save_directory, pipeline_component_name))
112 |
113 | @classmethod
114 | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
115 | r"""
116 | Add docstrings
117 | """
118 | cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
119 | resume_download = kwargs.pop("resume_download", False)
120 | proxies = kwargs.pop("proxies", None)
121 | local_files_only = kwargs.pop("local_files_only", False)
122 | use_auth_token = kwargs.pop("use_auth_token", None)
123 |
124 | # 1. Download the checkpoints and configs
125 | # use snapshot download here to get it working from from_pretrained
126 | if not os.path.isdir(pretrained_model_name_or_path):
127 | cached_folder = snapshot_download(
128 | pretrained_model_name_or_path,
129 | cache_dir=cache_dir,
130 | resume_download=resume_download,
131 | proxies=proxies,
132 | local_files_only=local_files_only,
133 | use_auth_token=use_auth_token,
134 | )
135 | else:
136 | cached_folder = pretrained_model_name_or_path
137 |
138 | config_dict = cls.get_config_dict(cached_folder)
139 |
140 | # 2. Load the pipeline class, if using custom module then load it from the hub
141 | # if we load from explicit class, let's use it
142 | if cls != DiffusionPipeline:
143 | pipeline_class = cls
144 | else:
145 | diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
146 | pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
147 |
148 | init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
149 |
150 | init_kwargs = {}
151 |
152 | # import it here to avoid circular import
153 | from diffusers import pipelines
154 |
155 | # 3. Load each module in the pipeline
156 | for name, (library_name, class_name) in init_dict.items():
157 | is_pipeline_module = hasattr(pipelines, library_name)
158 | # if the model is in a pipeline module, then we load it from the pipeline
159 | if is_pipeline_module:
160 | pipeline_module = getattr(pipelines, library_name)
161 | class_obj = getattr(pipeline_module, class_name)
162 | importable_classes = ALL_IMPORTABLE_CLASSES
163 | class_candidates = {c: class_obj for c in importable_classes.keys()}
164 | else:
165 | # else we just import it from the library.
166 | library = importlib.import_module(library_name)
167 | class_obj = getattr(library, class_name)
168 | importable_classes = LOADABLE_CLASSES[library_name]
169 | class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
170 |
171 | load_method_name = None
172 | for class_name, class_candidate in class_candidates.items():
173 | if issubclass(class_obj, class_candidate):
174 | load_method_name = importable_classes[class_name][1]
175 |
176 | load_method = getattr(class_obj, load_method_name)
177 |
178 | # check if the module is in a subdirectory
179 | if os.path.isdir(os.path.join(cached_folder, name)):
180 | loaded_sub_model = load_method(os.path.join(cached_folder, name))
181 | else:
182 | # else load from the root directory
183 | loaded_sub_model = load_method(cached_folder)
184 |
185 | init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
186 |
187 | # 5. Instantiate the pipeline
188 | model = pipeline_class(**init_kwargs)
189 | return model
190 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/README.md:
--------------------------------------------------------------------------------
1 | # Pipelines
2 |
3 | - Pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box
4 | - Pipelines should stay as close as possible to their original implementation
5 | - Pipelines can include components of other library, such as text-encoders.
6 |
7 | ## API
8 |
9 | TODO(Patrick, Anton, Suraj)
10 |
11 | ## Examples
12 |
13 | - DDPM for unconditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
14 | - DDIM for unconditional image generation in [pipeline_ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
15 | - PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
16 | - Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py).
17 | - Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py).
18 | - BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
19 | - Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py).
20 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
2 | from .bddm import BDDMPipeline
3 | from .ddim import DDIMPipeline
4 | from .ddpm import DDPMPipeline
5 | from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
6 | from .pndm import PNDMPipeline
7 | from .score_sde_ve import ScoreSdeVePipeline
8 | from .score_sde_vp import ScoreSdeVpPipeline
9 |
10 |
11 | if is_transformers_available():
12 | from .glide import GlidePipeline
13 | from .latent_diffusion import LatentDiffusionPipeline
14 |
15 |
16 | if is_transformers_available() and is_unidecode_available() and is_inflect_available():
17 | from .grad_tts import GradTTSPipeline
18 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/bddm/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_bddm import BDDMPipeline, DiffWave
2 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/ddim/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_ddim import DDIMPipeline
2 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/ddim/pipeline_ddim.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 |
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 |
19 | import tqdm
20 |
21 | from ...pipeline_utils import DiffusionPipeline
22 |
23 |
24 | class DDIMPipeline(DiffusionPipeline):
25 | def __init__(self, unet, noise_scheduler):
26 | super().__init__()
27 | noise_scheduler = noise_scheduler.set_format("pt")
28 | self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
29 |
30 | def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
31 | # eta corresponds to η in paper and should be between [0, 1]
32 | if torch_device is None:
33 | torch_device = "cuda" if torch.cuda.is_available() else "cpu"
34 |
35 | num_trained_timesteps = self.noise_scheduler.config.timesteps
36 | inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
37 |
38 | self.unet.to(torch_device)
39 |
40 | # Sample gaussian noise to begin loop
41 | image = torch.randn(
42 | (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
43 | generator=generator,
44 | )
45 | image = image.to(torch_device)
46 |
47 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
48 | # Ideally, read DDIM paper in-detail understanding
49 |
50 | # Notation ( ->
51 | # - pred_noise_t -> e_theta(x_t, t)
52 | # - pred_original_image -> f_theta(x_t, t) or x_0
53 | # - std_dev_t -> sigma_t
54 | # - eta -> η
55 | # - pred_image_direction -> "direction pointingc to x_t"
56 | # - pred_prev_image -> "x_t-1"
57 | for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
58 | # 1. predict noise residual
59 | with torch.no_grad():
60 | residual = self.unet(image, inference_step_times[t])
61 |
62 | # 2. predict previous mean of image x_t-1
63 | pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
64 |
65 | # 3. optionally sample variance
66 | variance = 0
67 | if eta > 0:
68 | noise = torch.randn(image.shape, generator=generator).to(image.device)
69 | variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
70 |
71 | # 4. set current image to prev_image: x_t -> x_t-1
72 | image = pred_prev_image + variance
73 |
74 | return image
75 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/ddpm/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_ddpm import DDPMPipeline
2 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/ddpm/pipeline_ddpm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 |
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 |
19 | import tqdm
20 |
21 | from ...pipeline_utils import DiffusionPipeline
22 |
23 |
24 | class DDPMPipeline(DiffusionPipeline):
25 | def __init__(self, unet, noise_scheduler):
26 | super().__init__()
27 | noise_scheduler = noise_scheduler.set_format("pt")
28 | self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
29 |
30 | def __call__(self, batch_size=1, generator=None, torch_device=None):
31 | if torch_device is None:
32 | torch_device = "cuda" if torch.cuda.is_available() else "cpu"
33 |
34 | self.unet.to(torch_device)
35 |
36 | # Sample gaussian noise to begin loop
37 | image = torch.randn(
38 | (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
39 | generator=generator,
40 | )
41 | image = image.to(torch_device)
42 |
43 | num_prediction_steps = len(self.noise_scheduler)
44 | for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
45 | # 1. predict noise residual
46 | with torch.no_grad():
47 | residual = self.unet(image, t)
48 |
49 | # 2. predict previous mean of image x_t-1
50 | pred_prev_image = self.noise_scheduler.step(residual, image, t)
51 |
52 | # 3. optionally sample variance
53 | variance = 0
54 | if t > 0:
55 | noise = torch.randn(image.shape, generator=generator).to(image.device)
56 | variance = self.noise_scheduler.get_variance(t).sqrt() * noise
57 |
58 | # 4. set current image to prev_image: x_t -> x_t-1
59 | image = pred_prev_image + variance
60 |
61 | return image
62 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/glide/__init__.py:
--------------------------------------------------------------------------------
1 | from ...utils import is_transformers_available
2 |
3 |
4 | if is_transformers_available():
5 | from .pipeline_glide import CLIPTextModel, GlidePipeline
6 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/grad_tts/__init__.py:
--------------------------------------------------------------------------------
1 | from ...utils import is_inflect_available, is_transformers_available, is_unidecode_available
2 |
3 |
4 | if is_transformers_available() and is_unidecode_available() and is_inflect_available():
5 | from .grad_tts_utils import GradTTSTokenizer
6 | from .pipeline_grad_tts import GradTTSPipeline, TextEncoder
7 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/latent_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from ...utils import is_transformers_available
2 |
3 |
4 | if is_transformers_available():
5 | from .pipeline_latent_diffusion import LatentDiffusionPipeline, LDMBertModel
6 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/latent_diffusion_uncond/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline
2 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import tqdm
4 |
5 | from ...pipeline_utils import DiffusionPipeline
6 |
7 |
8 | class LatentDiffusionUncondPipeline(DiffusionPipeline):
9 | def __init__(self, vqvae, unet, noise_scheduler):
10 | super().__init__()
11 | noise_scheduler = noise_scheduler.set_format("pt")
12 | self.register_modules(vqvae=vqvae, unet=unet, noise_scheduler=noise_scheduler)
13 |
14 | @torch.no_grad()
15 | def __call__(
16 | self,
17 | batch_size=1,
18 | generator=None,
19 | torch_device=None,
20 | eta=0.0,
21 | num_inference_steps=50,
22 | ):
23 | # eta corresponds to η in paper and should be between [0, 1]
24 |
25 | if torch_device is None:
26 | torch_device = "cuda" if torch.cuda.is_available() else "cpu"
27 |
28 | self.unet.to(torch_device)
29 | self.vqvae.to(torch_device)
30 |
31 | num_trained_timesteps = self.noise_scheduler.config.timesteps
32 | inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
33 |
34 | image = torch.randn(
35 | (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
36 | generator=generator,
37 | ).to(torch_device)
38 |
39 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
40 | # Ideally, read DDIM paper in-detail understanding
41 |
42 | # Notation ( ->
43 | # - pred_noise_t -> e_theta(x_t, t)
44 | # - pred_original_image -> f_theta(x_t, t) or x_0
45 | # - std_dev_t -> sigma_t
46 | # - eta -> η
47 | # - pred_image_direction -> "direction pointingc to x_t"
48 | # - pred_prev_image -> "x_t-1"
49 | for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
50 | # 1. predict noise residual
51 | timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
52 | pred_noise_t = self.unet(image, timesteps)
53 |
54 | # 2. predict previous mean of image x_t-1
55 | pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
56 |
57 | # 3. optionally sample variance
58 | variance = 0
59 | if eta > 0:
60 | noise = torch.randn(image.shape, generator=generator).to(image.device)
61 | variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
62 |
63 | # 4. set current image to prev_image: x_t -> x_t-1
64 | image = pred_prev_image + variance
65 |
66 | # decode image with vae
67 | image = self.vqvae.decode(image)
68 | return image
69 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/pndm/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_pndm import PNDMPipeline
2 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/pndm/pipeline_pndm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 |
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 |
19 | import tqdm
20 |
21 | from ...pipeline_utils import DiffusionPipeline
22 |
23 |
24 | class PNDMPipeline(DiffusionPipeline):
25 | def __init__(self, unet, noise_scheduler):
26 | super().__init__()
27 | noise_scheduler = noise_scheduler.set_format("pt")
28 | self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
29 |
30 | def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
31 | # For more information on the sampling method you can take a look at Algorithm 2 of
32 | # the official paper: https://arxiv.org/pdf/2202.09778.pdf
33 | if torch_device is None:
34 | torch_device = "cuda" if torch.cuda.is_available() else "cpu"
35 |
36 | self.unet.to(torch_device)
37 |
38 | # Sample gaussian noise to begin loop
39 | image = torch.randn(
40 | (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
41 | generator=generator,
42 | )
43 | image = image.to(torch_device)
44 |
45 | prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps)
46 | for t in tqdm.tqdm(range(len(prk_time_steps))):
47 | t_orig = prk_time_steps[t]
48 | residual = self.unet(image, t_orig)
49 |
50 | image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
51 |
52 | timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
53 | for t in tqdm.tqdm(range(len(timesteps))):
54 | t_orig = timesteps[t]
55 | residual = self.unet(image, t_orig)
56 |
57 | image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps)
58 |
59 | return image
60 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/score_sde_ve/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_score_sde_ve import ScoreSdeVePipeline
2 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import torch
3 |
4 | from diffusers import DiffusionPipeline
5 |
6 |
7 | # TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
8 | class ScoreSdeVePipeline(DiffusionPipeline):
9 | def __init__(self, model, scheduler):
10 | super().__init__()
11 | self.register_modules(model=model, scheduler=scheduler)
12 |
13 | def __call__(self, num_inference_steps=2000, generator=None):
14 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15 |
16 | img_size = self.model.config.image_size
17 | channels = self.model.config.num_channels
18 | shape = (1, channels, img_size, img_size)
19 |
20 | model = self.model.to(device)
21 |
22 | # TODO(Patrick) move to scheduler config
23 | n_steps = 1
24 |
25 | x = torch.randn(*shape) * self.scheduler.config.sigma_max
26 | x = x.to(device)
27 |
28 | self.scheduler.set_timesteps(num_inference_steps)
29 | self.scheduler.set_sigmas(num_inference_steps)
30 |
31 | for i, t in enumerate(self.scheduler.timesteps):
32 | sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
33 |
34 | for _ in range(n_steps):
35 | with torch.no_grad():
36 | result = self.model(x, sigma_t)
37 | x = self.scheduler.step_correct(result, x)
38 |
39 | with torch.no_grad():
40 | result = model(x, sigma_t)
41 |
42 | x, x_mean = self.scheduler.step_pred(result, x, t)
43 |
44 | return x_mean
45 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/score_sde_vp/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_score_sde_vp import ScoreSdeVpPipeline
2 |
--------------------------------------------------------------------------------
/src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import torch
3 |
4 | from diffusers import DiffusionPipeline
5 |
6 |
7 | # TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
8 | class ScoreSdeVpPipeline(DiffusionPipeline):
9 | def __init__(self, model, scheduler):
10 | super().__init__()
11 | self.register_modules(model=model, scheduler=scheduler)
12 |
13 | def __call__(self, num_inference_steps=1000, generator=None):
14 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15 |
16 | img_size = self.model.config.image_size
17 | channels = self.model.config.num_channels
18 | shape = (1, channels, img_size, img_size)
19 |
20 | model = self.model.to(device)
21 |
22 | x = torch.randn(*shape).to(device)
23 |
24 | self.scheduler.set_timesteps(num_inference_steps)
25 |
26 | for t in self.scheduler.timesteps:
27 | t = t * torch.ones(shape[0], device=device)
28 | scaled_t = t * (num_inference_steps - 1)
29 |
30 | with torch.no_grad():
31 | result = model(x, scaled_t)
32 |
33 | x, x_mean = self.scheduler.step_pred(result, x, t)
34 |
35 | x_mean = (x_mean + 1.0) / 2.0
36 |
37 | return x_mean
38 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/README.md:
--------------------------------------------------------------------------------
1 | # Schedulers
2 |
3 | - Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
4 | - Schedulers can be used interchangable between diffusion models in inference to find the preferred tradef-off between speed and generation quality.
5 | - Schedulers are available in numpy, but can easily be transformed into PyTorch.
6 |
7 | ## API
8 |
9 | - Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
10 | the forward pass.
11 | - Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
12 | with a `set_format(...)` method.
13 |
14 | ## Examples
15 |
16 | - The DDPM scheduler was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py). An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
17 | - The DDIM scheduler was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
18 | - The PNMD scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
19 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | # Copyright 2022 The HuggingFace Team. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | from .scheduling_ddim import DDIMScheduler
20 | from .scheduling_ddpm import DDPMScheduler
21 | from .scheduling_grad_tts import GradTTSScheduler
22 | from .scheduling_pndm import PNDMScheduler
23 | from .scheduling_sde_ve import ScoreSdeVeScheduler
24 | from .scheduling_sde_vp import ScoreSdeVpScheduler
25 | from .scheduling_utils import SchedulerMixin
26 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/scheduling_ddim.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16 | # and https://github.com/hojonathanho/diffusion
17 |
18 | import math
19 |
20 | import numpy as np
21 |
22 | from ..configuration_utils import ConfigMixin
23 | from .scheduling_utils import SchedulerMixin
24 |
25 |
26 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
27 | """
28 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
29 | (1-beta) over time from t = [0,1].
30 |
31 | :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
32 | from 0 to 1 and
33 | produces the cumulative product of (1-beta) up to that part of the diffusion process.
34 | :param max_beta: the maximum beta to use; use values lower than 1 to
35 | prevent singularities.
36 | """
37 |
38 | def alpha_bar(time_step):
39 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
40 |
41 | betas = []
42 | for i in range(num_diffusion_timesteps):
43 | t1 = i / num_diffusion_timesteps
44 | t2 = (i + 1) / num_diffusion_timesteps
45 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
46 | return np.array(betas, dtype=np.float32)
47 |
48 |
49 | class DDIMScheduler(SchedulerMixin, ConfigMixin):
50 | def __init__(
51 | self,
52 | timesteps=1000,
53 | beta_start=0.0001,
54 | beta_end=0.02,
55 | beta_schedule="linear",
56 | trained_betas=None,
57 | timestep_values=None,
58 | clip_sample=True,
59 | tensor_format="np",
60 | ):
61 | super().__init__()
62 | self.register_to_config(
63 | timesteps=timesteps,
64 | beta_start=beta_start,
65 | beta_end=beta_end,
66 | beta_schedule=beta_schedule,
67 | trained_betas=trained_betas,
68 | timestep_values=timestep_values,
69 | clip_sample=clip_sample,
70 | )
71 |
72 | if beta_schedule == "linear":
73 | self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
74 | elif beta_schedule == "scaled_linear":
75 | # this schedule is very specific to the latent diffusion model.
76 | self.betas = np.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=np.float32) ** 2
77 | elif beta_schedule == "squaredcos_cap_v2":
78 | # Glide cosine schedule
79 | self.betas = betas_for_alpha_bar(timesteps)
80 | else:
81 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
82 |
83 | self.alphas = 1.0 - self.betas
84 | self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
85 | self.one = np.array(1.0)
86 |
87 | self.set_format(tensor_format=tensor_format)
88 |
89 | def get_variance(self, t, num_inference_steps):
90 | orig_t = self.config.timesteps // num_inference_steps * t
91 | orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
92 |
93 | alpha_prod_t = self.alphas_cumprod[orig_t]
94 | alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
95 | beta_prod_t = 1 - alpha_prod_t
96 | beta_prod_t_prev = 1 - alpha_prod_t_prev
97 |
98 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
99 |
100 | return variance
101 |
102 | def step(self, residual, sample, t, num_inference_steps, eta, use_clipped_residual=False):
103 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
104 | # Ideally, read DDIM paper in-detail understanding
105 |
106 | # Notation ( ->
107 | # - pred_noise_t -> e_theta(x_t, t)
108 | # - pred_original_sample -> f_theta(x_t, t) or x_0
109 | # - std_dev_t -> sigma_t
110 | # - eta -> η
111 | # - pred_sample_direction -> "direction pointingc to x_t"
112 | # - pred_prev_sample -> "x_t-1"
113 |
114 | # 1. get actual t and t-1
115 | orig_t = self.config.timesteps // num_inference_steps * t
116 | orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
117 |
118 | # 2. compute alphas, betas
119 | alpha_prod_t = self.alphas_cumprod[orig_t]
120 | alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
121 | beta_prod_t = 1 - alpha_prod_t
122 |
123 | # 3. compute predicted original sample from predicted noise also called
124 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
125 | pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
126 |
127 | # 4. Clip "predicted x_0"
128 | if self.config.clip_sample:
129 | pred_original_sample = self.clip(pred_original_sample, -1, 1)
130 |
131 | # 5. compute variance: "sigma_t(η)" -> see formula (16)
132 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
133 | variance = self.get_variance(t, num_inference_steps)
134 | std_dev_t = eta * variance ** (0.5)
135 |
136 | if use_clipped_residual:
137 | # the residual is always re-derived from the clipped x_0 in Glide
138 | residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
139 |
140 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
141 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
142 |
143 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
144 | pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
145 |
146 | return pred_prev_sample
147 |
148 | def add_noise(self, original_samples, noise, timesteps):
149 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
150 | sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
151 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
152 | sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
153 |
154 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
155 | return noisy_samples
156 |
157 | def __len__(self):
158 | return self.config.timesteps
159 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/scheduling_ddpm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
16 |
17 | import math
18 |
19 | import numpy as np
20 |
21 | from ..configuration_utils import ConfigMixin
22 | from .scheduling_utils import SchedulerMixin
23 |
24 |
25 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
26 | """
27 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
28 | (1-beta) over time from t = [0,1].
29 |
30 | :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
31 | from 0 to 1 and
32 | produces the cumulative product of (1-beta) up to that part of the diffusion process.
33 | :param max_beta: the maximum beta to use; use values lower than 1 to
34 | prevent singularities.
35 | """
36 |
37 | def alpha_bar(time_step):
38 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
39 |
40 | betas = []
41 | for i in range(num_diffusion_timesteps):
42 | t1 = i / num_diffusion_timesteps
43 | t2 = (i + 1) / num_diffusion_timesteps
44 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
45 | return np.array(betas, dtype=np.float32)
46 |
47 |
48 | class DDPMScheduler(SchedulerMixin, ConfigMixin):
49 | def __init__(
50 | self,
51 | timesteps=1000,
52 | beta_start=0.0001,
53 | beta_end=0.02,
54 | beta_schedule="linear",
55 | trained_betas=None,
56 | timestep_values=None,
57 | variance_type="fixed_small",
58 | clip_sample=True,
59 | tensor_format="np",
60 | ):
61 | super().__init__()
62 | self.register_to_config(
63 | timesteps=timesteps,
64 | beta_start=beta_start,
65 | beta_end=beta_end,
66 | beta_schedule=beta_schedule,
67 | trained_betas=trained_betas,
68 | timestep_values=timestep_values,
69 | variance_type=variance_type,
70 | clip_sample=clip_sample,
71 | )
72 |
73 | if trained_betas is not None:
74 | self.betas = np.asarray(trained_betas)
75 | elif beta_schedule == "linear":
76 | self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
77 | elif beta_schedule == "squaredcos_cap_v2":
78 | # Glide cosine schedule
79 | self.betas = betas_for_alpha_bar(timesteps)
80 | else:
81 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
82 |
83 | self.alphas = 1.0 - self.betas
84 | self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
85 | self.one = np.array(1.0)
86 |
87 | self.set_format(tensor_format=tensor_format)
88 |
89 | def get_variance(self, t, variance_type=None):
90 | alpha_prod_t = self.alphas_cumprod[t]
91 | alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
92 |
93 | # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
94 | # and sample from it to get previous sample
95 | # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
96 | variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
97 |
98 | if variance_type is None:
99 | variance_type = self.config.variance_type
100 |
101 | # hacks - were probs added for training stability
102 | if variance_type == "fixed_small":
103 | variance = self.clip(variance, min_value=1e-20)
104 | # for rl-diffuser https://arxiv.org/abs/2205.09991
105 | elif variance_type == "fixed_small_log":
106 | variance = self.log(self.clip(variance, min_value=1e-20))
107 | elif variance_type == "fixed_large":
108 | variance = self.betas[t]
109 | elif variance_type == "fixed_large_log":
110 | # Glide max_log
111 | variance = self.log(self.betas[t])
112 |
113 | return variance
114 |
115 | def step(self, residual, sample, t, predict_epsilon=True):
116 | # 1. compute alphas, betas
117 | alpha_prod_t = self.alphas_cumprod[t]
118 | alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
119 | beta_prod_t = 1 - alpha_prod_t
120 | beta_prod_t_prev = 1 - alpha_prod_t_prev
121 |
122 | # 2. compute predicted original sample from predicted noise also called
123 | # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
124 | if predict_epsilon:
125 | pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
126 | else:
127 | pred_original_sample = residual
128 |
129 | # 3. Clip "predicted x_0"
130 | if self.config.clip_sample:
131 | pred_original_sample = self.clip(pred_original_sample, -1, 1)
132 |
133 | # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
134 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
135 | pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
136 | current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
137 |
138 | # 5. Compute predicted previous sample µ_t
139 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
140 | pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
141 |
142 | return pred_prev_sample
143 |
144 | def add_noise(self, original_samples, noise, timesteps):
145 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
146 | sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
147 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
148 | sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
149 |
150 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
151 | return noisy_samples
152 |
153 | def __len__(self):
154 | return self.config.timesteps
155 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/scheduling_grad_tts.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 |
17 | from ..configuration_utils import ConfigMixin
18 | from .scheduling_utils import SchedulerMixin
19 |
20 |
21 | class GradTTSScheduler(SchedulerMixin, ConfigMixin):
22 | def __init__(
23 | self,
24 | beta_start=0.05,
25 | beta_end=20,
26 | tensor_format="np",
27 | ):
28 | super().__init__()
29 | self.register_to_config(
30 | beta_start=beta_start,
31 | beta_end=beta_end,
32 | )
33 | self.set_format(tensor_format=tensor_format)
34 | self.betas = None
35 |
36 | def get_timesteps(self, num_inference_steps):
37 | return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)])
38 |
39 | def set_betas(self, num_inference_steps):
40 | timesteps = self.get_timesteps(num_inference_steps)
41 | self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps])
42 |
43 | def step(self, residual, sample, t, num_inference_steps):
44 | # This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix)
45 | if self.betas is None:
46 | self.set_betas(num_inference_steps)
47 |
48 | beta_t = self.betas[t]
49 | beta_t_deriv = beta_t / num_inference_steps
50 |
51 | sample_deriv = residual * beta_t_deriv / 2
52 |
53 | sample = sample + sample_deriv
54 | return sample
55 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/scheduling_pndm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
16 |
17 | import math
18 |
19 | import numpy as np
20 |
21 | from ..configuration_utils import ConfigMixin
22 | from .scheduling_utils import SchedulerMixin
23 |
24 |
25 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
26 | """
27 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
28 | (1-beta) over time from t = [0,1].
29 |
30 | :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
31 | from 0 to 1 and
32 | produces the cumulative product of (1-beta) up to that part of the diffusion process.
33 | :param max_beta: the maximum beta to use; use values lower than 1 to
34 | prevent singularities.
35 | """
36 |
37 | def alpha_bar(time_step):
38 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
39 |
40 | betas = []
41 | for i in range(num_diffusion_timesteps):
42 | t1 = i / num_diffusion_timesteps
43 | t2 = (i + 1) / num_diffusion_timesteps
44 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
45 | return np.array(betas, dtype=np.float32)
46 |
47 |
48 | class PNDMScheduler(SchedulerMixin, ConfigMixin):
49 | def __init__(
50 | self,
51 | timesteps=1000,
52 | beta_start=0.0001,
53 | beta_end=0.02,
54 | beta_schedule="linear",
55 | tensor_format="np",
56 | ):
57 | super().__init__()
58 | self.register_to_config(
59 | timesteps=timesteps,
60 | beta_start=beta_start,
61 | beta_end=beta_end,
62 | beta_schedule=beta_schedule,
63 | )
64 |
65 | if beta_schedule == "linear":
66 | self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
67 | elif beta_schedule == "squaredcos_cap_v2":
68 | # Glide cosine schedule
69 | self.betas = betas_for_alpha_bar(timesteps)
70 | else:
71 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
72 |
73 | self.alphas = 1.0 - self.betas
74 | self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
75 |
76 | self.one = np.array(1.0)
77 |
78 | self.set_format(tensor_format=tensor_format)
79 |
80 | # For now we only support F-PNDM, i.e. the runge-kutta method
81 | # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
82 | # mainly at formula (9), (12), (13) and the Algorithm 2.
83 | self.pndm_order = 4
84 |
85 | # running values
86 | self.cur_residual = 0
87 | self.cur_sample = None
88 | self.ets = []
89 | self.prk_time_steps = {}
90 | self.time_steps = {}
91 | self.set_prk_mode()
92 |
93 | def get_prk_time_steps(self, num_inference_steps):
94 | if num_inference_steps in self.prk_time_steps:
95 | return self.prk_time_steps[num_inference_steps]
96 |
97 | inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
98 |
99 | prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
100 | np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
101 | )
102 | self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
103 |
104 | return self.prk_time_steps[num_inference_steps]
105 |
106 | def get_time_steps(self, num_inference_steps):
107 | if num_inference_steps in self.time_steps:
108 | return self.time_steps[num_inference_steps]
109 |
110 | inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
111 | self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
112 |
113 | return self.time_steps[num_inference_steps]
114 |
115 | def set_prk_mode(self):
116 | self.mode = "prk"
117 |
118 | def set_plms_mode(self):
119 | self.mode = "plms"
120 |
121 | def step(self, *args, **kwargs):
122 | if self.mode == "prk":
123 | return self.step_prk(*args, **kwargs)
124 | if self.mode == "plms":
125 | return self.step_plms(*args, **kwargs)
126 |
127 | raise ValueError(f"mode {self.mode} does not exist.")
128 |
129 | def step_prk(self, residual, sample, t, num_inference_steps):
130 | prk_time_steps = self.get_prk_time_steps(num_inference_steps)
131 |
132 | t_orig = prk_time_steps[t // 4 * 4]
133 | t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
134 |
135 | if t % 4 == 0:
136 | self.cur_residual += 1 / 6 * residual
137 | self.ets.append(residual)
138 | self.cur_sample = sample
139 | elif (t - 1) % 4 == 0:
140 | self.cur_residual += 1 / 3 * residual
141 | elif (t - 2) % 4 == 0:
142 | self.cur_residual += 1 / 3 * residual
143 | elif (t - 3) % 4 == 0:
144 | residual = self.cur_residual + 1 / 6 * residual
145 | self.cur_residual = 0
146 |
147 | # cur_sample should not be `None`
148 | cur_sample = self.cur_sample if self.cur_sample is not None else sample
149 |
150 | return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)
151 |
152 | def step_plms(self, residual, sample, t, num_inference_steps):
153 | if len(self.ets) < 3:
154 | raise ValueError(
155 | f"{self.__class__} can only be run AFTER scheduler has been run "
156 | "in 'prk' mode for at least 12 iterations "
157 | "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
158 | "for more information."
159 | )
160 |
161 | timesteps = self.get_time_steps(num_inference_steps)
162 |
163 | t_orig = timesteps[t]
164 | t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
165 | self.ets.append(residual)
166 |
167 | residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
168 |
169 | return self.get_prev_sample(sample, t_orig, t_orig_prev, residual)
170 |
171 | def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
172 | # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
173 | # this function computes x_(t−δ) using the formula of (9)
174 | # Note that x_t needs to be added to both sides of the equation
175 |
176 | # Notation ( ->
177 | # alpha_prod_t -> α_t
178 | # alpha_prod_t_prev -> α_(t−δ)
179 | # beta_prod_t -> (1 - α_t)
180 | # beta_prod_t_prev -> (1 - α_(t−δ))
181 | # sample -> x_t
182 | # residual -> e_θ(x_t, t)
183 | # prev_sample -> x_(t−δ)
184 | alpha_prod_t = self.alphas_cumprod[t_orig + 1]
185 | alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1]
186 | beta_prod_t = 1 - alpha_prod_t
187 | beta_prod_t_prev = 1 - alpha_prod_t_prev
188 |
189 | # corresponds to (α_(t−δ) - α_t) divided by
190 | # denominator of x_t in formula (9) and plus 1
191 | # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
192 | # sqrt(α_(t−δ)) / sqrt(α_t))
193 | sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
194 |
195 | # corresponds to denominator of e_θ(x_t, t) in formula (9)
196 | residual_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
197 | alpha_prod_t * beta_prod_t * alpha_prod_t_prev
198 | ) ** (0.5)
199 |
200 | # full formula (9)
201 | prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * residual / residual_denom_coeff
202 |
203 | return prev_sample
204 |
205 | def __len__(self):
206 | return self.config.timesteps
207 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/scheduling_sde_ve.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
16 |
17 | # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
18 |
19 | import numpy as np
20 | import torch
21 |
22 | from ..configuration_utils import ConfigMixin
23 | from .scheduling_utils import SchedulerMixin
24 |
25 |
26 | class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
27 | def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
28 | super().__init__()
29 | self.register_to_config(
30 | snr=snr,
31 | sigma_min=sigma_min,
32 | sigma_max=sigma_max,
33 | sampling_eps=sampling_eps,
34 | )
35 |
36 | self.sigmas = None
37 | self.discrete_sigmas = None
38 | self.timesteps = None
39 |
40 | def set_timesteps(self, num_inference_steps):
41 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
42 |
43 | def set_sigmas(self, num_inference_steps):
44 | if self.timesteps is None:
45 | self.set_timesteps(num_inference_steps)
46 |
47 | self.discrete_sigmas = torch.exp(
48 | torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
49 | )
50 | self.sigmas = torch.tensor(
51 | [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
52 | )
53 |
54 | def step_pred(self, result, x, t):
55 | # TODO(Patrick) better comments + non-PyTorch
56 | t = t * torch.ones(x.shape[0], device=x.device)
57 | timestep = (t * (len(self.timesteps) - 1)).long()
58 |
59 | sigma = self.discrete_sigmas.to(t.device)[timestep]
60 | adjacent_sigma = torch.where(
61 | timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
62 | )
63 | f = torch.zeros_like(x)
64 | G = torch.sqrt(sigma**2 - adjacent_sigma**2)
65 |
66 | f = f - G[:, None, None, None] ** 2 * result
67 |
68 | z = torch.randn_like(x)
69 | x_mean = x - f
70 | x = x_mean + G[:, None, None, None] * z
71 | return x, x_mean
72 |
73 | def step_correct(self, result, x):
74 | # TODO(Patrick) better comments + non-PyTorch
75 | noise = torch.randn_like(x)
76 | grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
77 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
78 | step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
79 | step_size = step_size * torch.ones(x.shape[0], device=x.device)
80 | x_mean = x + step_size[:, None, None, None] * result
81 |
82 | x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
83 |
84 | return x
85 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/scheduling_sde_vp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
16 |
17 | # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
18 |
19 | import numpy as np
20 | import torch
21 |
22 | from ..configuration_utils import ConfigMixin
23 | from .scheduling_utils import SchedulerMixin
24 |
25 |
26 | class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
27 | def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
28 | super().__init__()
29 | self.register_to_config(
30 | beta_min=beta_min,
31 | beta_max=beta_max,
32 | sampling_eps=sampling_eps,
33 | )
34 |
35 | self.sigmas = None
36 | self.discrete_sigmas = None
37 | self.timesteps = None
38 |
39 | def set_timesteps(self, num_inference_steps):
40 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
41 |
42 | def step_pred(self, result, x, t):
43 | # TODO(Patrick) better comments + non-PyTorch
44 | # postprocess model result
45 | log_mean_coeff = (
46 | -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
47 | )
48 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
49 | result = -result / std[:, None, None, None]
50 |
51 | # compute
52 | dt = -1.0 / len(self.timesteps)
53 |
54 | beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
55 | drift = -0.5 * beta_t[:, None, None, None] * x
56 | diffusion = torch.sqrt(beta_t)
57 | drift = drift - diffusion[:, None, None, None] ** 2 * result
58 | x_mean = x + drift * dt
59 |
60 | # add noise
61 | z = torch.randn_like(x)
62 | x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
63 |
64 | return x, x_mean
65 |
--------------------------------------------------------------------------------
/src/diffusers/schedulers/scheduling_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Union
15 |
16 | import numpy as np
17 | import torch
18 |
19 |
20 | SCHEDULER_CONFIG_NAME = "scheduler_config.json"
21 |
22 |
23 | class SchedulerMixin:
24 |
25 | config_name = SCHEDULER_CONFIG_NAME
26 |
27 | def set_format(self, tensor_format="pt"):
28 | self.tensor_format = tensor_format
29 | if tensor_format == "pt":
30 | for key, value in vars(self).items():
31 | if isinstance(value, np.ndarray):
32 | setattr(self, key, torch.from_numpy(value))
33 |
34 | return self
35 |
36 | def clip(self, tensor, min_value=None, max_value=None):
37 | tensor_format = getattr(self, "tensor_format", "pt")
38 |
39 | if tensor_format == "np":
40 | return np.clip(tensor, min_value, max_value)
41 | elif tensor_format == "pt":
42 | return torch.clamp(tensor, min_value, max_value)
43 |
44 | raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
45 |
46 | def log(self, tensor):
47 | tensor_format = getattr(self, "tensor_format", "pt")
48 |
49 | if tensor_format == "np":
50 | return np.log(tensor)
51 | elif tensor_format == "pt":
52 | return torch.log(tensor)
53 |
54 | raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
55 |
56 | def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
57 | """
58 | Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
59 |
60 | Args:
61 | timesteps: an array or tensor of values to extract.
62 | broadcast_array: an array with a larger shape of K dimensions with the batch
63 | dimension equal to the length of timesteps.
64 | Returns:
65 | a tensor of shape [batch_size, 1, ...] where the shape has K dims.
66 | """
67 |
68 | tensor_format = getattr(self, "tensor_format", "pt")
69 | values = values.flatten()
70 |
71 | while len(values.shape) < len(broadcast_array.shape):
72 | values = values[..., None]
73 | if tensor_format == "pt":
74 | values = values.to(broadcast_array.device)
75 |
76 | return values
77 |
--------------------------------------------------------------------------------
/src/diffusers/testing_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import unittest
4 | from distutils.util import strtobool
5 |
6 | import torch
7 |
8 |
9 | global_rng = random.Random()
10 | torch_device = "cuda" if torch.cuda.is_available() else "cpu"
11 |
12 |
13 | def parse_flag_from_env(key, default=False):
14 | try:
15 | value = os.environ[key]
16 | except KeyError:
17 | # KEY isn't set, default to `default`.
18 | _value = default
19 | else:
20 | # KEY is set, convert it to True or False.
21 | try:
22 | _value = strtobool(value)
23 | except ValueError:
24 | # More values are supported, but let's keep the message simple.
25 | raise ValueError(f"If set, {key} must be yes or no.")
26 | return _value
27 |
28 |
29 | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
30 |
31 |
32 | def floats_tensor(shape, scale=1.0, rng=None, name=None):
33 | """Creates a random float32 tensor"""
34 | if rng is None:
35 | rng = global_rng
36 |
37 | total_dims = 1
38 | for dim in shape:
39 | total_dims *= dim
40 |
41 | values = []
42 | for _ in range(total_dims):
43 | values.append(rng.random() * scale)
44 |
45 | return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
46 |
47 |
48 | def slow(test_case):
49 | """
50 | Decorator marking a test as slow.
51 |
52 | Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
53 |
54 | """
55 | return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
56 |
--------------------------------------------------------------------------------
/src/diffusers/training_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 |
5 |
6 | class EMAModel:
7 | """
8 | Exponential Moving Average of models weights
9 | """
10 |
11 | def __init__(
12 | self,
13 | model,
14 | update_after_step=0,
15 | inv_gamma=1.0,
16 | power=2 / 3,
17 | min_value=0.0,
18 | max_value=0.9999,
19 | device=None,
20 | ):
21 | """
22 | @crowsonkb's notes on EMA Warmup:
23 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
24 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
25 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
26 | at 215.4k steps).
27 | Args:
28 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
29 | power (float): Exponential factor of EMA warmup. Default: 2/3.
30 | min_value (float): The minimum EMA decay rate. Default: 0.
31 | """
32 |
33 | self.averaged_model = copy.deepcopy(model).eval()
34 | self.averaged_model.requires_grad_(False)
35 |
36 | self.update_after_step = update_after_step
37 | self.inv_gamma = inv_gamma
38 | self.power = power
39 | self.min_value = min_value
40 | self.max_value = max_value
41 |
42 | if device is not None:
43 | self.averaged_model = self.averaged_model.to(device=device)
44 |
45 | self.decay = 0.0
46 | self.optimization_step = 0
47 |
48 | def get_decay(self, optimization_step):
49 | """
50 | Compute the decay factor for the exponential moving average.
51 | """
52 | step = max(0, optimization_step - self.update_after_step - 1)
53 | value = 1 - (1 + step / self.inv_gamma) ** -self.power
54 |
55 | if step <= 0:
56 | return 0.0
57 |
58 | return max(self.min_value, min(value, self.max_value))
59 |
60 | @torch.no_grad()
61 | def step(self, new_model):
62 | ema_state_dict = {}
63 | ema_params = self.averaged_model.state_dict()
64 |
65 | self.decay = self.get_decay(self.optimization_step)
66 |
67 | for key, param in new_model.named_parameters():
68 | if isinstance(param, dict):
69 | continue
70 | try:
71 | ema_param = ema_params[key]
72 | except KeyError:
73 | ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
74 | ema_params[key] = ema_param
75 |
76 | if not param.requires_grad:
77 | ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
78 | ema_param = ema_params[key]
79 | else:
80 | ema_param.mul_(self.decay)
81 | ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
82 |
83 | ema_state_dict[key] = ema_param
84 |
85 | for key, param in new_model.named_buffers():
86 | ema_state_dict[key] = param
87 |
88 | self.averaged_model.load_state_dict(ema_state_dict, strict=False)
89 | self.optimization_step += 1
90 |
--------------------------------------------------------------------------------
/src/diffusers/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import importlib
15 | import os
16 | from collections import OrderedDict
17 |
18 | import importlib_metadata
19 | from requests.exceptions import HTTPError
20 |
21 | from .logging import get_logger
22 |
23 |
24 | logger = get_logger(__name__)
25 |
26 |
27 | hf_cache_home = os.path.expanduser(
28 | os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
29 | )
30 | default_cache_path = os.path.join(hf_cache_home, "diffusers")
31 |
32 |
33 | CONFIG_NAME = "config.json"
34 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
35 | DIFFUSERS_CACHE = default_cache_path
36 | DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
37 | HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
38 |
39 |
40 | _transformers_available = importlib.util.find_spec("transformers") is not None
41 | try:
42 | _transformers_version = importlib_metadata.version("transformers")
43 | logger.debug(f"Successfully imported transformers version {_transformers_version}")
44 | except importlib_metadata.PackageNotFoundError:
45 | _transformers_available = False
46 |
47 |
48 | _inflect_available = importlib.util.find_spec("inflect") is not None
49 | try:
50 | _inflect_version = importlib_metadata.version("inflect")
51 | logger.debug(f"Successfully imported inflect version {_inflect_version}")
52 | except importlib_metadata.PackageNotFoundError:
53 | _inflect_available = False
54 |
55 |
56 | _unidecode_available = importlib.util.find_spec("unidecode") is not None
57 | try:
58 | _unidecode_version = importlib_metadata.version("unidecode")
59 | logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
60 | except importlib_metadata.PackageNotFoundError:
61 | _unidecode_available = False
62 |
63 |
64 | def is_transformers_available():
65 | return _transformers_available
66 |
67 |
68 | def is_inflect_available():
69 | return _inflect_available
70 |
71 |
72 | def is_unidecode_available():
73 | return _unidecode_available
74 |
75 |
76 | class RepositoryNotFoundError(HTTPError):
77 | """
78 | Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
79 | not have access to.
80 | """
81 |
82 |
83 | class EntryNotFoundError(HTTPError):
84 | """Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
85 |
86 |
87 | class RevisionNotFoundError(HTTPError):
88 | """Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
89 |
90 |
91 | TRANSFORMERS_IMPORT_ERROR = """
92 | {0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
93 | install transformers`
94 | """
95 |
96 |
97 | UNIDECODE_IMPORT_ERROR = """
98 | {0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
99 | Unidecode`
100 | """
101 |
102 |
103 | INFLECT_IMPORT_ERROR = """
104 | {0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
105 | inflect`
106 | """
107 |
108 |
109 | BACKENDS_MAPPING = OrderedDict(
110 | [
111 | ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
112 | ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
113 | ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
114 | ]
115 | )
116 |
117 |
118 | def requires_backends(obj, backends):
119 | if not isinstance(backends, (list, tuple)):
120 | backends = [backends]
121 |
122 | name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
123 | checks = (BACKENDS_MAPPING[backend] for backend in backends)
124 | failed = [msg.format(name) for available, msg in checks if not available()]
125 | if failed:
126 | raise ImportError("".join(failed))
127 |
128 |
129 | class DummyObject(type):
130 | """
131 | Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
132 | `requires_backend` each time a user tries to access any method of that class.
133 | """
134 |
135 | def __getattr__(cls, key):
136 | if key.startswith("_"):
137 | return super().__getattr__(cls, key)
138 | requires_backends(cls, cls._backends)
139 |
--------------------------------------------------------------------------------
/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | # flake8: noqa
3 | from ..utils import DummyObject, requires_backends
4 |
5 |
6 | class GradTTSPipeline(metaclass=DummyObject):
7 | _backends = ["transformers", "inflect", "unidecode"]
8 |
9 | def __init__(self, *args, **kwargs):
10 | requires_backends(self, ["transformers", "inflect", "unidecode"])
11 |
--------------------------------------------------------------------------------
/src/diffusers/utils/dummy_transformers_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | # flake8: noqa
3 | from ..utils import DummyObject, requires_backends
4 |
5 |
6 | class GlideSuperResUNetModel(metaclass=DummyObject):
7 | _backends = ["transformers"]
8 |
9 | def __init__(self, *args, **kwargs):
10 | requires_backends(self, ["transformers"])
11 |
12 |
13 | class GlideTextToImageUNetModel(metaclass=DummyObject):
14 | _backends = ["transformers"]
15 |
16 | def __init__(self, *args, **kwargs):
17 | requires_backends(self, ["transformers"])
18 |
19 |
20 | class GlideUNetModel(metaclass=DummyObject):
21 | _backends = ["transformers"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["transformers"])
25 |
26 |
27 | class UNetGradTTSModel(metaclass=DummyObject):
28 | _backends = ["transformers"]
29 |
30 | def __init__(self, *args, **kwargs):
31 | requires_backends(self, ["transformers"])
32 |
33 |
34 | class GlidePipeline(metaclass=DummyObject):
35 | _backends = ["transformers"]
36 |
37 | def __init__(self, *args, **kwargs):
38 | requires_backends(self, ["transformers"])
39 |
40 |
41 | class LatentDiffusionPipeline(metaclass=DummyObject):
42 | _backends = ["transformers"]
43 |
44 | def __init__(self, *args, **kwargs):
45 | requires_backends(self, ["transformers"])
46 |
--------------------------------------------------------------------------------
/src/diffusers/utils/logging.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Optuna, Hugging Face
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Logging utilities."""
16 |
17 | import logging
18 | import os
19 | import sys
20 | import threading
21 | from logging import CRITICAL # NOQA
22 | from logging import DEBUG # NOQA
23 | from logging import ERROR # NOQA
24 | from logging import FATAL # NOQA
25 | from logging import INFO # NOQA
26 | from logging import NOTSET # NOQA
27 | from logging import WARN # NOQA
28 | from logging import WARNING # NOQA
29 | from typing import Optional
30 |
31 | from tqdm import auto as tqdm_lib
32 |
33 |
34 | _lock = threading.Lock()
35 | _default_handler: Optional[logging.Handler] = None
36 |
37 | log_levels = {
38 | "debug": logging.DEBUG,
39 | "info": logging.INFO,
40 | "warning": logging.WARNING,
41 | "error": logging.ERROR,
42 | "critical": logging.CRITICAL,
43 | }
44 |
45 | _default_log_level = logging.WARNING
46 |
47 | _tqdm_active = True
48 |
49 |
50 | def _get_default_logging_level():
51 | """
52 | If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
53 | not - fall back to `_default_log_level`
54 | """
55 | env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
56 | if env_level_str:
57 | if env_level_str in log_levels:
58 | return log_levels[env_level_str]
59 | else:
60 | logging.getLogger().warning(
61 | f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
62 | f"has to be one of: { ', '.join(log_levels.keys()) }"
63 | )
64 | return _default_log_level
65 |
66 |
67 | def _get_library_name() -> str:
68 |
69 | return __name__.split(".")[0]
70 |
71 |
72 | def _get_library_root_logger() -> logging.Logger:
73 |
74 | return logging.getLogger(_get_library_name())
75 |
76 |
77 | def _configure_library_root_logger() -> None:
78 |
79 | global _default_handler
80 |
81 | with _lock:
82 | if _default_handler:
83 | # This library has already configured the library root logger.
84 | return
85 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
86 | _default_handler.flush = sys.stderr.flush
87 |
88 | # Apply our default configuration to the library root logger.
89 | library_root_logger = _get_library_root_logger()
90 | library_root_logger.addHandler(_default_handler)
91 | library_root_logger.setLevel(_get_default_logging_level())
92 | library_root_logger.propagate = False
93 |
94 |
95 | def _reset_library_root_logger() -> None:
96 |
97 | global _default_handler
98 |
99 | with _lock:
100 | if not _default_handler:
101 | return
102 |
103 | library_root_logger = _get_library_root_logger()
104 | library_root_logger.removeHandler(_default_handler)
105 | library_root_logger.setLevel(logging.NOTSET)
106 | _default_handler = None
107 |
108 |
109 | def get_log_levels_dict():
110 | return log_levels
111 |
112 |
113 | def get_logger(name: Optional[str] = None) -> logging.Logger:
114 | """
115 | Return a logger with the specified name.
116 |
117 | This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
118 | """
119 |
120 | if name is None:
121 | name = _get_library_name()
122 |
123 | _configure_library_root_logger()
124 | return logging.getLogger(name)
125 |
126 |
127 | def get_verbosity() -> int:
128 | """
129 | Return the current level for the 🤗 Diffusers' root logger as an int.
130 |
131 | Returns:
132 | `int`: The logging level.
133 |
134 |
135 |
136 | 🤗 Diffusers has following logging levels:
137 |
138 | - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
139 | - 40: `diffusers.logging.ERROR`
140 | - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
141 | - 20: `diffusers.logging.INFO`
142 | - 10: `diffusers.logging.DEBUG`
143 |
144 | """
145 |
146 | _configure_library_root_logger()
147 | return _get_library_root_logger().getEffectiveLevel()
148 |
149 |
150 | def set_verbosity(verbosity: int) -> None:
151 | """
152 | Set the verbosity level for the 🤗 Diffusers' root logger.
153 |
154 | Args:
155 | verbosity (`int`):
156 | Logging level, e.g., one of:
157 |
158 | - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
159 | - `diffusers.logging.ERROR`
160 | - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
161 | - `diffusers.logging.INFO`
162 | - `diffusers.logging.DEBUG`
163 | """
164 |
165 | _configure_library_root_logger()
166 | _get_library_root_logger().setLevel(verbosity)
167 |
168 |
169 | def set_verbosity_info():
170 | """Set the verbosity to the `INFO` level."""
171 | return set_verbosity(INFO)
172 |
173 |
174 | def set_verbosity_warning():
175 | """Set the verbosity to the `WARNING` level."""
176 | return set_verbosity(WARNING)
177 |
178 |
179 | def set_verbosity_debug():
180 | """Set the verbosity to the `DEBUG` level."""
181 | return set_verbosity(DEBUG)
182 |
183 |
184 | def set_verbosity_error():
185 | """Set the verbosity to the `ERROR` level."""
186 | return set_verbosity(ERROR)
187 |
188 |
189 | def disable_default_handler() -> None:
190 | """Disable the default handler of the HuggingFace Diffusers' root logger."""
191 |
192 | _configure_library_root_logger()
193 |
194 | assert _default_handler is not None
195 | _get_library_root_logger().removeHandler(_default_handler)
196 |
197 |
198 | def enable_default_handler() -> None:
199 | """Enable the default handler of the HuggingFace Diffusers' root logger."""
200 |
201 | _configure_library_root_logger()
202 |
203 | assert _default_handler is not None
204 | _get_library_root_logger().addHandler(_default_handler)
205 |
206 |
207 | def add_handler(handler: logging.Handler) -> None:
208 | """adds a handler to the HuggingFace Diffusers' root logger."""
209 |
210 | _configure_library_root_logger()
211 |
212 | assert handler is not None
213 | _get_library_root_logger().addHandler(handler)
214 |
215 |
216 | def remove_handler(handler: logging.Handler) -> None:
217 | """removes given handler from the HuggingFace Diffusers' root logger."""
218 |
219 | _configure_library_root_logger()
220 |
221 | assert handler is not None and handler not in _get_library_root_logger().handlers
222 | _get_library_root_logger().removeHandler(handler)
223 |
224 |
225 | def disable_propagation() -> None:
226 | """
227 | Disable propagation of the library log outputs. Note that log propagation is disabled by default.
228 | """
229 |
230 | _configure_library_root_logger()
231 | _get_library_root_logger().propagate = False
232 |
233 |
234 | def enable_propagation() -> None:
235 | """
236 | Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
237 | double logging if the root logger has been configured.
238 | """
239 |
240 | _configure_library_root_logger()
241 | _get_library_root_logger().propagate = True
242 |
243 |
244 | def enable_explicit_format() -> None:
245 | """
246 | Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
247 | ```
248 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
249 | ```
250 | All handlers currently bound to the root logger are affected by this method.
251 | """
252 | handlers = _get_library_root_logger().handlers
253 |
254 | for handler in handlers:
255 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
256 | handler.setFormatter(formatter)
257 |
258 |
259 | def reset_format() -> None:
260 | """
261 | Resets the formatting for HuggingFace Diffusers' loggers.
262 |
263 | All handlers currently bound to the root logger are affected by this method.
264 | """
265 | handlers = _get_library_root_logger().handlers
266 |
267 | for handler in handlers:
268 | handler.setFormatter(None)
269 |
270 |
271 | def warning_advice(self, *args, **kwargs):
272 | """
273 | This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
274 | warning will not be printed
275 | """
276 | no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
277 | if no_advisory_warnings:
278 | return
279 | self.warning(*args, **kwargs)
280 |
281 |
282 | logging.Logger.warning_advice = warning_advice
283 |
284 |
285 | class EmptyTqdm:
286 | """Dummy tqdm which doesn't do anything."""
287 |
288 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
289 | self._iterator = args[0] if args else None
290 |
291 | def __iter__(self):
292 | return iter(self._iterator)
293 |
294 | def __getattr__(self, _):
295 | """Return empty function."""
296 |
297 | def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
298 | return
299 |
300 | return empty_fn
301 |
302 | def __enter__(self):
303 | return self
304 |
305 | def __exit__(self, type_, value, traceback):
306 | return
307 |
308 |
309 | class _tqdm_cls:
310 | def __call__(self, *args, **kwargs):
311 | if _tqdm_active:
312 | return tqdm_lib.tqdm(*args, **kwargs)
313 | else:
314 | return EmptyTqdm(*args, **kwargs)
315 |
316 | def set_lock(self, *args, **kwargs):
317 | self._lock = None
318 | if _tqdm_active:
319 | return tqdm_lib.tqdm.set_lock(*args, **kwargs)
320 |
321 | def get_lock(self):
322 | if _tqdm_active:
323 | return tqdm_lib.tqdm.get_lock()
324 |
325 |
326 | tqdm = _tqdm_cls()
327 |
328 |
329 | def is_progress_bar_enabled() -> bool:
330 | """Return a boolean indicating whether tqdm progress bars are enabled."""
331 | global _tqdm_active
332 | return bool(_tqdm_active)
333 |
334 |
335 | def enable_progress_bar():
336 | """Enable tqdm progress bar."""
337 | global _tqdm_active
338 | _tqdm_active = True
339 |
340 |
341 | def disable_progress_bar():
342 | """Disable tqdm progress bar."""
343 | global _tqdm_active
344 | _tqdm_active = False
345 |
--------------------------------------------------------------------------------
/src/diffusers/utils/model_card_template.md:
--------------------------------------------------------------------------------
1 | ---
2 | {{ card_data }}
3 | ---
4 |
5 |
7 |
8 | # {{ model_name | default("Diffusion Model") }}
9 |
10 | ## Model description
11 |
12 | This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
13 | on the `{{ dataset_name }}` dataset.
14 |
15 | ## Intended uses & limitations
16 |
17 | #### How to use
18 |
19 | ```python
20 | # TODO: add an example code snippet for running this diffusion pipeline
21 | ```
22 |
23 | #### Limitations and bias
24 |
25 | [TODO: provide examples of latent issues and potential remediations]
26 |
27 | ## Training data
28 |
29 | [TODO: describe the data used to train the model]
30 |
31 | ### Training hyperparameters
32 |
33 | The following hyperparameters were used during training:
34 | - learning_rate: {{ learning_rate }}
35 | - train_batch_size: {{ train_batch_size }}
36 | - eval_batch_size: {{ eval_batch_size }}
37 | - gradient_accumulation_steps: {{ gradient_accumulation_steps }}
38 | - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
39 | - lr_scheduler: {{ lr_scheduler }}
40 | - lr_warmup_steps: {{ lr_warmup_steps }}
41 | - ema_inv_gamma: {{ ema_inv_gamma }}
42 | - ema_inv_gamma: {{ ema_power }}
43 | - ema_inv_gamma: {{ ema_max_decay }}
44 | - mixed_precision: {{ mixed_precision }}
45 |
46 | ### Training results
47 |
48 | 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
49 |
50 |
51 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/diffusers_all/c8c0c0e846c8afc07602c44180278a2f7f15331d/tests/__init__.py
--------------------------------------------------------------------------------
/utils/check_config_docstrings.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import importlib
17 | import inspect
18 | import os
19 | import re
20 |
21 |
22 | # All paths are set with the intent you should run this script from the root of the repo with the command
23 | # python utils/check_config_docstrings.py
24 | PATH_TO_TRANSFORMERS = "src/transformers"
25 |
26 |
27 | # This is to make sure the transformers module imported is the one in the repo.
28 | spec = importlib.util.spec_from_file_location(
29 | "transformers",
30 | os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
31 | submodule_search_locations=[PATH_TO_TRANSFORMERS],
32 | )
33 | transformers = spec.loader.load_module()
34 |
35 | CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
36 |
37 | # Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
38 | # For example, `[bert-base-uncased](https://huggingface.co/bert-base-uncased)`
39 | _re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
40 |
41 |
42 | CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
43 | "CLIPConfigMixin",
44 | "DecisionTransformerConfigMixin",
45 | "EncoderDecoderConfigMixin",
46 | "RagConfigMixin",
47 | "SpeechEncoderDecoderConfigMixin",
48 | "VisionEncoderDecoderConfigMixin",
49 | "VisionTextDualEncoderConfigMixin",
50 | }
51 |
52 |
53 | def check_config_docstrings_have_checkpoints():
54 | configs_without_checkpoint = []
55 |
56 | for config_class in list(CONFIG_MAPPING.values()):
57 | checkpoint_found = False
58 |
59 | # source code of `config_class`
60 | config_source = inspect.getsource(config_class)
61 | checkpoints = _re_checkpoint.findall(config_source)
62 |
63 | for checkpoint in checkpoints:
64 | # Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link.
65 | # For example, `('bert-base-uncased', 'https://huggingface.co/bert-base-uncased')`
66 | ckpt_name, ckpt_link = checkpoint
67 |
68 | # verify the checkpoint name corresponds to the checkpoint link
69 | ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
70 | if ckpt_link == ckpt_link_from_name:
71 | checkpoint_found = True
72 | break
73 |
74 | name = config_class.__name__
75 | if not checkpoint_found and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
76 | configs_without_checkpoint.append(name)
77 |
78 | if len(configs_without_checkpoint) > 0:
79 | message = "\n".join(sorted(configs_without_checkpoint))
80 | raise ValueError(f"The following configurations don't contain any valid checkpoint:\n{message}")
81 |
82 |
83 | if __name__ == "__main__":
84 | check_config_docstrings_have_checkpoints()
85 |
--------------------------------------------------------------------------------
/utils/check_dummies.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import os
18 | import re
19 |
20 |
21 | # All paths are set with the intent you should run this script from the root of the repo with the command
22 | # python utils/check_dummies.py
23 | PATH_TO_DIFFUSERS = "src/diffusers"
24 |
25 | # Matches is_xxx_available()
26 | _re_backend = re.compile(r"is\_([a-z_]*)_available\(\)")
27 | # Matches from xxx import bla
28 | _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
29 |
30 |
31 | DUMMY_CONSTANT = """
32 | {0} = None
33 | """
34 |
35 | DUMMY_CLASS = """
36 | class {0}(metaclass=DummyObject):
37 | _backends = {1}
38 |
39 | def __init__(self, *args, **kwargs):
40 | requires_backends(self, {1})
41 | """
42 |
43 |
44 | DUMMY_FUNCTION = """
45 | def {0}(*args, **kwargs):
46 | requires_backends({0}, {1})
47 | """
48 |
49 |
50 | def find_backend(line):
51 | """Find one (or multiple) backend in a code line of the init."""
52 | backends = _re_backend.findall(line)
53 | if len(backends) == 0:
54 | return None
55 |
56 | return "_and_".join(backends)
57 |
58 |
59 | def read_init():
60 | """Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
61 | with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
62 | lines = f.readlines()
63 |
64 | # Get to the point we do the actual imports for type checking
65 | line_index = 0
66 | backend_specific_objects = {}
67 | # Go through the end of the file
68 | while line_index < len(lines):
69 | # If the line is an if is_backend_available, we grab all objects associated.
70 | backend = find_backend(lines[line_index])
71 | if backend is not None:
72 | objects = []
73 | line_index += 1
74 | # Until we unindent, add backend objects to the list
75 | while not lines[line_index].startswith("else:"):
76 | line = lines[line_index]
77 | single_line_import_search = _re_single_line_import.search(line)
78 | if single_line_import_search is not None:
79 | objects.extend(single_line_import_search.groups()[0].split(", "))
80 | elif line.startswith(" " * 12):
81 | objects.append(line[12:-2])
82 | line_index += 1
83 |
84 | backend_specific_objects[backend] = objects
85 | else:
86 | line_index += 1
87 |
88 | return backend_specific_objects
89 |
90 |
91 | def create_dummy_object(name, backend_name):
92 | """Create the code for the dummy object corresponding to `name`."""
93 | if name.isupper():
94 | return DUMMY_CONSTANT.format(name)
95 | elif name.islower():
96 | return DUMMY_FUNCTION.format(name, backend_name)
97 | else:
98 | return DUMMY_CLASS.format(name, backend_name)
99 |
100 |
101 | def create_dummy_files():
102 | """Create the content of the dummy files."""
103 | backend_specific_objects = read_init()
104 | # For special correspondence backend to module name as used in the function requires_modulename
105 | dummy_files = {}
106 |
107 | for backend, objects in backend_specific_objects.items():
108 | backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
109 | dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
110 | dummy_file += "# flake8: noqa\n"
111 | dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
112 | dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
113 | dummy_files[backend] = dummy_file
114 |
115 | return dummy_files
116 |
117 |
118 | def check_dummies(overwrite=False):
119 | """Check if the dummy files are up to date and maybe `overwrite` with the right content."""
120 | dummy_files = create_dummy_files()
121 | # For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
122 | short_names = {"torch": "pt"}
123 |
124 | # Locate actual dummy modules and read their content.
125 | path = os.path.join(PATH_TO_DIFFUSERS, "utils")
126 | dummy_file_paths = {
127 | backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
128 | for backend in dummy_files.keys()
129 | }
130 |
131 | actual_dummies = {}
132 | for backend, file_path in dummy_file_paths.items():
133 | if os.path.isfile(file_path):
134 | with open(file_path, "r", encoding="utf-8", newline="\n") as f:
135 | actual_dummies[backend] = f.read()
136 | else:
137 | actual_dummies[backend] = ""
138 |
139 | for backend in dummy_files.keys():
140 | if dummy_files[backend] != actual_dummies[backend]:
141 | if overwrite:
142 | print(
143 | f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
144 | "__init__ has new objects."
145 | )
146 | with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
147 | f.write(dummy_files[backend])
148 | else:
149 | raise ValueError(
150 | "The main __init__ has objects that are not present in "
151 | f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
152 | "to fix this."
153 | )
154 |
155 |
156 | if __name__ == "__main__":
157 | parser = argparse.ArgumentParser()
158 | parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
159 | args = parser.parse_args()
160 |
161 | check_dummies(args.fix_and_overwrite)
162 |
--------------------------------------------------------------------------------
/utils/check_table.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import collections
18 | import importlib.util
19 | import os
20 | import re
21 |
22 |
23 | # All paths are set with the intent you should run this script from the root of the repo with the command
24 | # python utils/check_table.py
25 | TRANSFORMERS_PATH = "src/diffusers"
26 | PATH_TO_DOCS = "docs/source/en"
27 | REPO_PATH = "."
28 |
29 |
30 | def _find_text_in_file(filename, start_prompt, end_prompt):
31 | """
32 | Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
33 | lines.
34 | """
35 | with open(filename, "r", encoding="utf-8", newline="\n") as f:
36 | lines = f.readlines()
37 | # Find the start prompt.
38 | start_index = 0
39 | while not lines[start_index].startswith(start_prompt):
40 | start_index += 1
41 | start_index += 1
42 |
43 | end_index = start_index
44 | while not lines[end_index].startswith(end_prompt):
45 | end_index += 1
46 | end_index -= 1
47 |
48 | while len(lines[start_index]) <= 1:
49 | start_index += 1
50 | while len(lines[end_index]) <= 1:
51 | end_index -= 1
52 | end_index += 1
53 | return "".join(lines[start_index:end_index]), start_index, end_index, lines
54 |
55 |
56 | # Add here suffixes that are used to identify models, seperated by |
57 | ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
58 | # Regexes that match TF/Flax/PT model names.
59 | _re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
60 | _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
61 | # Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
62 | _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
63 |
64 |
65 | # This is to make sure the diffusers module imported is the one in the repo.
66 | spec = importlib.util.spec_from_file_location(
67 | "diffusers",
68 | os.path.join(TRANSFORMERS_PATH, "__init__.py"),
69 | submodule_search_locations=[TRANSFORMERS_PATH],
70 | )
71 | diffusers_module = spec.loader.load_module()
72 |
73 |
74 | # Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
75 | def camel_case_split(identifier):
76 | "Split a camelcased `identifier` into words."
77 | matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
78 | return [m.group(0) for m in matches]
79 |
80 |
81 | def _center_text(text, width):
82 | text_length = 2 if text == "✅" or text == "❌" else len(text)
83 | left_indent = (width - text_length) // 2
84 | right_indent = width - text_length - left_indent
85 | return " " * left_indent + text + " " * right_indent
86 |
87 |
88 | def get_model_table_from_auto_modules():
89 | """Generates an up-to-date model table from the content of the auto modules."""
90 | # Dictionary model names to config.
91 | config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
92 | model_name_to_config = {
93 | name: config_maping_names[code]
94 | for code, name in diffusers_module.MODEL_NAMES_MAPPING.items()
95 | if code in config_maping_names
96 | }
97 | model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
98 |
99 | # Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
100 | slow_tokenizers = collections.defaultdict(bool)
101 | fast_tokenizers = collections.defaultdict(bool)
102 | pt_models = collections.defaultdict(bool)
103 | tf_models = collections.defaultdict(bool)
104 | flax_models = collections.defaultdict(bool)
105 |
106 | # Let's lookup through all diffusers object (once).
107 | for attr_name in dir(diffusers_module):
108 | lookup_dict = None
109 | if attr_name.endswith("Tokenizer"):
110 | lookup_dict = slow_tokenizers
111 | attr_name = attr_name[:-9]
112 | elif attr_name.endswith("TokenizerFast"):
113 | lookup_dict = fast_tokenizers
114 | attr_name = attr_name[:-13]
115 | elif _re_tf_models.match(attr_name) is not None:
116 | lookup_dict = tf_models
117 | attr_name = _re_tf_models.match(attr_name).groups()[0]
118 | elif _re_flax_models.match(attr_name) is not None:
119 | lookup_dict = flax_models
120 | attr_name = _re_flax_models.match(attr_name).groups()[0]
121 | elif _re_pt_models.match(attr_name) is not None:
122 | lookup_dict = pt_models
123 | attr_name = _re_pt_models.match(attr_name).groups()[0]
124 |
125 | if lookup_dict is not None:
126 | while len(attr_name) > 0:
127 | if attr_name in model_name_to_prefix.values():
128 | lookup_dict[attr_name] = True
129 | break
130 | # Try again after removing the last word in the name
131 | attr_name = "".join(camel_case_split(attr_name)[:-1])
132 |
133 | # Let's build that table!
134 | model_names = list(model_name_to_config.keys())
135 | model_names.sort(key=str.lower)
136 | columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
137 | # We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
138 | widths = [len(c) + 2 for c in columns]
139 | widths[0] = max([len(name) for name in model_names]) + 2
140 |
141 | # Build the table per se
142 | table = "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
143 | # Use ":-----:" format to center-aligned table cell texts
144 | table += "|" + "|".join([":" + "-" * (w - 2) + ":" for w in widths]) + "|\n"
145 |
146 | check = {True: "✅", False: "❌"}
147 | for name in model_names:
148 | prefix = model_name_to_prefix[name]
149 | line = [
150 | name,
151 | check[slow_tokenizers[prefix]],
152 | check[fast_tokenizers[prefix]],
153 | check[pt_models[prefix]],
154 | check[tf_models[prefix]],
155 | check[flax_models[prefix]],
156 | ]
157 | table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
158 | return table
159 |
160 |
161 | def check_model_table(overwrite=False):
162 | """Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
163 | current_table, start_index, end_index, lines = _find_text_in_file(
164 | filename=os.path.join(PATH_TO_DOCS, "index.mdx"),
165 | start_prompt="",
167 | )
168 | new_table = get_model_table_from_auto_modules()
169 |
170 | if current_table != new_table:
171 | if overwrite:
172 | with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f:
173 | f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
174 | else:
175 | raise ValueError(
176 | "The model table in the `index.mdx` has not been updated. Run `make fix-copies` to fix this."
177 | )
178 |
179 |
180 | if __name__ == "__main__":
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
183 | args = parser.parse_args()
184 |
185 | check_model_table(args.fix_and_overwrite)
186 |
--------------------------------------------------------------------------------
/utils/check_tf_ops.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import json
18 | import os
19 |
20 | from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
21 |
22 |
23 | # All paths are set with the intent you should run this script from the root of the repo with the command
24 | # python utils/check_copies.py
25 | REPO_PATH = "."
26 |
27 | # Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
28 | INTERNAL_OPS = [
29 | "Assert",
30 | "AssignVariableOp",
31 | "EmptyTensorList",
32 | "MergeV2Checkpoints",
33 | "ReadVariableOp",
34 | "ResourceGather",
35 | "RestoreV2",
36 | "SaveV2",
37 | "ShardedFilename",
38 | "StatefulPartitionedCall",
39 | "StaticRegexFullMatch",
40 | "VarHandleOp",
41 | ]
42 |
43 |
44 | def onnx_compliancy(saved_model_path, strict, opset):
45 | saved_model = SavedModel()
46 | onnx_ops = []
47 |
48 | with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
49 | onnx_opsets = json.load(f)["opsets"]
50 |
51 | for i in range(1, opset + 1):
52 | onnx_ops.extend(onnx_opsets[str(i)])
53 |
54 | with open(saved_model_path, "rb") as f:
55 | saved_model.ParseFromString(f.read())
56 |
57 | model_op_names = set()
58 |
59 | # Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
60 | for meta_graph in saved_model.meta_graphs:
61 | # Add operations in the graph definition
62 | model_op_names.update(node.op for node in meta_graph.graph_def.node)
63 |
64 | # Go through the functions in the graph definition
65 | for func in meta_graph.graph_def.library.function:
66 | # Add operations in each function
67 | model_op_names.update(node.op for node in func.node_def)
68 |
69 | # Convert to list, sorted if you want
70 | model_op_names = sorted(model_op_names)
71 | incompatible_ops = []
72 |
73 | for op in model_op_names:
74 | if op not in onnx_ops and op not in INTERNAL_OPS:
75 | incompatible_ops.append(op)
76 |
77 | if strict and len(incompatible_ops) > 0:
78 | raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
79 | elif len(incompatible_ops) > 0:
80 | print(f"Found the following incompatible ops for the opset {opset}:")
81 | print(*incompatible_ops, sep="\n")
82 | else:
83 | print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
84 |
85 |
86 | if __name__ == "__main__":
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
89 | parser.add_argument(
90 | "--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
91 | )
92 | parser.add_argument(
93 | "--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
94 | )
95 | parser.add_argument(
96 | "--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
97 | )
98 | args = parser.parse_args()
99 |
100 | if args.framework == "onnx":
101 | onnx_compliancy(args.saved_model_path, args.strict, args.opset)
102 |
--------------------------------------------------------------------------------
/utils/custom_init_isort.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import os
18 | import re
19 |
20 |
21 | PATH_TO_TRANSFORMERS = "src/diffusers"
22 |
23 | # Pattern that looks at the indentation in a line.
24 | _re_indent = re.compile(r"^(\s*)\S")
25 | # Pattern that matches `"key":" and puts `key` in group 0.
26 | _re_direct_key = re.compile(r'^\s*"([^"]+)":')
27 | # Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
28 | _re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
29 | # Pattern that matches `"key",` and puts `key` in group 0.
30 | _re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
31 | # Pattern that matches any `[stuff]` and puts `stuff` in group 0.
32 | _re_bracket_content = re.compile(r"\[([^\]]+)\]")
33 |
34 |
35 | def get_indent(line):
36 | """Returns the indent in `line`."""
37 | search = _re_indent.search(line)
38 | return "" if search is None else search.groups()[0]
39 |
40 |
41 | def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
42 | """
43 | Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
44 | `start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
45 | after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
46 | """
47 | # Let's split the code into lines and move to start_index.
48 | index = 0
49 | lines = code.split("\n")
50 | if start_prompt is not None:
51 | while not lines[index].startswith(start_prompt):
52 | index += 1
53 | blocks = ["\n".join(lines[:index])]
54 | else:
55 | blocks = []
56 |
57 | # We split into blocks until we get to the `end_prompt` (or the end of the block).
58 | current_block = [lines[index]]
59 | index += 1
60 | while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
61 | if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
62 | if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
63 | current_block.append(lines[index])
64 | blocks.append("\n".join(current_block))
65 | if index < len(lines) - 1:
66 | current_block = [lines[index + 1]]
67 | index += 1
68 | else:
69 | current_block = []
70 | else:
71 | blocks.append("\n".join(current_block))
72 | current_block = [lines[index]]
73 | else:
74 | current_block.append(lines[index])
75 | index += 1
76 |
77 | # Adds current block if it's nonempty.
78 | if len(current_block) > 0:
79 | blocks.append("\n".join(current_block))
80 |
81 | # Add final block after end_prompt if provided.
82 | if end_prompt is not None and index < len(lines):
83 | blocks.append("\n".join(lines[index:]))
84 |
85 | return blocks
86 |
87 |
88 | def ignore_underscore(key):
89 | "Wraps a `key` (that maps an object to string) to lower case and remove underscores."
90 |
91 | def _inner(x):
92 | return key(x).lower().replace("_", "")
93 |
94 | return _inner
95 |
96 |
97 | def sort_objects(objects, key=None):
98 | "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
99 | # If no key is provided, we use a noop.
100 | def noop(x):
101 | return x
102 |
103 | if key is None:
104 | key = noop
105 | # Constants are all uppercase, they go first.
106 | constants = [obj for obj in objects if key(obj).isupper()]
107 | # Classes are not all uppercase but start with a capital, they go second.
108 | classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
109 | # Functions begin with a lowercase, they go last.
110 | functions = [obj for obj in objects if not key(obj)[0].isupper()]
111 |
112 | key1 = ignore_underscore(key)
113 | return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
114 |
115 |
116 | def sort_objects_in_import(import_statement):
117 | """
118 | Return the same `import_statement` but with objects properly sorted.
119 | """
120 | # This inner function sort imports between [ ].
121 | def _replace(match):
122 | imports = match.groups()[0]
123 | if "," not in imports:
124 | return f"[{imports}]"
125 | keys = [part.strip().replace('"', "") for part in imports.split(",")]
126 | # We will have a final empty element if the line finished with a comma.
127 | if len(keys[-1]) == 0:
128 | keys = keys[:-1]
129 | return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
130 |
131 | lines = import_statement.split("\n")
132 | if len(lines) > 3:
133 | # Here we have to sort internal imports that are on several lines (one per name):
134 | # key: [
135 | # "object1",
136 | # "object2",
137 | # ...
138 | # ]
139 |
140 | # We may have to ignore one or two lines on each side.
141 | idx = 2 if lines[1].strip() == "[" else 1
142 | keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
143 | sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
144 | sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
145 | return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
146 | elif len(lines) == 3:
147 | # Here we have to sort internal imports that are on one separate line:
148 | # key: [
149 | # "object1", "object2", ...
150 | # ]
151 | if _re_bracket_content.search(lines[1]) is not None:
152 | lines[1] = _re_bracket_content.sub(_replace, lines[1])
153 | else:
154 | keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
155 | # We will have a final empty element if the line finished with a comma.
156 | if len(keys[-1]) == 0:
157 | keys = keys[:-1]
158 | lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
159 | return "\n".join(lines)
160 | else:
161 | # Finally we have to deal with imports fitting on one line
162 | import_statement = _re_bracket_content.sub(_replace, import_statement)
163 | return import_statement
164 |
165 |
166 | def sort_imports(file, check_only=True):
167 | """
168 | Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
169 | """
170 | with open(file, "r") as f:
171 | code = f.read()
172 |
173 | if "_import_structure" not in code:
174 | return
175 |
176 | # Blocks of indent level 0
177 | main_blocks = split_code_in_indented_blocks(
178 | code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
179 | )
180 |
181 | # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt).
182 | for block_idx in range(1, len(main_blocks) - 1):
183 | # Check if the block contains some `_import_structure`s thingy to sort.
184 | block = main_blocks[block_idx]
185 | block_lines = block.split("\n")
186 |
187 | # Get to the start of the imports.
188 | line_idx = 0
189 | while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
190 | # Skip dummy import blocks
191 | if "import dummy" in block_lines[line_idx]:
192 | line_idx = len(block_lines)
193 | else:
194 | line_idx += 1
195 | if line_idx >= len(block_lines):
196 | continue
197 |
198 | # Ignore beginning and last line: they don't contain anything.
199 | internal_block_code = "\n".join(block_lines[line_idx:-1])
200 | indent = get_indent(block_lines[1])
201 | # Slit the internal block into blocks of indent level 1.
202 | internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
203 | # We have two categories of import key: list or _import_structu[key].append/extend
204 | pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
205 | # Grab the keys, but there is a trap: some lines are empty or jsut comments.
206 | keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
207 | # We only sort the lines with a key.
208 | keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
209 | sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
210 |
211 | # We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
212 | count = 0
213 | reorderded_blocks = []
214 | for i in range(len(internal_blocks)):
215 | if keys[i] is None:
216 | reorderded_blocks.append(internal_blocks[i])
217 | else:
218 | block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
219 | reorderded_blocks.append(block)
220 | count += 1
221 |
222 | # And we put our main block back together with its first and last line.
223 | main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
224 |
225 | if code != "\n".join(main_blocks):
226 | if check_only:
227 | return True
228 | else:
229 | print(f"Overwriting {file}.")
230 | with open(file, "w") as f:
231 | f.write("\n".join(main_blocks))
232 |
233 |
234 | def sort_imports_in_all_inits(check_only=True):
235 | failures = []
236 | for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
237 | if "__init__.py" in files:
238 | result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
239 | if result:
240 | failures = [os.path.join(root, "__init__.py")]
241 | if len(failures) > 0:
242 | raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
243 |
244 |
245 | if __name__ == "__main__":
246 | parser = argparse.ArgumentParser()
247 | parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
248 | args = parser.parse_args()
249 |
250 | sort_imports_in_all_inits(check_only=args.check_only)
251 |
--------------------------------------------------------------------------------